从核心能力到可交付服务,保姆级实现商用级RESTful接口,覆盖文档管理、流式对话、权限鉴权全场景,附完整工程化代码、接口规范、踩坑避坑指南


前言:从demo到商用,接口是必经的核心桥梁

前五篇我们完成了生产级RAG系统的全链路核心能力建设:

  • 第1篇:完成了分层架构设计与项目骨架搭建
  • 第2篇:实现了全格式文档预处理与智能分块
  • 第3篇:完成了Embedding模型选型与Milvus向量库搭建
  • 第4篇:优化了多路召回+重排序的全链路检索引擎
  • 第5篇:实现了体系化Prompt工程、LLM调用封装与幻觉抑制方案

到这里,我们已经实现了端到端的RAG问答核心能力,但它还只是一个跑在本地的代码工程,无法对接前端页面、无法被第三方系统集成、无法做权限管控、无法上线交付给用户使用

而这一切的核心桥梁,就是标准化的后端接口

很多新手在这一步最容易犯的错:

  • 随便写几个裸接口,没有任何参数校验、异常处理,非法参数一进来服务直接崩溃
  • 把业务逻辑全写在接口里,代码臃肿混乱,后期维护成本极高
  • 没有鉴权、没有租户隔离,谁都能调用,A租户能删B租户的文档,出现严重的数据泄露
  • 不做统一响应格式,前端对接要为每个接口写单独的处理逻辑,对接成本极高
  • 流式对话接口实现错误,要么阻塞服务,要么前端无法正常解析
  • 没有接口文档,前后端联调全靠口口相传,效率极低
  • 没有日志、没有监控,接口出了问题根本找不到原因

一个核心结论必须记死
商用级接口的核心,从来不是“能跑通”,而是稳定、安全、规范、可维护、可观测。一个不合格的接口,会让你前面五篇做的所有核心能力都白费,上线后只会收到用户的投诉和bug反馈。

本篇我们就严格遵循商用级开发规范,基于FastAPI实现完整的标准化后端接口,完全对齐前五篇的项目架构,所有代码可直接复制运行,跟着做完,你的RAG系统就能直接对接前端、上线交付。


一、先立规矩:商用级后端接口的7大设计铁律

在写代码之前,我们必须先明确商用级接口的设计原则,这是保证接口质量的核心,新手必须严格遵守:

1. 分层解耦原则

接口层绝对不能写业务逻辑,只负责3件事:

  • 接收请求参数,做格式与合法性校验
  • 调用对应Service层的业务逻辑
  • 封装处理结果,统一格式返回给前端
    业务逻辑100%收敛到Service层,数据操作100%收敛到DB层,层与层之间通过标准化接口交互,修改某一层的代码,不会影响其他层。

2. RESTful规范原则

严格遵循RESTful API设计规范,用正确的HTTP方法表达语义:

  • GET:查询资源,不修改任何数据
  • POST:创建资源,比如上传文档、创建会话
  • PUT:全量更新资源
  • DELETE:删除资源
    URL语义清晰,版本化管理,比如/api/v1/document,后续升级不影响老调用方。

3. 统一响应原则

所有接口必须返回完全一致的JSON格式,无论成功还是失败,前端只需要一套逻辑就能处理所有接口的响应,大幅降低对接成本。
统一响应格式如下:

{
    "code": 200,          // 状态码:200成功,非200失败
    "message": "success", // 提示信息:成功返回success,失败返回错误原因
    "data": {}            // 数据体:成功返回对应数据,失败返回null
}

4. 强校验原则

所有入参必须做严格校验,包括:

  • 类型校验:字符串、数字、布尔值严格区分
  • 格式校验:邮箱、手机号、ID格式、文件格式
  • 范围校验:数值范围、长度限制、文件大小限制
  • 合法性校验:租户权限、资源归属、越权操作拦截
    非法参数必须在接口层直接拦截,绝对不能传入Service层,避免服务崩溃。

5. 安全优先原则

商用接口的安全是底线,必须做到:

  • 所有业务接口必须鉴权,禁止裸接口对外暴露
  • 严格的租户隔离,每个用户只能操作自己租户下的资源
  • 敏感数据脱敏,密钥、Token、密码绝对不能明文返回或打印日志
  • 接口限流,防止恶意刷接口、CC攻击
  • 大文件上传格式与大小限制,防止恶意文件上传

6. 可观测性原则

所有接口必须有完整的日志记录,包括:

  • 请求唯一ID,全链路追踪
  • 请求用户、租户ID、接口地址、HTTP方法
  • 请求入参(敏感信息脱敏)
  • 响应状态码、响应耗时
  • 异常信息、堆栈跟踪(仅开发环境返回,生产环境不暴露)
    出问题能通过日志快速定位、快速复盘。

7. 兼容性原则

接口必须做版本控制,/api/v1//api/v2/,升级接口时必须保证向下兼容,绝对不能修改已上线接口的入参、出参结构,避免导致前端或第三方系统崩溃。


二、前置准备:项目结构补全与依赖更新

我们完全对齐第一篇的项目骨架,补全接口开发所需的目录和文件,保证整个项目的结构一致性。

2.1 最终项目目录结构(接口开发补全版)

production-rag-system/
├── app/
│   ├── api/                          # 接口层核心目录
│   │   ├── deps.py                   # 接口依赖项:鉴权、参数校验、租户隔离
│   │   ├── common.py                 # 通用接口:健康检查、系统信息
│   │   └── v1/                       # v1版本接口,后续升级可新增v2
│   │       ├── router.py             # v1接口路由总入口
│   │       ├── auth.py               # 认证接口:登录、注册、Token刷新
│   │       ├── document.py           # 文档管理接口:上传、删除、列表查询
│   │       ├── chat.py               # 对话问答接口:同步/流式问答、会话管理
│   │       └── user.py               # 用户/租户基础接口
│   ├── core/                         # 核心引擎层(前五篇已实现)
│   ├── service/                      # 业务服务层(前五篇已实现,本次补充)
│   ├── models/                       # ORM数据模型:SQLAlchemy表结构定义
│   │   ├── __init__.py
│   │   ├── user.py                   # 用户/租户模型
│   │   ├── document.py               # 文档元数据模型
│   │   └── chat.py                   # 对话会话/历史模型
│   ├── schemas/                      # Pydantic模型:入参/出参校验、类型定义
│   │   ├── __init__.py
│   │   ├── common.py                 # 通用响应模型
│   │   ├── auth.py                   # 认证相关入参/出参
│   │   ├── document.py               # 文档相关入参/出参
│   │   └── chat.py                   # 对话相关入参/出参
│   ├── db/                           # 数据库连接层(第一篇已实现,本次补充)
│   │   ├── __init__.py
│   │   ├── relational_db.py          # PostgreSQL连接、会话管理
│   │   ├── vector_db.py              # Milvus连接(第三篇已实现)
│   │   └── cache.py                  # Redis连接(第一篇已实现)
│   ├── config/                       # 配置管理(第一篇已实现)
│   ├── utils/                        # 通用工具函数:JWT、加密、脱敏
│   │   ├── __init__.py
│   │   ├── security.py               # 密码加密、JWT生成/校验
│   │   └── logger.py                 # 日志工具
│   └── main.py                       # 项目启动入口(本次补充完整)
├── tests/                            # 单元测试
├── deploy/                           # 部署配置
├── logs/                             # 日志文件
├── .env.example                      # 环境变量示例
├── requirements.txt                  # 依赖包(本次补充)
└── README.md

2.2 补充依赖包

更新requirements.txt,新增接口开发所需的依赖:

# 原有依赖(前五篇已包含,此处省略)
# ...

# 本篇新增:FastAPI接口开发核心依赖
fastapi>=0.115.0
uvicorn[standard]>=0.30.0
pydantic>=2.9.0
pydantic-settings>=2.5.0

# 本篇新增:ORM与数据库
sqlalchemy>=2.0.30
psycopg2-binary>=2.9.9
alembic>=1.13.0

# 本篇新增:认证与安全
python-jose[cryptography]>=3.3.0
passlib[bcrypt]>=1.7.4
python-multipart>=0.0.12

# 本篇新增:接口限流
slowapi>=0.1.9
limits>=3.12.0

# 本篇新增:工具类
python-dotenv>=1.0.0
loguru>=0.7.2

2.3 环境变量补充

更新.env.example,新增认证、数据库相关配置:

# 原有配置(前五篇已包含,此处省略)
# ...

# 本篇新增:JWT认证配置
JWT_SECRET_KEY=你的JWT密钥,用随机字符串生成,至少32位
JWT_ALGORITHM=HS256
JWT_ACCESS_TOKEN_EXPIRE_MINUTES=1440  # Token有效期24小时

# 本篇新增:文件上传配置
MAX_UPLOAD_FILE_SIZE=10485760  # 最大文件大小10MB
ALLOWED_FILE_EXTENSIONS=pdf,docx,doc,xlsx,xls,pptx,ppt,txt,md

# 本篇新增:超级管理员配置
SUPER_ADMIN_USERNAME=admin
SUPER_ADMIN_PASSWORD=admin123456

三、核心基础模块实现

在写业务接口之前,我们先把通用的基础模块实现,包括统一响应模型、全局异常处理、ORM模型、Pydantic校验模型、JWT鉴权依赖项,这些是所有接口的基础。

3.1 通用响应模型与全局异常处理

新建app/schemas/common.py,定义统一响应模型:

from typing import Generic, TypeVar, Optional
from pydantic import BaseModel, Field

# 泛型类型,支持任意数据类型
T = TypeVar("T")

class ResponseModel(BaseModel, Generic[T]):
    """统一响应模型"""
    code: int = Field(200, description="状态码,200为成功,非200为失败")
    message: str = Field("success", description="响应信息,成功为success,失败为错误原因")
    data: Optional[T] = Field(None, description="响应数据体")

    @classmethod
    def success(cls, data: T = None, message: str = "success") -> "ResponseModel[T]":
        """成功响应"""
        return cls(code=200, message=message, data=data)

    @classmethod
    def error(cls, code: int = 400, message: str = "请求失败", data: T = None) -> "ResponseModel[T]":
        """失败响应"""
        return cls(code=code, message=message, data=data)

# 通用分页请求模型
class PageQuery(BaseModel):
    page: int = Field(1, ge=1, description="页码,从1开始")
    page_size: int = Field(10, ge=1, le=100, description="每页条数,1-100")

# 通用分页响应模型
class PageResult(BaseModel, Generic[T]):
    total: int = Field(description="总条数")
    page: int = Field(description="当前页码")
    page_size: int = Field(description="每页条数")
    total_page: int = Field(description="总页数")
    list: List[T] = Field(description="数据列表")

新建app/utils/exception.py,定义全局自定义异常与异常处理器:

from fastapi import Request, status
from fastapi.responses import JSONResponse
from loguru import logger
from app.schemas.common import ResponseModel

# 自定义业务异常
class BusinessException(Exception):
    def __init__(self, code: int = 400, message: str = "业务处理失败"):
        self.code = code
        self.message = message

# 全局异常处理器
async def global_exception_handler(request: Request, exc: Exception):
    """全局异常捕获,统一返回格式"""
    # 生成请求唯一ID,用于日志追踪
    request_id = request.headers.get("X-Request-ID", "unknown")
    # 获取请求路径和方法
    path = request.url.path
    method = request.method

    # 自定义业务异常
    if isinstance(exc, BusinessException):
        logger.warning(f"[{request_id}] 业务异常:{method} {path} | code={exc.code} | message={exc.message}")
        return JSONResponse(
            status_code=200,
            content=ResponseModel.error(code=exc.code, message=exc.message).model_dump()
        )

    # 其他未知异常
    logger.error(f"[{request_id}] 系统异常:{method} {path} | 异常信息:{str(exc)}", exc_info=True)
    return JSONResponse(
        status_code=200,
        content=ResponseModel.error(code=500, message="系统内部错误,请联系管理员").model_dump()
    )

# 404异常处理器
async def not_found_exception_handler(request: Request, exc):
    return JSONResponse(
        status_code=200,
        content=ResponseModel.error(code=404, message="请求的接口不存在").model_dump()
    )

# 权限异常处理器
async def permission_denied_exception_handler(request: Request, exc):
    return JSONResponse(
        status_code=200,
        content=ResponseModel.error(code=403, message="权限不足,无法访问该资源").model_dump()
    )

3.2 ORM数据模型与数据库连接

更新app/db/relational_db.py,完善PostgreSQL连接与SQLAlchemy会话管理:

from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, Session
from loguru import logger
from app.config.settings import settings

# 构建数据库连接URL
SQLALCHEMY_DATABASE_URL = f"postgresql://{settings.DB_USER}:{settings.DB_PASSWORD}@{settings.DB_HOST}:{settings.DB_PORT}/{settings.DB_NAME}"

# 创建数据库引擎
engine = create_engine(
    SQLALCHEMY_DATABASE_URL,
    pool_pre_ping=True,  # 自动检测失效连接
    pool_recycle=3600,   # 连接回收时间
    pool_size=20,        # 连接池大小
    max_overflow=10      # 最大溢出连接数
)

# 会话工厂
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)

# 基类,所有ORM模型都继承这个类
Base = declarative_base()

# 数据库依赖项,用于接口中获取数据库会话
def get_db() -> Session:
    db = SessionLocal()
    try:
        yield db
    finally:
        db.close()

# 初始化数据库表
def init_db():
    """项目启动时初始化数据库表"""
    try:
        # 导入所有模型,确保创建表
        from app.models.user import User, Tenant
        from app.models.document import Document
        from app.models.chat import ChatSession, ChatMessage
        # 创建所有表
        Base.metadata.create_all(bind=engine)
        logger.info("数据库表初始化成功")
    except Exception as e:
        logger.error(f"数据库表初始化失败:{str(e)}", exc_info=True)
        raise e

新建app/models/user.py,定义用户与租户ORM模型:

from sqlalchemy import Column, Integer, String, Boolean, DateTime, ForeignKey
from sqlalchemy.sql import func
from app.db.relational_db import Base

class Tenant(Base):
    """租户模型,多租户隔离核心"""
    __tablename__ = "sys_tenant"

    id = Column(Integer, primary_key=True, autoincrement=True, comment="租户ID")
    tenant_name = Column(String(64), nullable=False, unique=True, comment="租户名称")
    tenant_code = Column(String(32), nullable=False, unique=True, comment="租户编码")
    status = Column(Integer, default=1, comment="状态:1启用 0禁用")
    created_at = Column(DateTime, default=func.now(), comment="创建时间")
    updated_at = Column(DateTime, default=func.now(), onupdate=func.now(), comment="更新时间")

class User(Base):
    """用户模型"""
    __tablename__ = "sys_user"

    id = Column(Integer, primary_key=True, autoincrement=True, comment="用户ID")
    username = Column(String(32), nullable=False, unique=True, comment="用户名")
    password = Column(String(128), nullable=False, comment="加密后的密码")
    nickname = Column(String(64), default="", comment="用户昵称")
    email = Column(String(64), default="", comment="邮箱")
    phone = Column(String(16), default="", comment="手机号")
    tenant_id = Column(Integer, ForeignKey("sys_tenant.id"), nullable=False, comment="所属租户ID")
    is_admin = Column(Boolean, default=False, comment="是否为租户管理员")
    is_super_admin = Column(Boolean, default=False, comment="是否为超级管理员")
    status = Column(Integer, default=1, comment="状态:1启用 0禁用")
    created_at = Column(DateTime, default=func.now(), comment="创建时间")
    updated_at = Column(DateTime, default=func.now(), onupdate=func.now(), comment="更新时间")

新建app/models/document.py,定义文档元数据模型:

from sqlalchemy import Column, Integer, String, BigInteger, DateTime, ForeignKey, Text
from sqlalchemy.sql import func
from app.db.relational_db import Base

class Document(Base):
    """文档元数据模型"""
    __tablename__ = "biz_document"

    id = Column(Integer, primary_key=True, autoincrement=True, comment="文档ID")
    document_id = Column(String(64), nullable=False, unique=True, comment="文档唯一ID,对应向量库中的document_id")
    file_name = Column(String(256), nullable=False, comment="原始文件名")
    file_ext = Column(String(16), nullable=False, comment="文件后缀")
    file_size = Column(BigInteger, default=0, comment="文件大小,单位字节")
    file_path = Column(String(512), nullable=False, comment="文件存储路径")
    chunk_count = Column(Integer, default=0, comment="分块数量")
    tenant_id = Column(Integer, ForeignKey("sys_tenant.id"), nullable=False, comment="所属租户ID")
    created_by = Column(Integer, ForeignKey("sys_user.id"), nullable=False, comment="上传人ID")
    status = Column(Integer, default=0, comment="状态:0处理中 1处理完成 2处理失败")
    failed_reason = Column(Text, default="", comment="处理失败原因")
    created_at = Column(DateTime, default=func.now(), comment="创建时间")
    updated_at = Column(DateTime, default=func.now(), onupdate=func.now(), comment="更新时间")

新建app/models/chat.py,定义对话会话与消息模型:

from sqlalchemy import Column, Integer, String, DateTime, ForeignKey, Text, Boolean
from sqlalchemy.sql import func
from app.db.relational_db import Base

class ChatSession(Base):
    """对话会话模型"""
    __tablename__ = "biz_chat_session"

    id = Column(Integer, primary_key=True, autoincrement=True, comment="会话ID")
    session_id = Column(String(64), nullable=False, unique=True, comment="会话唯一ID")
    session_name = Column(String(128), default="新对话", comment="会话名称")
    tenant_id = Column(Integer, ForeignKey("sys_tenant.id"), nullable=False, comment="所属租户ID")
    created_by = Column(Integer, ForeignKey("sys_user.id"), nullable=False, comment="创建人ID")
    document_ids = Column(String(512), default="", comment="关联的文档ID列表,逗号分隔")
    is_deleted = Column(Boolean, default=False, comment="是否删除")
    created_at = Column(DateTime, default=func.now(), comment="创建时间")
    updated_at = Column(DateTime, default=func.now(), onupdate=func.now(), comment="更新时间")

class ChatMessage(Base):
    """对话消息模型,存储对话历史"""
    __tablename__ = "biz_chat_message"

    id = Column(Integer, primary_key=True, autoincrement=True, comment="消息ID")
    session_id = Column(String(64), ForeignKey("biz_chat_session.session_id"), nullable=False, comment="所属会话ID")
    role = Column(String(16), nullable=False, comment="角色:user/assistant")
    content = Column(Text, nullable=False, comment="消息内容")
    tenant_id = Column(Integer, ForeignKey("sys_tenant.id"), nullable=False, comment="所属租户ID")
    user_id = Column(Integer, ForeignKey("sys_user.id"), nullable=False, comment="发送人ID")
    token_usage = Column(Integer, default=0, comment="本次消息消耗的Token数量")
    latency = Column(Integer, default=0, comment="响应耗时,单位毫秒")
    is_deleted = Column(Boolean, default=False, comment="是否删除")
    created_at = Column(DateTime, default=func.now(), comment="创建时间")

3.3 安全工具与JWT鉴权

新建app/utils/security.py,实现密码加密与JWT生成/校验:

from datetime import datetime, timedelta
from typing import Optional
from jose import JWTError, jwt
from passlib.context import CryptContext
from app.config.settings import settings

# 密码加密上下文
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")

def get_password_hash(password: str) -> str:
    """生成密码哈希"""
    return pwd_context.hash(password)

def verify_password(plain_password: str, hashed_password: str) -> bool:
    """验证密码"""
    return pwd_context.verify(plain_password, hashed_password)

def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
    """生成JWT访问令牌"""
    to_encode = data.copy()
    if expires_delta:
        expire = datetime.utcnow() + expires_delta
    else:
        expire = datetime.utcnow() + timedelta(minutes=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES)
    to_encode.update({"exp": expire})
    encoded_jwt = jwt.encode(to_encode, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM)
    return encoded_jwt

def verify_access_token(token: str) -> Optional[dict]:
    """校验JWT令牌,返回令牌中的数据,校验失败返回None"""
    try:
        payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM])
        return payload
    except JWTError:
        return None

新建app/api/deps.py,实现接口依赖项,包括鉴权、租户隔离、数据库会话:

from typing import Annotated, Optional
from fastapi import Depends, Header, HTTPException, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from sqlalchemy.orm import Session
from loguru import logger

from app.db.relational_db import get_db
from app.utils.security import verify_access_token
from app.models.user import User, Tenant
from app.utils.exception import BusinessException

# Bearer Token认证
security = HTTPBearer(auto_error=False)

# 通用类型注解
DBSession = Annotated[Session, Depends(get_db)]
TokenCredentials = Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]

async def get_token_payload(credentials: TokenCredentials) -> dict:
    """获取Token中的载荷数据,校验Token有效性"""
    if not credentials:
        raise BusinessException(code=401, message="请先登录,未提供Token")
    token = credentials.credentials
    payload = verify_access_token(token)
    if not payload:
        raise BusinessException(code=401, message="Token无效或已过期,请重新登录")
    return payload

async def get_current_user(
    payload: Annotated[dict, Depends(get_token_payload)],
    db: DBSession
) -> User:
    """获取当前登录用户"""
    user_id = payload.get("user_id")
    tenant_id = payload.get("tenant_id")
    if not user_id or not tenant_id:
        raise BusinessException(code=401, message="Token数据异常,请重新登录")
    # 查询用户
    user = db.query(User).filter(User.id == user_id, User.tenant_id == tenant_id, User.status == 1).first()
    if not user:
        raise BusinessException(code=401, message="用户不存在或已被禁用")
    # 校验租户状态
    tenant = db.query(Tenant).filter(Tenant.id == tenant_id, Tenant.status == 1).first()
    if not tenant:
        raise BusinessException(code=403, message="所属租户已被禁用")
    return user

# 注解类型,当前登录用户
CurrentUser = Annotated[User, Depends(get_current_user)]

async def check_tenant_permission(
    target_tenant_id: int,
    current_user: CurrentUser
) -> bool:
    """校验租户权限,普通用户只能操作自己租户的资源,超级管理员可以操作所有"""
    if current_user.is_super_admin:
        return True
    if current_user.tenant_id != target_tenant_id:
        raise BusinessException(code=403, message="权限不足,无法操作其他租户的资源")
    return True

3.4 补充配置文件

更新app/config/settings.py,新增本篇所需的配置项:

# 原有配置(第一篇已包含,此处省略)
# ...

# 本篇新增:JWT认证配置
JWT_SECRET_KEY: str = os.getenv("JWT_SECRET_KEY", "your-secret-key-change-it-in-production")
JWT_ALGORITHM: str = os.getenv("JWT_ALGORITHM", "HS256")
JWT_ACCESS_TOKEN_EXPIRE_MINUTES: int = int(os.getenv("JWT_ACCESS_TOKEN_EXPIRE_MINUTES", 1440))

# 本篇新增:文件上传配置
MAX_UPLOAD_FILE_SIZE: int = int(os.getenv("MAX_UPLOAD_FILE_SIZE", 10485760))  # 10MB
ALLOWED_FILE_EXTENSIONS: List[str] = os.getenv("ALLOWED_FILE_EXTENSIONS", "pdf,docx,doc,xlsx,xls,pptx,ppt,txt,md").split(",")

# 本篇新增:超级管理员配置
SUPER_ADMIN_USERNAME: str = os.getenv("SUPER_ADMIN_USERNAME", "admin")
SUPER_ADMIN_PASSWORD: str = os.getenv("SUPER_ADMIN_PASSWORD", "admin123456")

3.5 更新项目启动入口

更新app/main.py,补全完整的项目启动配置,包括中间件、异常处理、路由注册、数据库初始化:

import os
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
from loguru import logger

from app.config.settings import settings
from app.db.relational_db import init_db
from app.utils.exception import (
    global_exception_handler,
    not_found_exception_handler,
    permission_denied_exception_handler
)
from app.api.v1.router import api_v1_router
from app.api.common import common_router

# 初始化限流器
limiter = Limiter(key_func=get_remote_address)

# 应用生命周期管理
@asynccontextmanager
async def lifespan(app: FastAPI):
    """应用启动与关闭时的生命周期管理"""
    logger.info("=== 开始启动RAG知识库系统 ===")
    # 初始化数据库
    init_db()
    # 初始化超级管理员
    from app.service.auth_service import init_super_admin
    init_super_admin()
    # 初始化核心引擎(避免懒加载导致首次请求卡顿)
    from app.service.document_service import document_process_service
    from app.service.chat_service import chat_service
    logger.info("=== RAG知识库系统启动成功 ===")
    yield
    # 应用关闭时的清理工作
    logger.info("=== 开始关闭RAG知识库系统 ===")
    from app.core.embedding.vector_db import get_vector_db_manager
    get_vector_db_manager().close()
    logger.info("=== RAG知识库系统关闭成功 ===")

# 初始化FastAPI应用
app = FastAPI(
    title=settings.PROJECT_NAME,
    description="生产级RAG知识库系统 - 商用级API文档",
    version=settings.PROJECT_VERSION,
    docs_url="/docs",
    redoc_url="/redoc",
    lifespan=lifespan
)

# 绑定限流器
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)

# 注册全局异常处理器
app.add_exception_handler(Exception, global_exception_handler)
app.add_exception_handler(404, not_found_exception_handler)
app.add_exception_handler(403, permission_denied_exception_handler)

# 注册中间件
# Gzip压缩,减少传输体积
app.add_middleware(GZipMiddleware, minimum_size=1000)
# CORS跨域中间件
app.add_middleware(
    CORSMiddleware,
    allow_origins=settings.ALLOWED_ORIGINS,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
    expose_headers=["*"]
)

# 注册路由
app.include_router(common_router, prefix="/api/common", tags=["通用接口"])
app.include_router(api_v1_router, prefix="/api/v1", tags=["V1版本接口"])

# 根路径健康检查
@app.get("/", tags=["系统"], summary="根路径健康检查")
async def root():
    return {
        "status": "healthy",
        "project_name": settings.PROJECT_NAME,
        "version": settings.PROJECT_VERSION,
        "docs": "/docs",
        "redoc": "/redoc"
    }

# 项目启动入口
if __name__ == "__main__":
    import uvicorn
    uvicorn.run(
        "main:app",
        host=settings.SERVER_HOST,
        port=settings.SERVER_PORT,
        reload=settings.DEBUG,
        workers=settings.WORKERS
    )

四、业务服务层补充

接口层只负责参数接收和返回,业务逻辑必须在Service层实现,我们补充前五篇没有的业务逻辑,包括用户认证、文档元数据管理、对话会话管理。

4.1 认证服务

新建app/service/auth_service.py,实现用户认证与初始化逻辑:

from sqlalchemy.orm import Session
from loguru import logger
from datetime import timedelta

from app.db.relational_db import SessionLocal
from app.models.user import User, Tenant
from app.utils.security import get_password_hash, verify_password, create_access_token
from app.config.settings import settings
from app.utils.exception import BusinessException

class AuthService:
    def init_super_admin(self):
        """初始化超级管理员,项目启动时执行"""
        db = SessionLocal()
        try:
            # 检查超级管理员是否已存在
            super_admin = db.query(User).filter(User.username == settings.SUPER_ADMIN_USERNAME, User.is_super_admin == True).first()
            if super_admin:
                logger.info("超级管理员已存在,无需初始化")
                return
            # 创建默认租户
            default_tenant = Tenant(
                tenant_name="默认租户",
                tenant_code="default",
                status=1
            )
            db.add(default_tenant)
            db.flush()
            # 创建超级管理员
            hashed_password = get_password_hash(settings.SUPER_ADMIN_PASSWORD)
            super_admin = User(
                username=settings.SUPER_ADMIN_USERNAME,
                password=hashed_password,
                nickname="超级管理员",
                tenant_id=default_tenant.id,
                is_admin=True,
                is_super_admin=True,
                status=1
            )
            db.add(super_admin)
            db.commit()
            logger.info("超级管理员初始化成功")
        except Exception as e:
            db.rollback()
            logger.error(f"超级管理员初始化失败:{str(e)}", exc_info=True)
        finally:
            db.close()

    def login(self, username: str, password: str, db: Session) -> dict:
        """用户登录,返回Token和用户信息"""
        # 查询用户
        user = db.query(User).filter(User.username == username, User.status == 1).first()
        if not user:
            raise BusinessException(code=400, message="用户名或密码错误")
        # 验证密码
        if not verify_password(password, user.password):
            raise BusinessException(code=400, message="用户名或密码错误")
        # 校验租户状态
        tenant = db.query(Tenant).filter(Tenant.id == user.tenant_id, Tenant.status == 1).first()
        if not tenant:
            raise BusinessException(code=403, message="所属租户已被禁用")
        # 生成Token
        access_token_expires = timedelta(minutes=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES)
        access_token = create_access_token(
            data={
                "user_id": user.id,
                "tenant_id": user.tenant_id,
                "username": user.username,
                "is_admin": user.is_admin,
                "is_super_admin": user.is_super_admin
            },
            expires_delta=access_token_expires
        )
        # 返回用户信息和Token
        return {
            "access_token": access_token,
            "token_type": "bearer",
            "expires_in": settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES * 60,
            "user_info": {
                "id": user.id,
                "username": user.username,
                "nickname": user.nickname,
                "email": user.email,
                "phone": user.phone,
                "tenant_id": user.tenant_id,
                "is_admin": user.is_admin,
                "is_super_admin": user.is_super_admin
            }
        }

# 全局单例
auth_service = AuthService()

# 项目启动时初始化超级管理员的方法
def init_super_admin():
    auth_service.init_super_admin()

4.2 文档服务补充

更新app/service/document_service.py,补充文档元数据管理、列表查询、文件存储逻辑:

# 原有导入(前五篇已包含,此处省略)
# ...
import os
import uuid
from fastapi import UploadFile
from sqlalchemy.orm import Session
from loguru import logger
from app.models.document import Document
from app.config.settings import settings
from app.utils.exception import BusinessException

# 原有代码(前五篇已实现,此处省略)
# ...

class DocumentProcessService:
    # 原有方法(前五篇已实现,此处省略)
    # ...

    def __init__(self):
        # 原有初始化代码
        # ...
        # 新增:文件存储目录
        self.upload_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../data/upload")
        os.makedirs(self.upload_dir, exist_ok=True)

    async def save_upload_file(self, file: UploadFile, tenant_id: int, user_id: int, db: Session) -> Document:
        """保存上传的文件,创建文档元数据"""
        # 1. 校验文件格式
        file_ext = os.path.splitext(file.filename)[1].lower().replace(".", "")
        if file_ext not in settings.ALLOWED_FILE_EXTENSIONS:
            raise BusinessException(code=400, message=f"不支持的文件格式,支持的格式:{','.join(settings.ALLOWED_FILE_EXTENSIONS)}")
        # 2. 校验文件大小
        file_content = await file.read()
        file_size = len(file_content)
        if file_size > settings.MAX_UPLOAD_FILE_SIZE:
            raise BusinessException(code=400, message=f"文件大小超过限制,最大支持{settings.MAX_UPLOAD_FILE_SIZE/1024/1024}MB")
        # 3. 重置文件指针
        await file.seek(0)
        # 4. 生成唯一文档ID
        document_id = uuid.uuid4().hex
        # 5. 生成文件存储路径
        file_name = file.filename
        save_file_name = f"{document_id}_{file_name}"
        save_path = os.path.join(self.upload_dir, save_file_name)
        # 6. 保存文件到本地
        with open(save_path, "wb") as f:
            f.write(file_content)
        # 7. 创建文档元数据
        document = Document(
            document_id=document_id,
            file_name=file_name,
            file_ext=file_ext,
            file_size=file_size,
            file_path=save_path,
            tenant_id=tenant_id,
            created_by=user_id,
            status=0  # 处理中
        )
        db.add(document)
        db.commit()
        db.refresh(document)
        logger.info(f"文件保存成功,文档ID:{document_id},文件名:{file_name}")
        return document

    def get_document_list(self, tenant_id: int, page: int, page_size: int, db: Session, document_name: str = None):
        """获取租户下的文档列表,分页查询"""
        query = db.query(Document).filter(Document.tenant_id == tenant_id, Document.is_deleted == False)
        # 按文件名模糊查询
        if document_name:
            query = query.filter(Document.file_name.like(f"%{document_name}%"))
        # 总条数
        total = query.count()
        # 分页查询
        offset = (page - 1) * page_size
        list = query.order_by(Document.created_at.desc()).offset(offset).limit(page_size).all()
        # 总页数
        total_page = (total + page_size - 1) // page_size
        return {
            "total": total,
            "page": page,
            "page_size": page_size,
            "total_page": total_page,
            "list": list
        }

    def get_document_by_id(self, document_id: str, tenant_id: int, db: Session) -> Document:
        """根据文档ID获取文档信息,校验租户权限"""
        document = db.query(Document).filter(Document.document_id == document_id, Document.tenant_id == tenant_id).first()
        if not document:
            raise BusinessException(code=404, message="文档不存在")
        return document

    def delete_document(self, document_id: str, tenant_id: int, db: Session) -> bool:
        """删除文档,同时删除向量库、关键词索引和本地文件"""
        document = self.get_document_by_id(document_id, tenant_id, db)
        try:
            # 1. 删除向量库中的数据
            self.vector_db.delete_by_document_id(document_id, str(tenant_id))
            # 2. 删除关键词索引中的数据
            self.keyword_retriever.delete_by_document_id(document_id, str(tenant_id))
            # 3. 删除本地文件
            if os.path.exists(document.file_path):
                os.remove(document.file_path)
            # 4. 软删除数据库中的元数据
            document.is_deleted = True
            db.commit()
            logger.info(f"文档删除成功,文档ID:{document_id}")
            return True
        except Exception as e:
            db.rollback()
            logger.error(f"文档删除失败,文档ID:{document_id},错误信息:{str(e)}", exc_info=True)
            raise BusinessException(code=500, message="文档删除失败")

# 全局单例
document_process_service = DocumentProcessService()

4.3 对话服务补充

更新app/service/chat_service.py,补充会话管理、对话历史持久化逻辑:

# 原有导入(前五篇已包含,此处省略)
# ...
import uuid
from sqlalchemy.orm import Session
from app.models.chat import ChatSession, ChatMessage
from app.utils.exception import BusinessException

class ChatService:
    # 原有方法(前五篇已实现,此处省略)
    # ...

    def create_session(self, tenant_id: int, user_id: int, session_name: str = "新对话", document_ids: List[str] = None) -> ChatSession:
        """创建新的对话会话"""
        session_id = uuid.uuid4().hex
        document_ids_str = ",".join(document_ids) if document_ids else ""
        session = ChatSession(
            session_id=session_id,
            session_name=session_name,
            tenant_id=tenant_id,
            created_by=user_id,
            document_ids=document_ids_str
        )
        return session

    def get_session_list(self, tenant_id: int, user_id: int, page: int, page_size: int, db: Session):
        """获取用户的会话列表,分页查询"""
        query = db.query(ChatSession).filter(
            ChatSession.tenant_id == tenant_id,
            ChatSession.created_by == user_id,
            ChatSession.is_deleted == False
        )
        total = query.count()
        offset = (page - 1) * page_size
        list = query.order_by(ChatSession.updated_at.desc()).offset(offset).limit(page_size).all()
        total_page = (total + page_size - 1) // page_size
        return {
            "total": total,
            "page": page,
            "page_size": page_size,
            "total_page": total_page,
            "list": list
        }

    def get_session_by_id(self, session_id: str, tenant_id: int, user_id: int, db: Session) -> ChatSession:
        """获取会话详情,校验权限"""
        session = db.query(ChatSession).filter(
            ChatSession.session_id == session_id,
            ChatSession.tenant_id == tenant_id,
            ChatSession.created_by == user_id,
            ChatSession.is_deleted == False
        ).first()
        if not session:
            raise BusinessException(code=404, message="会话不存在")
        return session

    def get_chat_history(self, session_id: str, tenant_id: int, db: Session) -> List[ChatMessage]:
        """获取会话的对话历史"""
        messages = db.query(ChatMessage).filter(
            ChatMessage.session_id == session_id,
            ChatMessage.tenant_id == tenant_id,
            ChatMessage.is_deleted == False
        ).order_by(ChatMessage.created_at.asc()).all()
        return messages

    def save_chat_message(self, session_id: str, role: str, content: str, tenant_id: int, user_id: int, token_usage: int, latency: int, db: Session) -> ChatMessage:
        """保存对话消息到数据库"""
        message = ChatMessage(
            session_id=session_id,
            role=role,
            content=content,
            tenant_id=tenant_id,
            user_id=user_id,
            token_usage=token_usage,
            latency=latency
        )
        db.add(message)
        db.commit()
        db.refresh(message)
        return message

# 全局单例
chat_service = ChatService()

五、核心业务接口实现

现在我们实现完整的业务接口,分为认证接口、文档管理接口、对话问答接口,所有接口都严格遵循前面的设计规范。

5.1 通用接口

新建app/api/common.py,实现健康检查、系统信息等通用接口:

from fastapi import APIRouter
from slowapi import limiter
from app.schemas.common import ResponseModel

common_router = APIRouter()

@common_router.get("/health", summary="系统健康检查", description="用于监控系统健康状态")
@limiter.limit("100/minute")
async def health_check():
    return ResponseModel.success(data={
        "status": "healthy",
        "timestamp": int(time.time() * 1000)
    })

@common_router.get("/info", summary="系统信息", description="获取系统版本、名称等基础信息")
async def get_system_info():
    from app.config.settings import settings
    return ResponseModel.success(data={
        "project_name": settings.PROJECT_NAME,
        "version": settings.PROJECT_VERSION,
        "allowed_file_extensions": settings.ALLOWED_FILE_EXTENSIONS,
        "max_upload_file_size": settings.MAX_UPLOAD_FILE_SIZE
    })

5.2 V1版本路由总入口

新建app/api/v1/router.py,统一管理V1版本的所有接口路由:

from fastapi import APIRouter
from app.api.v1.auth import auth_router
from app.api.v1.document import document_router
from app.api.v1.chat import chat_router
from app.api.v1.user import user_router

# V1版本路由总入口
api_v1_router = APIRouter()

# 注册子路由
api_v1_router.include_router(auth_router, prefix="/auth", tags=["认证管理"])
api_v1_router.include_router(document_router, prefix="/document", tags=["文档管理"])
api_v1_router.include_router(chat_router, prefix="/chat", tags=["对话问答"])
api_v1_router.include_router(user_router, prefix="/user", tags=["用户管理"])

5.3 认证接口

新建app/api/v1/auth.py,实现登录、Token刷新接口:

from fastapi import APIRouter, Depends
from app.schemas.common import ResponseModel
from app.schemas.auth import LoginRequest, LoginResponse
from app.api.deps import DBSession, get_current_user, CurrentUser
from app.service.auth_service import auth_service

auth_router = APIRouter()

@auth_router.post("/login", summary="用户登录", description="用户登录,获取访问Token", response_model=ResponseModel[LoginResponse])
async def login(data: LoginRequest, db: DBSession):
    result = auth_service.login(data.username, data.password, db)
    return ResponseModel.success(data=result, message="登录成功")

@auth_router.get("/me", summary="获取当前用户信息", description="获取当前登录用户的详细信息")
async def get_current_user_info(current_user: CurrentUser):
    return ResponseModel.success(data={
        "id": current_user.id,
        "username": current_user.username,
        "nickname": current_user.nickname,
        "email": current_user.email,
        "phone": current_user.phone,
        "tenant_id": current_user.tenant_id,
        "is_admin": current_user.is_admin,
        "is_super_admin": current_user.is_super_admin
    })

@auth_router.post("/logout", summary="用户登出", description="用户登出,前端清除Token即可")
async def logout():
    return ResponseModel.success(message="登出成功")

对应的Pydantic模型,新建app/schemas/auth.py

from pydantic import BaseModel, Field
from typing import Optional

class LoginRequest(BaseModel):
    username: str = Field(..., description="用户名", min_length=3, max_length=32)
    password: str = Field(..., description="密码", min_length=6, max_length=32)

class UserInfo(BaseModel):
    id: int
    username: str
    nickname: str
    email: str
    phone: str
    tenant_id: int
    is_admin: bool
    is_super_admin: bool

class LoginResponse(BaseModel):
    access_token: str
    token_type: str
    expires_in: int
    user_info: UserInfo

5.4 文档管理接口

新建app/api/v1/document.py,实现文档上传、列表查询、删除、详情查询接口:

import asyncio
from fastapi import APIRouter, UploadFile, File, Query, Depends, BackgroundTasks
from fastapi.concurrency import run_in_threadpool

from app.schemas.common import ResponseModel, PageQuery, PageResult
from app.schemas.document import DocumentInfoResponse
from app.api.deps import DBSession, CurrentUser, check_tenant_permission
from app.service.document_service import document_process_service
from app.utils.exception import BusinessException

document_router = APIRouter()

@document_router.post("/upload", summary="上传文档", description="上传文档,系统自动进行解析、分块、向量化,存入知识库")
async def upload_document(
    background_tasks: BackgroundTasks,
    file: UploadFile = File(..., description="上传的文档文件"),
    current_user: CurrentUser = Depends(),
    db: DBSession = Depends()
):
    # 1. 保存文件,创建元数据
    document = await document_process_service.save_upload_file(
        file=file,
        tenant_id=current_user.tenant_id,
        user_id=current_user.id,
        db=db
    )
    # 2. 后台异步处理文档,避免阻塞接口
    async def process_document_background():
        try:
            # 在线程池中运行同步的文档处理方法,避免阻塞async事件循环
            await run_in_threadpool(
                document_process_service.process_document,
                file_path=document.file_path,
                tenant_id=str(document.tenant_id),
                upload_user_id=document.created_by,
                chunk_strategy="hybrid",
                chunk_size=500,
                chunk_overlap=50
            )
            # 更新文档状态为处理完成
            document.status = 1
            document.chunk_count = len(chunks)
            db.commit()
        except Exception as e:
            # 更新文档状态为处理失败
            document.status = 2
            document.failed_reason = str(e)
            db.commit()
    # 添加后台任务
    background_tasks.add_task(process_document_background)
    # 立即返回,无需等待处理完成
    return ResponseModel.success(
        data={
            "document_id": document.document_id,
            "file_name": document.file_name,
            "status": document.status,
            "message": "文档上传成功,正在后台处理中,请稍后刷新查看状态"
        },
        message="文档上传成功"
    )

@document_router.get("/list", summary="获取文档列表", description="分页获取当前租户下的文档列表", response_model=ResponseModel[PageResult[DocumentInfoResponse]])
async def get_document_list(
    page_query: PageQuery = Depends(),
    document_name: str = Query(None, description="文档名称,模糊查询"),
    current_user: CurrentUser = Depends(),
    db: DBSession = Depends()
):
    result = document_process_service.get_document_list(
        tenant_id=current_user.tenant_id,
        page=page_query.page,
        page_size=page_query.page_size,
        db=db,
        document_name=document_name
    )
    return ResponseModel.success(data=result)

@document_router.get("/{document_id}", summary="获取文档详情", description="根据文档ID获取文档详细信息")
async def get_document_detail(
    document_id: str,
    current_user: CurrentUser = Depends(),
    db: DBSession = Depends()
):
    document = document_process_service.get_document_by_id(document_id, current_user.tenant_id, db)
    return ResponseModel.success(data=document)

@document_router.delete("/{document_id}", summary="删除文档", description="删除文档,同时删除知识库中的向量数据")
async def delete_document(
    document_id: str,
    current_user: CurrentUser = Depends(),
    db: DBSession = Depends()
):
    document_process_service.delete_document(document_id, current_user.tenant_id, db)
    return ResponseModel.success(message="文档删除成功")

对应的Pydantic模型,新建app/schemas/document.py

from pydantic import BaseModel
from datetime import datetime
from typing import Optional

class DocumentInfoResponse(BaseModel):
    id: int
    document_id: str
    file_name: str
    file_ext: str
    file_size: int
    chunk_count: int
    status: int
    failed_reason: str
    created_by: int
    created_at: datetime
    updated_at: datetime

    class Config:
        from_attributes = True

5.5 对话问答接口(核心重点)

新建app/api/v1/chat.py,实现会话管理、同步问答、流式问答接口,流式接口是重点,采用SSE标准格式,前端可直接对接

import json
import time
from fastapi import APIRouter, Depends, Query
from fastapi.responses import StreamingResponse
from fastapi.concurrency import run_in_threadpool
from typing import List, Optional
from loguru import logger

from app.schemas.common import ResponseModel, PageQuery, PageResult
from app.schemas.chat import ChatRequest, ChatResponse, SessionInfoResponse
from app.api.deps import DBSession, CurrentUser
from app.service.chat_service import chat_service
from app.core.llm.prompt_template import get_prompt_manager

chat_router = APIRouter()
prompt_manager = get_prompt_manager()

@chat_router.post("/session/create", summary="创建对话会话", description="创建新的对话会话,可指定关联的文档")
async def create_chat_session(
    session_name: str = Query("新对话", description="会话名称"),
    document_ids: Optional[str] = Query(None, description="关联的文档ID,多个用逗号分隔"),
    current_user: CurrentUser = Depends(),
    db: DBSession = Depends()
):
    document_id_list = document_ids.split(",") if document_ids else None
    session = chat_service.create_session(
        tenant_id=current_user.tenant_id,
        user_id=current_user.id,
        session_name=session_name,
        document_ids=document_id_list
    )
    db.add(session)
    db.commit()
    db.refresh(session)
    return ResponseModel.success(data=session, message="会话创建成功")

@chat_router.get("/session/list", summary="获取会话列表", description="分页获取当前用户的对话会话列表")
async def get_session_list(
    page_query: PageQuery = Depends(),
    current_user: CurrentUser = Depends(),
    db: DBSession = Depends()
):
    result = chat_service.get_session_list(
        tenant_id=current_user.tenant_id,
        user_id=current_user.id,
        page=page_query.page,
        page_size=page_query.page_size,
        db=db
    )
    return ResponseModel.success(data=result)

@chat_router.get("/session/{session_id}", summary="获取会话详情", description="获取会话的详细信息和对话历史")
async def get_session_detail(
    session_id: str,
    current_user: CurrentUser = Depends(),
    db: DBSession = Depends()
):
    session = chat_service.get_session_by_id(session_id, current_user.tenant_id, current_user.id, db)
    history = chat_service.get_chat_history(session_id, current_user.tenant_id, db)
    return ResponseModel.success(data={
        "session_info": session,
        "chat_history": history
    })

@chat_router.delete("/session/{session_id}", summary="删除会话", description="删除对话会话,软删除")
async def delete_session(
    session_id: str,
    current_user: CurrentUser = Depends(),
    db: DBSession = Depends()
):
    session = chat_service.get_session_by_id(session_id, current_user.tenant_id, current_user.id, db)
    session.is_deleted = True
    db.commit()
    return ResponseModel.success(message="会话删除成功")

@chat_router.post("/sync", summary="同步问答接口", description="非流式问答,等待完整答案生成后返回", response_model=ResponseModel[ChatResponse])
async def chat_sync(
    data: ChatRequest,
    current_user: CurrentUser = Depends(),
    db: DBSession = Depends()
):
    start_time = time.time()
    # 1. 获取会话信息
    session = chat_service.get_session_by_id(data.session_id, current_user.tenant_id, current_user.id, db)
    # 2. 获取对话历史
    history_messages = chat_service.get_chat_history(data.session_id, current_user.tenant_id, db)
    # 3. 格式化对话历史
    formatted_history = prompt_manager.format_chat_history([
        {"role": msg.role, "content": msg.content} for msg in history_messages
    ])
    # 4. 解析关联的文档ID
    document_ids = session.document_ids.split(",") if session.document_ids else None
    # 5. 调用RAG问答服务
    chat_result = await run_in_threadpool(
        chat_service.chat,
        user_query=data.query,
        tenant_id=str(current_user.tenant_id),
        document_ids=document_ids,
        chat_history=formatted_history,
        enable_citation=True
    )
    # 6. 计算耗时
    latency = int((time.time() - start_time) * 1000)
    # 7. 保存用户提问和助手回答到数据库
    chat_service.save_chat_message(
        session_id=data.session_id,
        role="user",
        content=data.query,
        tenant_id=current_user.tenant_id,
        user_id=current_user.id,
        token_usage=0,
        latency=0,
        db=db
    )
    chat_service.save_chat_message(
        session_id=data.session_id,
        role="assistant",
        content=chat_result["answer"],
        tenant_id=current_user.tenant_id,
        user_id=current_user.id,
        token_usage=chat_result["usage"]["total_tokens"],
        latency=latency,
        db=db
    )
    # 8. 更新会话最后更新时间
    session.updated_at = func.now()
    db.commit()
    # 9. 返回结果
    return ResponseModel.success(data=chat_result, message="问答成功")

@chat_router.post("/stream", summary="流式问答接口", description="SSE流式输出,逐字返回答案,前端用EventSource对接")
async def chat_stream(
    data: ChatRequest,
    current_user: CurrentUser = Depends(),
    db: DBSession = Depends()
):
    start_time = time.time()
    # 1. 获取会话信息
    session = chat_service.get_session_by_id(data.session_id, current_user.tenant_id, current_user.id, db)
    # 2. 获取对话历史
    history_messages = chat_service.get_chat_history(data.session_id, current_user.tenant_id, db)
    formatted_history = prompt_manager.format_chat_history([
        {"role": msg.role, "content": msg.content} for msg in history_messages
    ])
    document_ids = session.document_ids.split(",") if session.document_ids else None
    # 3. 保存用户提问
    chat_service.save_chat_message(
        session_id=data.session_id,
        role="user",
        content=data.query,
        tenant_id=current_user.tenant_id,
        user_id=current_user.id,
        token_usage=0,
        latency=0,
        db=db
    )
    # 4. 异步流式生成器
    async def event_generator():
        full_answer = ""
        token_usage = 0
        try:
            # 先发送开始事件
            yield f"data: {json.dumps({'type': 'start', 'message': '开始生成答案'})}\n\n"
            # 获取流式生成器
            stream_generator = await run_in_threadpool(
                chat_service.stream_chat,
                user_query=data.query,
                tenant_id=str(current_user.tenant_id),
                document_ids=document_ids,
                chat_history=formatted_history
            )
            # 逐字返回
            for chunk in stream_generator:
                if chunk.startswith("[ERROR]"):
                    yield f"data: {json.dumps({'type': 'error', 'message': chunk})}\n\n"
                    break
                full_answer += chunk
                yield f"data: {json.dumps({'type': 'chunk', 'content': chunk})}\n\n"
            # 发送结束事件
            latency = int((time.time() - start_time) * 1000)
            yield f"data: {json.dumps({'type': 'end', 'full_answer': full_answer, 'latency': latency})}\n\n"
            # 保存助手回答到数据库
            chat_service.save_chat_message(
                session_id=data.session_id,
                role="assistant",
                content=full_answer,
                tenant_id=current_user.tenant_id,
                user_id=current_user.id,
                token_usage=token_usage,
                latency=latency,
                db=db
            )
            # 更新会话时间
            session.updated_at = func.now()
            db.commit()
        except Exception as e:
            logger.error(f"流式问答失败:{str(e)}", exc_info=True)
            yield f"data: {json.dumps({'type': 'error', 'message': f'生成失败:{str(e)}'})}\n\n"

    # 返回SSE流式响应
    return StreamingResponse(
        event_generator(),
        media_type="text/event-stream",
        headers={
            "Cache-Control": "no-cache",
            "Connection": "keep-alive",
            "X-Accel-Buffering": "no"  # 禁用Nginx缓冲,保证流式输出正常
        }
    )

对应的Pydantic模型,新建app/schemas/chat.py

from pydantic import BaseModel, Field
from datetime import datetime
from typing import List, Optional, Dict, Any

class ChatRequest(BaseModel):
    session_id: str = Field(..., description="会话ID")
    query: str = Field(..., description="用户提问", min_length=1, max_length=2000)
    stream: bool = Field(True, description="是否流式输出")

class CitationSource(BaseModel):
    citation_number: int
    text: str
    document_id: str
    file_name: str
    chunk_id: str
    heading: str

class ChatResponse(BaseModel):
    answer: str
    original_answer: str
    has_reference: bool
    citation_sources: List[CitationSource]
    usage: Dict[str, int]
    latency: float

class SessionInfoResponse(BaseModel):
    id: int
    session_id: str
    session_name: str
    document_ids: str
    created_at: datetime
    updated_at: datetime

    class Config:
        from_attributes = True

六、接口测试与文档

FastAPI自带了自动生成的接口文档,项目启动后,直接访问以下地址即可:

  • Swagger UI交互式文档:http://localhost:8000/docs
  • Redoc文档:http://localhost:8000/redoc

6.1 项目启动步骤

  1. 安装依赖:pip install -r requirements.txt
  2. 复制环境变量文件:cp .env.example .env,修改对应的配置,尤其是PostgreSQL连接信息、JWT密钥
  3. 启动PostgreSQL、Redis、Milvus服务(第三篇的Docker Compose已包含)
  4. 启动项目:python app/main.py
  5. 访问接口文档:http://localhost:8000/docs,即可看到所有接口,直接在页面上测试

6.2 核心接口测试流程

  1. 调用/api/v1/auth/login接口,用默认的超级管理员账号admin/admin123456登录,获取Token
  2. 点击Swagger UI右上角的Authorize,输入Bearer 你的Token,完成鉴权
  3. 调用/api/v1/document/upload接口,上传一个测试文档,比如PDF、Word
  4. 等待文档后台处理完成,调用/api/v1/document/list接口,查看文档状态,确认处理完成
  5. 调用/api/v1/chat/session/create接口,创建对话会话,关联刚才上传的文档ID
  6. 调用/api/v1/chat/sync同步接口,或者/api/v1/chat/stream流式接口,输入问题,即可得到RAG问答结果

七、踩坑记录&避坑指南:新手接口开发必踩的8个大坑

坑1:同步代码阻塞async事件循环,服务卡死

踩坑场景:新手直接在async接口里调用文档处理、LLM调用等耗时的同步方法,导致FastAPI的事件循环被阻塞,所有请求都被卡住,服务直接卡死。
避坑方案:用fastapi.concurrency.run_in_threadpool把同步方法放到线程池中运行,避免阻塞事件循环,本篇的流式接口和文档上传接口都已经做了这个处理。

坑2:流式接口跨域/缓冲问题,前端收不到流式数据

踩坑场景:流式接口在本地测试正常,部署到服务器上,前端用Nginx反向代理后,收不到流式数据,必须等完整内容生成后才一次性返回。
避坑方案:流式响应必须添加X-Accel-Buffering: no响应头,禁用Nginx的缓冲,同时设置Cache-Control: no-cacheConnection: keep-alive,本篇的流式接口已经添加了这些头。

坑3:大文件上传超时/内存溢出

踩坑场景:新手上传大文件时,直接把整个文件读入内存,导致内存溢出,或者接口超时。
避坑方案

  • 限制文件大小,本篇的配置里设置了最大10MB,可根据业务调整
  • 用异步分块读取文件,不要一次性读入内存
  • 大文件用分块上传方案,前端把文件切成小块,后端合并
  • 延长接口超时时间,或者用后台异步处理

坑4:全局异常处理不生效

踩坑场景:新手写了全局异常处理器,但异常还是直接抛出,没有被捕获,返回默认的FastAPI错误格式。
避坑方案

  • 异常处理器必须用app.add_exception_handler注册到FastAPI实例上
  • 自定义异常必须继承自Exception,不能用HTTPException
  • 依赖项中抛出的异常,必须在全局异常处理器中注册对应的处理函数

坑5:租户隔离没做好,越权访问

踩坑场景:新手只在接口里校验了用户登录,没有校验租户权限,A租户的用户可以通过修改文档ID,访问、删除B租户的文档,出现严重的数据泄露。
避坑方案

  • 所有资源操作,必须校验租户ID,普通用户只能操作自己租户的资源
  • 数据库查询时,必须带上tenant_id过滤条件,不能只靠ID查询
  • 本篇的check_tenant_permission依赖项已经实现了权限校验,所有接口必须调用

坑6:接口没有参数校验,非法参数导致服务崩溃

踩坑场景:新手接口的入参没有做校验,用户传入非法的参数,比如超长的字符串、负数的页码,导致数据库查询报错、服务崩溃。
避坑方案:所有入参必须用Pydantic模型做严格校验,包括类型、长度、范围、格式,非法参数在接口层直接拦截,不会传入业务层。

坑7:敏感信息泄露,日志/返回里明文打印密钥、Token、密码

踩坑场景:新手在日志里打印完整的Token、用户密码、API密钥,或者在接口返回里把用户的密码哈希返回给前端,出现严重的敏感信息泄露。
避坑方案

  • 敏感信息必须脱敏,日志里只打印掩码后的内容
  • 用户密码绝对不能返回给前端,哪怕是哈希后的
  • API密钥、JWT密钥绝对不能打印到日志里,也不能硬编码在代码里,必须放在环境变量中

坑8:接口没有限流,被恶意刷接口导致服务崩溃

踩坑场景:接口没有限流,被恶意用户用脚本疯狂刷接口,导致LLM账单爆炸、服务资源耗尽、直接崩溃。
避坑方案:用SlowAPI实现接口限流,本篇已经集成了限流器,核心接口必须添加限流规则,比如登录接口限制10/minute,问答接口限制60/minute


八、下一篇预告

本篇我们完成了生产级RAG系统的标准化后端接口开发,现在我们的系统已经可以对接前端页面、被第三方系统集成,具备了上线交付的基础能力。

下一篇预告:第7篇《用户权限与多租户系统设计,商用必备》,我会手把手带你完成:

  1. 商用级多租户系统的核心设计,租户隔离的3种实现方案
  2. RBAC角色权限体系设计,超级管理员、租户管理员、普通用户的权限管控
  3. 细粒度的资源权限控制,文档、会话的共享与权限分配
  4. 完整的用户、租户、角色管理接口,与本篇的接口无缝衔接
  5. 商用场景的租户计费、配额管理方案

结尾互动

本篇是《30天做一个生产级RAG知识库系统》全系列的第六篇,我们完成了商用级标准化接口的全量开发,现在你的RAG系统已经从一个本地demo,变成了一个可交付、可集成、可上线的商用服务。

最后想问一下大家:

  • 你在开发后端接口的时候,遇到过最头疼的问题是什么?是流式接口对接、权限管控,还是服务性能问题?
  • 你在对接前端的时候,遇到过哪些跨域、数据格式的坑?

欢迎在评论区留言,我会在后续的文章中,针对性地给你解决方案!

如果觉得这个系列对你有帮助,欢迎点赞、收藏、关注,后续的文章会第一时间推送给你,跟着更完,就能上线属于你的商用级RAG系统!

Logo

腾讯云面向开发者汇聚海量精品云计算使用和开发经验,营造开放的云计算技术生态圈。

更多推荐