下面为你 无缝集成 JWT(JSON Web Token)认证 到现有的 FastAPI 情感分析 API 中,确保只有授权用户才能调用敏感接口(如批量分析),同时保持 Gradio Web UI 无需认证(方便内部使用)。
✅ 目标
🔒 保护 API 接口:/predict 和 /predict/batch 需要有效 JWT
🌐 Web UI 不受影响:Gradio 仍可公开访问(或另配 Nginx 基础认证)
🗝️ 支持 token 发放:提供 /token 接口用于登录获取 token
⏱️ Token 过期机制:默认 1 小时有效期
第一步:安装依赖
txt

编辑

# requirements.txt 新增
python-jose[cryptography]==3.3.0
passlib[bcrypt]==1.7.4
💡 python-jose 用于 JWT 编解码,passlib 用于密码哈希
第二步:安全配置(core/security.py)
python

编辑

# core/security.py
from datetime import datetime, timedelta
from typing import Optional
from jose import JWTError, jwt
from passlib.context import CryptContext
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer

# 密钥配置(生产环境请从环境变量读取!)
SECRET_KEY = "your-super-secret-jwt-key-change-in-production"  # ← 必须更换!
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 60

pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")

# 模拟用户数据库(实际应查数据库)
fake_users_db = {
    "admin": {
        "username": "admin",
        "hashed_password": pwd_context.hash("secret123"),  # 密码: secret123
    }
}

def verify_password(plain_password, hashed_password):
    return pwd_context.verify(plain_password, hashed_password)

def get_user(username: str):
    if username in fake_users_db:
        return fake_users_db[username]

def authenticate_user(username: str, password: str):
    user = get_user(username)
    if not user or not verify_password(password, user["hashed_password"]):
        return False
    return user

def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
    to_encode = data.copy()
    expire = datetime.utcnow() + (expires_delta or timedelta(minutes=15))
    to_encode.update({"exp": expire})
    return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)

async def get_current_user(token: str = Depends(oauth2_scheme)):
    credentials_exception = HTTPException(
        status_code=status.HTTP_401_UNAUTHORIZED,
        detail="无效的认证凭证",
        headers={"WWW-Authenticate": "Bearer"},
    )
    try:
        payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
        username: str = payload.get("sub")
        if username is None:
            raise credentials_exception
    except JWTError:
        raise credentials_exception
    user = get_user(username)
    if user is None:
        raise credentials_exception
    return user
🔐 重要提示:
SECRET_KEY 必须是强随机字符串(可用 openssl rand -hex 32 生成)
生产环境应从 os.getenv("JWT_SECRET") 读取
第三步:更新 FastAPI 路由(api/main.py)
python

编辑

# api/main.py(更新部分)
from fastapi import FastAPI, HTTPException, Depends, status
from fastapi.security import OAuth2PasswordRequestForm
from core.security import (
    create_access_token,
    get_current_user,
    ACCESS_TOKEN_EXPIRE_MINUTES
)
from datetime import timedelta
from core.predictor import predictor
from core.schemas import SingleRequest, SingleResponse, BatchRequest, BatchResponse

app = FastAPI(
    title="中文情感分析 API (JWT 认证)",
    description="需要 JWT 认证的情感分析服务",
    version="1.1.0"
)

# ==================== 新增:Token 获取接口 ===================
@app.post("/token", summary="获取访问令牌")
async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()):
    user = authenticate_user(form_data.username, form_data.password)
    if not user:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="用户名或密码错误",
            headers={"WWW-Authenticate": "Bearer"},
        )
    access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
    access_token = create_access_token(
        data={"sub": form_data.username}, expires_delta=access_token_expires
    )
    return {"access_token": access_token, "token_type": "bearer"}

# ==================== 受保护的预测接口 ===================
@app.post("/predict", response_model=SingleResponse, dependencies=[Depends(get_current_user)])
async def predict_single(request: SingleRequest):
    """单条文本情感分析(需认证)"""
    if not request.text.strip():
        raise HTTPException(status_code=400, detail="文本不能为空")
    # ... [预测逻辑不变] ...

@app.post("/predict/batch", response_model=BatchResponse, dependencies=[Depends(get_current_user)])
async def predict_batch(request: BatchRequest):
    """批量文本情感分析(需认证)"""
    if not request.texts:
        raise HTTPException(status_code=400, detail="文本列表不能为空")
    if len(request.texts) > 100:
        raise HTTPException(status_code=400, detail="批量大小不能超过 100")
    # ... [预测逻辑不变] ...

@app.get("/health")
async def health_check():
    """健康检查(公开)"""
    return {"status": "ok"}
✅ 关键点:
使用 dependencies=[Depends(get_current_user)] 保护路由
/token 和 /health 保持公开
第四步:Docker 环境变量支持(安全最佳实践)
🛡️ .env 文件(不要提交到 Git!)
env

编辑

# .env
JWT_SECRET=your-very-long-random-secret-key-here-32-bytes-or-more
ADMIN_PASSWORD=your-admin-password
🔧 更新 core/security.py(从环境变量读取)
python

编辑

import os
SECRET_KEY = os.getenv("JWT_SECRET", "fallback-dev-key")
fake_users_db = {
    "admin": {
        "username": "admin",
        "hashed_password": pwd_context.hash(os.getenv("ADMIN_PASSWORD", "secret123")),
    }
}
🐳 更新 docker-compose.yml
yaml

编辑

# docker-compose.yml
version: '3'

services:
  sentiment-api:
    build: .
    ports:
      - "8000:8000"
      - "7861:7861"
    volumes:
      - ./models:/app/models
      - ./.env:/app/.env  # 挂载环境变量
    env_file:
      - .env
    command: >
      sh -c "
        uvicorn api.main:app --host 0.0.0.0 --port 8000 &
        python web/app.py
      "
▶️ 使用流程
. 获取 Token
bash

编辑

curl -X POST "http://localhost:8000/token" \
  -H "Content-Type: application/x-www-form-urlencoded" \
  -d "username=admin&password=secret123"
响应:
json

编辑

{
  "access_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...",
  "token_type": "bearer"
}
. 调用受保护 API
bash

编辑

TOKEN="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."

curl -X POST "http://localhost:8000/predict/batch" \
  -H "Authorization: Bearer $TOKEN" \
  -H "Content-Type: application/json" \
  -d '{"texts": ["今天真开心", "太糟糕了"], "model": "bert"}'
❌ 无 token 或无效 token → 返回 401 Unauthorized
🌐 Gradio Web UI 保持开放
Gradio 仍运行在 7861 端口,无需任何认证
适合内部团队快速测试
如需保护 Web UI,建议在 Nginx 层加基础认证
🔒 安全增强建议(生产环境)
表格
措施    说明
HTTPS    用 Nginx + Let's Encrypt 强制 HTTPS
速率限制    用 slowapi 限制 API 调用频率
真实用户数据库    替换 fake_users_db 为 PostgreSQL/MySQL
Refresh Token    支持长期登录(本例为简化省略)
✅ 最终效果
表格
接口    是否需要认证    用途
POST /token    ❌ 公开    获取 JWT
GET /health    ❌ 公开    健康检查
POST /predict    ✅ 需 JWT    单条分析
POST /predict/batch    ✅ 需 JWT    批量分析
http://:7861    ❌ 公开    Web UI
现在你的系统既 安全 又 易用:
外部开发者通过 API + JWT 安全调用
内部人员通过 Web UI 快速验证

Logo

腾讯云面向开发者汇聚海量精品云计算使用和开发经验,营造开放的云计算技术生态圈。

更多推荐