【30天做一个生产级RAG知识库系统】第6篇:后端接口开发,基于FastAPI实现标准化商用接口
前五篇我们完成了生产级RAG系统的全链路核心能力建设:到这里,我们已经实现了端到端的RAG问答核心能力,但它还只是一个跑在本地的代码工程,无法对接前端页面、无法被第三方系统集成、无法做权限管控、无法上线交付给用户使用。而这一切的核心桥梁,就是标准化的后端接口。很多新手在这一步最容易犯的错:一个核心结论必须记死:商用级接口的核心,从来不是“能跑通”,而是稳定、安全、规范、可维护、可观测。一个不合格的
从核心能力到可交付服务,保姆级实现商用级RESTful接口,覆盖文档管理、流式对话、权限鉴权全场景,附完整工程化代码、接口规范、踩坑避坑指南
前言:从demo到商用,接口是必经的核心桥梁
前五篇我们完成了生产级RAG系统的全链路核心能力建设:
- 第1篇:完成了分层架构设计与项目骨架搭建
- 第2篇:实现了全格式文档预处理与智能分块
- 第3篇:完成了Embedding模型选型与Milvus向量库搭建
- 第4篇:优化了多路召回+重排序的全链路检索引擎
- 第5篇:实现了体系化Prompt工程、LLM调用封装与幻觉抑制方案
到这里,我们已经实现了端到端的RAG问答核心能力,但它还只是一个跑在本地的代码工程,无法对接前端页面、无法被第三方系统集成、无法做权限管控、无法上线交付给用户使用。
而这一切的核心桥梁,就是标准化的后端接口。
很多新手在这一步最容易犯的错:
- 随便写几个裸接口,没有任何参数校验、异常处理,非法参数一进来服务直接崩溃
- 把业务逻辑全写在接口里,代码臃肿混乱,后期维护成本极高
- 没有鉴权、没有租户隔离,谁都能调用,A租户能删B租户的文档,出现严重的数据泄露
- 不做统一响应格式,前端对接要为每个接口写单独的处理逻辑,对接成本极高
- 流式对话接口实现错误,要么阻塞服务,要么前端无法正常解析
- 没有接口文档,前后端联调全靠口口相传,效率极低
- 没有日志、没有监控,接口出了问题根本找不到原因
一个核心结论必须记死:
商用级接口的核心,从来不是“能跑通”,而是稳定、安全、规范、可维护、可观测。一个不合格的接口,会让你前面五篇做的所有核心能力都白费,上线后只会收到用户的投诉和bug反馈。
本篇我们就严格遵循商用级开发规范,基于FastAPI实现完整的标准化后端接口,完全对齐前五篇的项目架构,所有代码可直接复制运行,跟着做完,你的RAG系统就能直接对接前端、上线交付。
一、先立规矩:商用级后端接口的7大设计铁律
在写代码之前,我们必须先明确商用级接口的设计原则,这是保证接口质量的核心,新手必须严格遵守:
1. 分层解耦原则
接口层绝对不能写业务逻辑,只负责3件事:
- 接收请求参数,做格式与合法性校验
- 调用对应Service层的业务逻辑
- 封装处理结果,统一格式返回给前端
业务逻辑100%收敛到Service层,数据操作100%收敛到DB层,层与层之间通过标准化接口交互,修改某一层的代码,不会影响其他层。
2. RESTful规范原则
严格遵循RESTful API设计规范,用正确的HTTP方法表达语义:
GET:查询资源,不修改任何数据POST:创建资源,比如上传文档、创建会话PUT:全量更新资源DELETE:删除资源
URL语义清晰,版本化管理,比如/api/v1/document,后续升级不影响老调用方。
3. 统一响应原则
所有接口必须返回完全一致的JSON格式,无论成功还是失败,前端只需要一套逻辑就能处理所有接口的响应,大幅降低对接成本。
统一响应格式如下:
{
"code": 200, // 状态码:200成功,非200失败
"message": "success", // 提示信息:成功返回success,失败返回错误原因
"data": {} // 数据体:成功返回对应数据,失败返回null
}
4. 强校验原则
所有入参必须做严格校验,包括:
- 类型校验:字符串、数字、布尔值严格区分
- 格式校验:邮箱、手机号、ID格式、文件格式
- 范围校验:数值范围、长度限制、文件大小限制
- 合法性校验:租户权限、资源归属、越权操作拦截
非法参数必须在接口层直接拦截,绝对不能传入Service层,避免服务崩溃。
5. 安全优先原则
商用接口的安全是底线,必须做到:
- 所有业务接口必须鉴权,禁止裸接口对外暴露
- 严格的租户隔离,每个用户只能操作自己租户下的资源
- 敏感数据脱敏,密钥、Token、密码绝对不能明文返回或打印日志
- 接口限流,防止恶意刷接口、CC攻击
- 大文件上传格式与大小限制,防止恶意文件上传
6. 可观测性原则
所有接口必须有完整的日志记录,包括:
- 请求唯一ID,全链路追踪
- 请求用户、租户ID、接口地址、HTTP方法
- 请求入参(敏感信息脱敏)
- 响应状态码、响应耗时
- 异常信息、堆栈跟踪(仅开发环境返回,生产环境不暴露)
出问题能通过日志快速定位、快速复盘。
7. 兼容性原则
接口必须做版本控制,/api/v1/、/api/v2/,升级接口时必须保证向下兼容,绝对不能修改已上线接口的入参、出参结构,避免导致前端或第三方系统崩溃。
二、前置准备:项目结构补全与依赖更新
我们完全对齐第一篇的项目骨架,补全接口开发所需的目录和文件,保证整个项目的结构一致性。
2.1 最终项目目录结构(接口开发补全版)
production-rag-system/
├── app/
│ ├── api/ # 接口层核心目录
│ │ ├── deps.py # 接口依赖项:鉴权、参数校验、租户隔离
│ │ ├── common.py # 通用接口:健康检查、系统信息
│ │ └── v1/ # v1版本接口,后续升级可新增v2
│ │ ├── router.py # v1接口路由总入口
│ │ ├── auth.py # 认证接口:登录、注册、Token刷新
│ │ ├── document.py # 文档管理接口:上传、删除、列表查询
│ │ ├── chat.py # 对话问答接口:同步/流式问答、会话管理
│ │ └── user.py # 用户/租户基础接口
│ ├── core/ # 核心引擎层(前五篇已实现)
│ ├── service/ # 业务服务层(前五篇已实现,本次补充)
│ ├── models/ # ORM数据模型:SQLAlchemy表结构定义
│ │ ├── __init__.py
│ │ ├── user.py # 用户/租户模型
│ │ ├── document.py # 文档元数据模型
│ │ └── chat.py # 对话会话/历史模型
│ ├── schemas/ # Pydantic模型:入参/出参校验、类型定义
│ │ ├── __init__.py
│ │ ├── common.py # 通用响应模型
│ │ ├── auth.py # 认证相关入参/出参
│ │ ├── document.py # 文档相关入参/出参
│ │ └── chat.py # 对话相关入参/出参
│ ├── db/ # 数据库连接层(第一篇已实现,本次补充)
│ │ ├── __init__.py
│ │ ├── relational_db.py # PostgreSQL连接、会话管理
│ │ ├── vector_db.py # Milvus连接(第三篇已实现)
│ │ └── cache.py # Redis连接(第一篇已实现)
│ ├── config/ # 配置管理(第一篇已实现)
│ ├── utils/ # 通用工具函数:JWT、加密、脱敏
│ │ ├── __init__.py
│ │ ├── security.py # 密码加密、JWT生成/校验
│ │ └── logger.py # 日志工具
│ └── main.py # 项目启动入口(本次补充完整)
├── tests/ # 单元测试
├── deploy/ # 部署配置
├── logs/ # 日志文件
├── .env.example # 环境变量示例
├── requirements.txt # 依赖包(本次补充)
└── README.md
2.2 补充依赖包
更新requirements.txt,新增接口开发所需的依赖:
# 原有依赖(前五篇已包含,此处省略)
# ...
# 本篇新增:FastAPI接口开发核心依赖
fastapi>=0.115.0
uvicorn[standard]>=0.30.0
pydantic>=2.9.0
pydantic-settings>=2.5.0
# 本篇新增:ORM与数据库
sqlalchemy>=2.0.30
psycopg2-binary>=2.9.9
alembic>=1.13.0
# 本篇新增:认证与安全
python-jose[cryptography]>=3.3.0
passlib[bcrypt]>=1.7.4
python-multipart>=0.0.12
# 本篇新增:接口限流
slowapi>=0.1.9
limits>=3.12.0
# 本篇新增:工具类
python-dotenv>=1.0.0
loguru>=0.7.2
2.3 环境变量补充
更新.env.example,新增认证、数据库相关配置:
# 原有配置(前五篇已包含,此处省略)
# ...
# 本篇新增:JWT认证配置
JWT_SECRET_KEY=你的JWT密钥,用随机字符串生成,至少32位
JWT_ALGORITHM=HS256
JWT_ACCESS_TOKEN_EXPIRE_MINUTES=1440 # Token有效期24小时
# 本篇新增:文件上传配置
MAX_UPLOAD_FILE_SIZE=10485760 # 最大文件大小10MB
ALLOWED_FILE_EXTENSIONS=pdf,docx,doc,xlsx,xls,pptx,ppt,txt,md
# 本篇新增:超级管理员配置
SUPER_ADMIN_USERNAME=admin
SUPER_ADMIN_PASSWORD=admin123456
三、核心基础模块实现
在写业务接口之前,我们先把通用的基础模块实现,包括统一响应模型、全局异常处理、ORM模型、Pydantic校验模型、JWT鉴权依赖项,这些是所有接口的基础。
3.1 通用响应模型与全局异常处理
新建app/schemas/common.py,定义统一响应模型:
from typing import Generic, TypeVar, Optional
from pydantic import BaseModel, Field
# 泛型类型,支持任意数据类型
T = TypeVar("T")
class ResponseModel(BaseModel, Generic[T]):
"""统一响应模型"""
code: int = Field(200, description="状态码,200为成功,非200为失败")
message: str = Field("success", description="响应信息,成功为success,失败为错误原因")
data: Optional[T] = Field(None, description="响应数据体")
@classmethod
def success(cls, data: T = None, message: str = "success") -> "ResponseModel[T]":
"""成功响应"""
return cls(code=200, message=message, data=data)
@classmethod
def error(cls, code: int = 400, message: str = "请求失败", data: T = None) -> "ResponseModel[T]":
"""失败响应"""
return cls(code=code, message=message, data=data)
# 通用分页请求模型
class PageQuery(BaseModel):
page: int = Field(1, ge=1, description="页码,从1开始")
page_size: int = Field(10, ge=1, le=100, description="每页条数,1-100")
# 通用分页响应模型
class PageResult(BaseModel, Generic[T]):
total: int = Field(description="总条数")
page: int = Field(description="当前页码")
page_size: int = Field(description="每页条数")
total_page: int = Field(description="总页数")
list: List[T] = Field(description="数据列表")
新建app/utils/exception.py,定义全局自定义异常与异常处理器:
from fastapi import Request, status
from fastapi.responses import JSONResponse
from loguru import logger
from app.schemas.common import ResponseModel
# 自定义业务异常
class BusinessException(Exception):
def __init__(self, code: int = 400, message: str = "业务处理失败"):
self.code = code
self.message = message
# 全局异常处理器
async def global_exception_handler(request: Request, exc: Exception):
"""全局异常捕获,统一返回格式"""
# 生成请求唯一ID,用于日志追踪
request_id = request.headers.get("X-Request-ID", "unknown")
# 获取请求路径和方法
path = request.url.path
method = request.method
# 自定义业务异常
if isinstance(exc, BusinessException):
logger.warning(f"[{request_id}] 业务异常:{method} {path} | code={exc.code} | message={exc.message}")
return JSONResponse(
status_code=200,
content=ResponseModel.error(code=exc.code, message=exc.message).model_dump()
)
# 其他未知异常
logger.error(f"[{request_id}] 系统异常:{method} {path} | 异常信息:{str(exc)}", exc_info=True)
return JSONResponse(
status_code=200,
content=ResponseModel.error(code=500, message="系统内部错误,请联系管理员").model_dump()
)
# 404异常处理器
async def not_found_exception_handler(request: Request, exc):
return JSONResponse(
status_code=200,
content=ResponseModel.error(code=404, message="请求的接口不存在").model_dump()
)
# 权限异常处理器
async def permission_denied_exception_handler(request: Request, exc):
return JSONResponse(
status_code=200,
content=ResponseModel.error(code=403, message="权限不足,无法访问该资源").model_dump()
)
3.2 ORM数据模型与数据库连接
更新app/db/relational_db.py,完善PostgreSQL连接与SQLAlchemy会话管理:
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, Session
from loguru import logger
from app.config.settings import settings
# 构建数据库连接URL
SQLALCHEMY_DATABASE_URL = f"postgresql://{settings.DB_USER}:{settings.DB_PASSWORD}@{settings.DB_HOST}:{settings.DB_PORT}/{settings.DB_NAME}"
# 创建数据库引擎
engine = create_engine(
SQLALCHEMY_DATABASE_URL,
pool_pre_ping=True, # 自动检测失效连接
pool_recycle=3600, # 连接回收时间
pool_size=20, # 连接池大小
max_overflow=10 # 最大溢出连接数
)
# 会话工厂
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# 基类,所有ORM模型都继承这个类
Base = declarative_base()
# 数据库依赖项,用于接口中获取数据库会话
def get_db() -> Session:
db = SessionLocal()
try:
yield db
finally:
db.close()
# 初始化数据库表
def init_db():
"""项目启动时初始化数据库表"""
try:
# 导入所有模型,确保创建表
from app.models.user import User, Tenant
from app.models.document import Document
from app.models.chat import ChatSession, ChatMessage
# 创建所有表
Base.metadata.create_all(bind=engine)
logger.info("数据库表初始化成功")
except Exception as e:
logger.error(f"数据库表初始化失败:{str(e)}", exc_info=True)
raise e
新建app/models/user.py,定义用户与租户ORM模型:
from sqlalchemy import Column, Integer, String, Boolean, DateTime, ForeignKey
from sqlalchemy.sql import func
from app.db.relational_db import Base
class Tenant(Base):
"""租户模型,多租户隔离核心"""
__tablename__ = "sys_tenant"
id = Column(Integer, primary_key=True, autoincrement=True, comment="租户ID")
tenant_name = Column(String(64), nullable=False, unique=True, comment="租户名称")
tenant_code = Column(String(32), nullable=False, unique=True, comment="租户编码")
status = Column(Integer, default=1, comment="状态:1启用 0禁用")
created_at = Column(DateTime, default=func.now(), comment="创建时间")
updated_at = Column(DateTime, default=func.now(), onupdate=func.now(), comment="更新时间")
class User(Base):
"""用户模型"""
__tablename__ = "sys_user"
id = Column(Integer, primary_key=True, autoincrement=True, comment="用户ID")
username = Column(String(32), nullable=False, unique=True, comment="用户名")
password = Column(String(128), nullable=False, comment="加密后的密码")
nickname = Column(String(64), default="", comment="用户昵称")
email = Column(String(64), default="", comment="邮箱")
phone = Column(String(16), default="", comment="手机号")
tenant_id = Column(Integer, ForeignKey("sys_tenant.id"), nullable=False, comment="所属租户ID")
is_admin = Column(Boolean, default=False, comment="是否为租户管理员")
is_super_admin = Column(Boolean, default=False, comment="是否为超级管理员")
status = Column(Integer, default=1, comment="状态:1启用 0禁用")
created_at = Column(DateTime, default=func.now(), comment="创建时间")
updated_at = Column(DateTime, default=func.now(), onupdate=func.now(), comment="更新时间")
新建app/models/document.py,定义文档元数据模型:
from sqlalchemy import Column, Integer, String, BigInteger, DateTime, ForeignKey, Text
from sqlalchemy.sql import func
from app.db.relational_db import Base
class Document(Base):
"""文档元数据模型"""
__tablename__ = "biz_document"
id = Column(Integer, primary_key=True, autoincrement=True, comment="文档ID")
document_id = Column(String(64), nullable=False, unique=True, comment="文档唯一ID,对应向量库中的document_id")
file_name = Column(String(256), nullable=False, comment="原始文件名")
file_ext = Column(String(16), nullable=False, comment="文件后缀")
file_size = Column(BigInteger, default=0, comment="文件大小,单位字节")
file_path = Column(String(512), nullable=False, comment="文件存储路径")
chunk_count = Column(Integer, default=0, comment="分块数量")
tenant_id = Column(Integer, ForeignKey("sys_tenant.id"), nullable=False, comment="所属租户ID")
created_by = Column(Integer, ForeignKey("sys_user.id"), nullable=False, comment="上传人ID")
status = Column(Integer, default=0, comment="状态:0处理中 1处理完成 2处理失败")
failed_reason = Column(Text, default="", comment="处理失败原因")
created_at = Column(DateTime, default=func.now(), comment="创建时间")
updated_at = Column(DateTime, default=func.now(), onupdate=func.now(), comment="更新时间")
新建app/models/chat.py,定义对话会话与消息模型:
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey, Text, Boolean
from sqlalchemy.sql import func
from app.db.relational_db import Base
class ChatSession(Base):
"""对话会话模型"""
__tablename__ = "biz_chat_session"
id = Column(Integer, primary_key=True, autoincrement=True, comment="会话ID")
session_id = Column(String(64), nullable=False, unique=True, comment="会话唯一ID")
session_name = Column(String(128), default="新对话", comment="会话名称")
tenant_id = Column(Integer, ForeignKey("sys_tenant.id"), nullable=False, comment="所属租户ID")
created_by = Column(Integer, ForeignKey("sys_user.id"), nullable=False, comment="创建人ID")
document_ids = Column(String(512), default="", comment="关联的文档ID列表,逗号分隔")
is_deleted = Column(Boolean, default=False, comment="是否删除")
created_at = Column(DateTime, default=func.now(), comment="创建时间")
updated_at = Column(DateTime, default=func.now(), onupdate=func.now(), comment="更新时间")
class ChatMessage(Base):
"""对话消息模型,存储对话历史"""
__tablename__ = "biz_chat_message"
id = Column(Integer, primary_key=True, autoincrement=True, comment="消息ID")
session_id = Column(String(64), ForeignKey("biz_chat_session.session_id"), nullable=False, comment="所属会话ID")
role = Column(String(16), nullable=False, comment="角色:user/assistant")
content = Column(Text, nullable=False, comment="消息内容")
tenant_id = Column(Integer, ForeignKey("sys_tenant.id"), nullable=False, comment="所属租户ID")
user_id = Column(Integer, ForeignKey("sys_user.id"), nullable=False, comment="发送人ID")
token_usage = Column(Integer, default=0, comment="本次消息消耗的Token数量")
latency = Column(Integer, default=0, comment="响应耗时,单位毫秒")
is_deleted = Column(Boolean, default=False, comment="是否删除")
created_at = Column(DateTime, default=func.now(), comment="创建时间")
3.3 安全工具与JWT鉴权
新建app/utils/security.py,实现密码加密与JWT生成/校验:
from datetime import datetime, timedelta
from typing import Optional
from jose import JWTError, jwt
from passlib.context import CryptContext
from app.config.settings import settings
# 密码加密上下文
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def get_password_hash(password: str) -> str:
"""生成密码哈希"""
return pwd_context.hash(password)
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""验证密码"""
return pwd_context.verify(plain_password, hashed_password)
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
"""生成JWT访问令牌"""
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM)
return encoded_jwt
def verify_access_token(token: str) -> Optional[dict]:
"""校验JWT令牌,返回令牌中的数据,校验失败返回None"""
try:
payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM])
return payload
except JWTError:
return None
新建app/api/deps.py,实现接口依赖项,包括鉴权、租户隔离、数据库会话:
from typing import Annotated, Optional
from fastapi import Depends, Header, HTTPException, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from sqlalchemy.orm import Session
from loguru import logger
from app.db.relational_db import get_db
from app.utils.security import verify_access_token
from app.models.user import User, Tenant
from app.utils.exception import BusinessException
# Bearer Token认证
security = HTTPBearer(auto_error=False)
# 通用类型注解
DBSession = Annotated[Session, Depends(get_db)]
TokenCredentials = Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]
async def get_token_payload(credentials: TokenCredentials) -> dict:
"""获取Token中的载荷数据,校验Token有效性"""
if not credentials:
raise BusinessException(code=401, message="请先登录,未提供Token")
token = credentials.credentials
payload = verify_access_token(token)
if not payload:
raise BusinessException(code=401, message="Token无效或已过期,请重新登录")
return payload
async def get_current_user(
payload: Annotated[dict, Depends(get_token_payload)],
db: DBSession
) -> User:
"""获取当前登录用户"""
user_id = payload.get("user_id")
tenant_id = payload.get("tenant_id")
if not user_id or not tenant_id:
raise BusinessException(code=401, message="Token数据异常,请重新登录")
# 查询用户
user = db.query(User).filter(User.id == user_id, User.tenant_id == tenant_id, User.status == 1).first()
if not user:
raise BusinessException(code=401, message="用户不存在或已被禁用")
# 校验租户状态
tenant = db.query(Tenant).filter(Tenant.id == tenant_id, Tenant.status == 1).first()
if not tenant:
raise BusinessException(code=403, message="所属租户已被禁用")
return user
# 注解类型,当前登录用户
CurrentUser = Annotated[User, Depends(get_current_user)]
async def check_tenant_permission(
target_tenant_id: int,
current_user: CurrentUser
) -> bool:
"""校验租户权限,普通用户只能操作自己租户的资源,超级管理员可以操作所有"""
if current_user.is_super_admin:
return True
if current_user.tenant_id != target_tenant_id:
raise BusinessException(code=403, message="权限不足,无法操作其他租户的资源")
return True
3.4 补充配置文件
更新app/config/settings.py,新增本篇所需的配置项:
# 原有配置(第一篇已包含,此处省略)
# ...
# 本篇新增:JWT认证配置
JWT_SECRET_KEY: str = os.getenv("JWT_SECRET_KEY", "your-secret-key-change-it-in-production")
JWT_ALGORITHM: str = os.getenv("JWT_ALGORITHM", "HS256")
JWT_ACCESS_TOKEN_EXPIRE_MINUTES: int = int(os.getenv("JWT_ACCESS_TOKEN_EXPIRE_MINUTES", 1440))
# 本篇新增:文件上传配置
MAX_UPLOAD_FILE_SIZE: int = int(os.getenv("MAX_UPLOAD_FILE_SIZE", 10485760)) # 10MB
ALLOWED_FILE_EXTENSIONS: List[str] = os.getenv("ALLOWED_FILE_EXTENSIONS", "pdf,docx,doc,xlsx,xls,pptx,ppt,txt,md").split(",")
# 本篇新增:超级管理员配置
SUPER_ADMIN_USERNAME: str = os.getenv("SUPER_ADMIN_USERNAME", "admin")
SUPER_ADMIN_PASSWORD: str = os.getenv("SUPER_ADMIN_PASSWORD", "admin123456")
3.5 更新项目启动入口
更新app/main.py,补全完整的项目启动配置,包括中间件、异常处理、路由注册、数据库初始化:
import os
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
from loguru import logger
from app.config.settings import settings
from app.db.relational_db import init_db
from app.utils.exception import (
global_exception_handler,
not_found_exception_handler,
permission_denied_exception_handler
)
from app.api.v1.router import api_v1_router
from app.api.common import common_router
# 初始化限流器
limiter = Limiter(key_func=get_remote_address)
# 应用生命周期管理
@asynccontextmanager
async def lifespan(app: FastAPI):
"""应用启动与关闭时的生命周期管理"""
logger.info("=== 开始启动RAG知识库系统 ===")
# 初始化数据库
init_db()
# 初始化超级管理员
from app.service.auth_service import init_super_admin
init_super_admin()
# 初始化核心引擎(避免懒加载导致首次请求卡顿)
from app.service.document_service import document_process_service
from app.service.chat_service import chat_service
logger.info("=== RAG知识库系统启动成功 ===")
yield
# 应用关闭时的清理工作
logger.info("=== 开始关闭RAG知识库系统 ===")
from app.core.embedding.vector_db import get_vector_db_manager
get_vector_db_manager().close()
logger.info("=== RAG知识库系统关闭成功 ===")
# 初始化FastAPI应用
app = FastAPI(
title=settings.PROJECT_NAME,
description="生产级RAG知识库系统 - 商用级API文档",
version=settings.PROJECT_VERSION,
docs_url="/docs",
redoc_url="/redoc",
lifespan=lifespan
)
# 绑定限流器
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
# 注册全局异常处理器
app.add_exception_handler(Exception, global_exception_handler)
app.add_exception_handler(404, not_found_exception_handler)
app.add_exception_handler(403, permission_denied_exception_handler)
# 注册中间件
# Gzip压缩,减少传输体积
app.add_middleware(GZipMiddleware, minimum_size=1000)
# CORS跨域中间件
app.add_middleware(
CORSMiddleware,
allow_origins=settings.ALLOWED_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
expose_headers=["*"]
)
# 注册路由
app.include_router(common_router, prefix="/api/common", tags=["通用接口"])
app.include_router(api_v1_router, prefix="/api/v1", tags=["V1版本接口"])
# 根路径健康检查
@app.get("/", tags=["系统"], summary="根路径健康检查")
async def root():
return {
"status": "healthy",
"project_name": settings.PROJECT_NAME,
"version": settings.PROJECT_VERSION,
"docs": "/docs",
"redoc": "/redoc"
}
# 项目启动入口
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"main:app",
host=settings.SERVER_HOST,
port=settings.SERVER_PORT,
reload=settings.DEBUG,
workers=settings.WORKERS
)
四、业务服务层补充
接口层只负责参数接收和返回,业务逻辑必须在Service层实现,我们补充前五篇没有的业务逻辑,包括用户认证、文档元数据管理、对话会话管理。
4.1 认证服务
新建app/service/auth_service.py,实现用户认证与初始化逻辑:
from sqlalchemy.orm import Session
from loguru import logger
from datetime import timedelta
from app.db.relational_db import SessionLocal
from app.models.user import User, Tenant
from app.utils.security import get_password_hash, verify_password, create_access_token
from app.config.settings import settings
from app.utils.exception import BusinessException
class AuthService:
def init_super_admin(self):
"""初始化超级管理员,项目启动时执行"""
db = SessionLocal()
try:
# 检查超级管理员是否已存在
super_admin = db.query(User).filter(User.username == settings.SUPER_ADMIN_USERNAME, User.is_super_admin == True).first()
if super_admin:
logger.info("超级管理员已存在,无需初始化")
return
# 创建默认租户
default_tenant = Tenant(
tenant_name="默认租户",
tenant_code="default",
status=1
)
db.add(default_tenant)
db.flush()
# 创建超级管理员
hashed_password = get_password_hash(settings.SUPER_ADMIN_PASSWORD)
super_admin = User(
username=settings.SUPER_ADMIN_USERNAME,
password=hashed_password,
nickname="超级管理员",
tenant_id=default_tenant.id,
is_admin=True,
is_super_admin=True,
status=1
)
db.add(super_admin)
db.commit()
logger.info("超级管理员初始化成功")
except Exception as e:
db.rollback()
logger.error(f"超级管理员初始化失败:{str(e)}", exc_info=True)
finally:
db.close()
def login(self, username: str, password: str, db: Session) -> dict:
"""用户登录,返回Token和用户信息"""
# 查询用户
user = db.query(User).filter(User.username == username, User.status == 1).first()
if not user:
raise BusinessException(code=400, message="用户名或密码错误")
# 验证密码
if not verify_password(password, user.password):
raise BusinessException(code=400, message="用户名或密码错误")
# 校验租户状态
tenant = db.query(Tenant).filter(Tenant.id == user.tenant_id, Tenant.status == 1).first()
if not tenant:
raise BusinessException(code=403, message="所属租户已被禁用")
# 生成Token
access_token_expires = timedelta(minutes=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_access_token(
data={
"user_id": user.id,
"tenant_id": user.tenant_id,
"username": user.username,
"is_admin": user.is_admin,
"is_super_admin": user.is_super_admin
},
expires_delta=access_token_expires
)
# 返回用户信息和Token
return {
"access_token": access_token,
"token_type": "bearer",
"expires_in": settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES * 60,
"user_info": {
"id": user.id,
"username": user.username,
"nickname": user.nickname,
"email": user.email,
"phone": user.phone,
"tenant_id": user.tenant_id,
"is_admin": user.is_admin,
"is_super_admin": user.is_super_admin
}
}
# 全局单例
auth_service = AuthService()
# 项目启动时初始化超级管理员的方法
def init_super_admin():
auth_service.init_super_admin()
4.2 文档服务补充
更新app/service/document_service.py,补充文档元数据管理、列表查询、文件存储逻辑:
# 原有导入(前五篇已包含,此处省略)
# ...
import os
import uuid
from fastapi import UploadFile
from sqlalchemy.orm import Session
from loguru import logger
from app.models.document import Document
from app.config.settings import settings
from app.utils.exception import BusinessException
# 原有代码(前五篇已实现,此处省略)
# ...
class DocumentProcessService:
# 原有方法(前五篇已实现,此处省略)
# ...
def __init__(self):
# 原有初始化代码
# ...
# 新增:文件存储目录
self.upload_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../data/upload")
os.makedirs(self.upload_dir, exist_ok=True)
async def save_upload_file(self, file: UploadFile, tenant_id: int, user_id: int, db: Session) -> Document:
"""保存上传的文件,创建文档元数据"""
# 1. 校验文件格式
file_ext = os.path.splitext(file.filename)[1].lower().replace(".", "")
if file_ext not in settings.ALLOWED_FILE_EXTENSIONS:
raise BusinessException(code=400, message=f"不支持的文件格式,支持的格式:{','.join(settings.ALLOWED_FILE_EXTENSIONS)}")
# 2. 校验文件大小
file_content = await file.read()
file_size = len(file_content)
if file_size > settings.MAX_UPLOAD_FILE_SIZE:
raise BusinessException(code=400, message=f"文件大小超过限制,最大支持{settings.MAX_UPLOAD_FILE_SIZE/1024/1024}MB")
# 3. 重置文件指针
await file.seek(0)
# 4. 生成唯一文档ID
document_id = uuid.uuid4().hex
# 5. 生成文件存储路径
file_name = file.filename
save_file_name = f"{document_id}_{file_name}"
save_path = os.path.join(self.upload_dir, save_file_name)
# 6. 保存文件到本地
with open(save_path, "wb") as f:
f.write(file_content)
# 7. 创建文档元数据
document = Document(
document_id=document_id,
file_name=file_name,
file_ext=file_ext,
file_size=file_size,
file_path=save_path,
tenant_id=tenant_id,
created_by=user_id,
status=0 # 处理中
)
db.add(document)
db.commit()
db.refresh(document)
logger.info(f"文件保存成功,文档ID:{document_id},文件名:{file_name}")
return document
def get_document_list(self, tenant_id: int, page: int, page_size: int, db: Session, document_name: str = None):
"""获取租户下的文档列表,分页查询"""
query = db.query(Document).filter(Document.tenant_id == tenant_id, Document.is_deleted == False)
# 按文件名模糊查询
if document_name:
query = query.filter(Document.file_name.like(f"%{document_name}%"))
# 总条数
total = query.count()
# 分页查询
offset = (page - 1) * page_size
list = query.order_by(Document.created_at.desc()).offset(offset).limit(page_size).all()
# 总页数
total_page = (total + page_size - 1) // page_size
return {
"total": total,
"page": page,
"page_size": page_size,
"total_page": total_page,
"list": list
}
def get_document_by_id(self, document_id: str, tenant_id: int, db: Session) -> Document:
"""根据文档ID获取文档信息,校验租户权限"""
document = db.query(Document).filter(Document.document_id == document_id, Document.tenant_id == tenant_id).first()
if not document:
raise BusinessException(code=404, message="文档不存在")
return document
def delete_document(self, document_id: str, tenant_id: int, db: Session) -> bool:
"""删除文档,同时删除向量库、关键词索引和本地文件"""
document = self.get_document_by_id(document_id, tenant_id, db)
try:
# 1. 删除向量库中的数据
self.vector_db.delete_by_document_id(document_id, str(tenant_id))
# 2. 删除关键词索引中的数据
self.keyword_retriever.delete_by_document_id(document_id, str(tenant_id))
# 3. 删除本地文件
if os.path.exists(document.file_path):
os.remove(document.file_path)
# 4. 软删除数据库中的元数据
document.is_deleted = True
db.commit()
logger.info(f"文档删除成功,文档ID:{document_id}")
return True
except Exception as e:
db.rollback()
logger.error(f"文档删除失败,文档ID:{document_id},错误信息:{str(e)}", exc_info=True)
raise BusinessException(code=500, message="文档删除失败")
# 全局单例
document_process_service = DocumentProcessService()
4.3 对话服务补充
更新app/service/chat_service.py,补充会话管理、对话历史持久化逻辑:
# 原有导入(前五篇已包含,此处省略)
# ...
import uuid
from sqlalchemy.orm import Session
from app.models.chat import ChatSession, ChatMessage
from app.utils.exception import BusinessException
class ChatService:
# 原有方法(前五篇已实现,此处省略)
# ...
def create_session(self, tenant_id: int, user_id: int, session_name: str = "新对话", document_ids: List[str] = None) -> ChatSession:
"""创建新的对话会话"""
session_id = uuid.uuid4().hex
document_ids_str = ",".join(document_ids) if document_ids else ""
session = ChatSession(
session_id=session_id,
session_name=session_name,
tenant_id=tenant_id,
created_by=user_id,
document_ids=document_ids_str
)
return session
def get_session_list(self, tenant_id: int, user_id: int, page: int, page_size: int, db: Session):
"""获取用户的会话列表,分页查询"""
query = db.query(ChatSession).filter(
ChatSession.tenant_id == tenant_id,
ChatSession.created_by == user_id,
ChatSession.is_deleted == False
)
total = query.count()
offset = (page - 1) * page_size
list = query.order_by(ChatSession.updated_at.desc()).offset(offset).limit(page_size).all()
total_page = (total + page_size - 1) // page_size
return {
"total": total,
"page": page,
"page_size": page_size,
"total_page": total_page,
"list": list
}
def get_session_by_id(self, session_id: str, tenant_id: int, user_id: int, db: Session) -> ChatSession:
"""获取会话详情,校验权限"""
session = db.query(ChatSession).filter(
ChatSession.session_id == session_id,
ChatSession.tenant_id == tenant_id,
ChatSession.created_by == user_id,
ChatSession.is_deleted == False
).first()
if not session:
raise BusinessException(code=404, message="会话不存在")
return session
def get_chat_history(self, session_id: str, tenant_id: int, db: Session) -> List[ChatMessage]:
"""获取会话的对话历史"""
messages = db.query(ChatMessage).filter(
ChatMessage.session_id == session_id,
ChatMessage.tenant_id == tenant_id,
ChatMessage.is_deleted == False
).order_by(ChatMessage.created_at.asc()).all()
return messages
def save_chat_message(self, session_id: str, role: str, content: str, tenant_id: int, user_id: int, token_usage: int, latency: int, db: Session) -> ChatMessage:
"""保存对话消息到数据库"""
message = ChatMessage(
session_id=session_id,
role=role,
content=content,
tenant_id=tenant_id,
user_id=user_id,
token_usage=token_usage,
latency=latency
)
db.add(message)
db.commit()
db.refresh(message)
return message
# 全局单例
chat_service = ChatService()
五、核心业务接口实现
现在我们实现完整的业务接口,分为认证接口、文档管理接口、对话问答接口,所有接口都严格遵循前面的设计规范。
5.1 通用接口
新建app/api/common.py,实现健康检查、系统信息等通用接口:
from fastapi import APIRouter
from slowapi import limiter
from app.schemas.common import ResponseModel
common_router = APIRouter()
@common_router.get("/health", summary="系统健康检查", description="用于监控系统健康状态")
@limiter.limit("100/minute")
async def health_check():
return ResponseModel.success(data={
"status": "healthy",
"timestamp": int(time.time() * 1000)
})
@common_router.get("/info", summary="系统信息", description="获取系统版本、名称等基础信息")
async def get_system_info():
from app.config.settings import settings
return ResponseModel.success(data={
"project_name": settings.PROJECT_NAME,
"version": settings.PROJECT_VERSION,
"allowed_file_extensions": settings.ALLOWED_FILE_EXTENSIONS,
"max_upload_file_size": settings.MAX_UPLOAD_FILE_SIZE
})
5.2 V1版本路由总入口
新建app/api/v1/router.py,统一管理V1版本的所有接口路由:
from fastapi import APIRouter
from app.api.v1.auth import auth_router
from app.api.v1.document import document_router
from app.api.v1.chat import chat_router
from app.api.v1.user import user_router
# V1版本路由总入口
api_v1_router = APIRouter()
# 注册子路由
api_v1_router.include_router(auth_router, prefix="/auth", tags=["认证管理"])
api_v1_router.include_router(document_router, prefix="/document", tags=["文档管理"])
api_v1_router.include_router(chat_router, prefix="/chat", tags=["对话问答"])
api_v1_router.include_router(user_router, prefix="/user", tags=["用户管理"])
5.3 认证接口
新建app/api/v1/auth.py,实现登录、Token刷新接口:
from fastapi import APIRouter, Depends
from app.schemas.common import ResponseModel
from app.schemas.auth import LoginRequest, LoginResponse
from app.api.deps import DBSession, get_current_user, CurrentUser
from app.service.auth_service import auth_service
auth_router = APIRouter()
@auth_router.post("/login", summary="用户登录", description="用户登录,获取访问Token", response_model=ResponseModel[LoginResponse])
async def login(data: LoginRequest, db: DBSession):
result = auth_service.login(data.username, data.password, db)
return ResponseModel.success(data=result, message="登录成功")
@auth_router.get("/me", summary="获取当前用户信息", description="获取当前登录用户的详细信息")
async def get_current_user_info(current_user: CurrentUser):
return ResponseModel.success(data={
"id": current_user.id,
"username": current_user.username,
"nickname": current_user.nickname,
"email": current_user.email,
"phone": current_user.phone,
"tenant_id": current_user.tenant_id,
"is_admin": current_user.is_admin,
"is_super_admin": current_user.is_super_admin
})
@auth_router.post("/logout", summary="用户登出", description="用户登出,前端清除Token即可")
async def logout():
return ResponseModel.success(message="登出成功")
对应的Pydantic模型,新建app/schemas/auth.py:
from pydantic import BaseModel, Field
from typing import Optional
class LoginRequest(BaseModel):
username: str = Field(..., description="用户名", min_length=3, max_length=32)
password: str = Field(..., description="密码", min_length=6, max_length=32)
class UserInfo(BaseModel):
id: int
username: str
nickname: str
email: str
phone: str
tenant_id: int
is_admin: bool
is_super_admin: bool
class LoginResponse(BaseModel):
access_token: str
token_type: str
expires_in: int
user_info: UserInfo
5.4 文档管理接口
新建app/api/v1/document.py,实现文档上传、列表查询、删除、详情查询接口:
import asyncio
from fastapi import APIRouter, UploadFile, File, Query, Depends, BackgroundTasks
from fastapi.concurrency import run_in_threadpool
from app.schemas.common import ResponseModel, PageQuery, PageResult
from app.schemas.document import DocumentInfoResponse
from app.api.deps import DBSession, CurrentUser, check_tenant_permission
from app.service.document_service import document_process_service
from app.utils.exception import BusinessException
document_router = APIRouter()
@document_router.post("/upload", summary="上传文档", description="上传文档,系统自动进行解析、分块、向量化,存入知识库")
async def upload_document(
background_tasks: BackgroundTasks,
file: UploadFile = File(..., description="上传的文档文件"),
current_user: CurrentUser = Depends(),
db: DBSession = Depends()
):
# 1. 保存文件,创建元数据
document = await document_process_service.save_upload_file(
file=file,
tenant_id=current_user.tenant_id,
user_id=current_user.id,
db=db
)
# 2. 后台异步处理文档,避免阻塞接口
async def process_document_background():
try:
# 在线程池中运行同步的文档处理方法,避免阻塞async事件循环
await run_in_threadpool(
document_process_service.process_document,
file_path=document.file_path,
tenant_id=str(document.tenant_id),
upload_user_id=document.created_by,
chunk_strategy="hybrid",
chunk_size=500,
chunk_overlap=50
)
# 更新文档状态为处理完成
document.status = 1
document.chunk_count = len(chunks)
db.commit()
except Exception as e:
# 更新文档状态为处理失败
document.status = 2
document.failed_reason = str(e)
db.commit()
# 添加后台任务
background_tasks.add_task(process_document_background)
# 立即返回,无需等待处理完成
return ResponseModel.success(
data={
"document_id": document.document_id,
"file_name": document.file_name,
"status": document.status,
"message": "文档上传成功,正在后台处理中,请稍后刷新查看状态"
},
message="文档上传成功"
)
@document_router.get("/list", summary="获取文档列表", description="分页获取当前租户下的文档列表", response_model=ResponseModel[PageResult[DocumentInfoResponse]])
async def get_document_list(
page_query: PageQuery = Depends(),
document_name: str = Query(None, description="文档名称,模糊查询"),
current_user: CurrentUser = Depends(),
db: DBSession = Depends()
):
result = document_process_service.get_document_list(
tenant_id=current_user.tenant_id,
page=page_query.page,
page_size=page_query.page_size,
db=db,
document_name=document_name
)
return ResponseModel.success(data=result)
@document_router.get("/{document_id}", summary="获取文档详情", description="根据文档ID获取文档详细信息")
async def get_document_detail(
document_id: str,
current_user: CurrentUser = Depends(),
db: DBSession = Depends()
):
document = document_process_service.get_document_by_id(document_id, current_user.tenant_id, db)
return ResponseModel.success(data=document)
@document_router.delete("/{document_id}", summary="删除文档", description="删除文档,同时删除知识库中的向量数据")
async def delete_document(
document_id: str,
current_user: CurrentUser = Depends(),
db: DBSession = Depends()
):
document_process_service.delete_document(document_id, current_user.tenant_id, db)
return ResponseModel.success(message="文档删除成功")
对应的Pydantic模型,新建app/schemas/document.py:
from pydantic import BaseModel
from datetime import datetime
from typing import Optional
class DocumentInfoResponse(BaseModel):
id: int
document_id: str
file_name: str
file_ext: str
file_size: int
chunk_count: int
status: int
failed_reason: str
created_by: int
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
5.5 对话问答接口(核心重点)
新建app/api/v1/chat.py,实现会话管理、同步问答、流式问答接口,流式接口是重点,采用SSE标准格式,前端可直接对接:
import json
import time
from fastapi import APIRouter, Depends, Query
from fastapi.responses import StreamingResponse
from fastapi.concurrency import run_in_threadpool
from typing import List, Optional
from loguru import logger
from app.schemas.common import ResponseModel, PageQuery, PageResult
from app.schemas.chat import ChatRequest, ChatResponse, SessionInfoResponse
from app.api.deps import DBSession, CurrentUser
from app.service.chat_service import chat_service
from app.core.llm.prompt_template import get_prompt_manager
chat_router = APIRouter()
prompt_manager = get_prompt_manager()
@chat_router.post("/session/create", summary="创建对话会话", description="创建新的对话会话,可指定关联的文档")
async def create_chat_session(
session_name: str = Query("新对话", description="会话名称"),
document_ids: Optional[str] = Query(None, description="关联的文档ID,多个用逗号分隔"),
current_user: CurrentUser = Depends(),
db: DBSession = Depends()
):
document_id_list = document_ids.split(",") if document_ids else None
session = chat_service.create_session(
tenant_id=current_user.tenant_id,
user_id=current_user.id,
session_name=session_name,
document_ids=document_id_list
)
db.add(session)
db.commit()
db.refresh(session)
return ResponseModel.success(data=session, message="会话创建成功")
@chat_router.get("/session/list", summary="获取会话列表", description="分页获取当前用户的对话会话列表")
async def get_session_list(
page_query: PageQuery = Depends(),
current_user: CurrentUser = Depends(),
db: DBSession = Depends()
):
result = chat_service.get_session_list(
tenant_id=current_user.tenant_id,
user_id=current_user.id,
page=page_query.page,
page_size=page_query.page_size,
db=db
)
return ResponseModel.success(data=result)
@chat_router.get("/session/{session_id}", summary="获取会话详情", description="获取会话的详细信息和对话历史")
async def get_session_detail(
session_id: str,
current_user: CurrentUser = Depends(),
db: DBSession = Depends()
):
session = chat_service.get_session_by_id(session_id, current_user.tenant_id, current_user.id, db)
history = chat_service.get_chat_history(session_id, current_user.tenant_id, db)
return ResponseModel.success(data={
"session_info": session,
"chat_history": history
})
@chat_router.delete("/session/{session_id}", summary="删除会话", description="删除对话会话,软删除")
async def delete_session(
session_id: str,
current_user: CurrentUser = Depends(),
db: DBSession = Depends()
):
session = chat_service.get_session_by_id(session_id, current_user.tenant_id, current_user.id, db)
session.is_deleted = True
db.commit()
return ResponseModel.success(message="会话删除成功")
@chat_router.post("/sync", summary="同步问答接口", description="非流式问答,等待完整答案生成后返回", response_model=ResponseModel[ChatResponse])
async def chat_sync(
data: ChatRequest,
current_user: CurrentUser = Depends(),
db: DBSession = Depends()
):
start_time = time.time()
# 1. 获取会话信息
session = chat_service.get_session_by_id(data.session_id, current_user.tenant_id, current_user.id, db)
# 2. 获取对话历史
history_messages = chat_service.get_chat_history(data.session_id, current_user.tenant_id, db)
# 3. 格式化对话历史
formatted_history = prompt_manager.format_chat_history([
{"role": msg.role, "content": msg.content} for msg in history_messages
])
# 4. 解析关联的文档ID
document_ids = session.document_ids.split(",") if session.document_ids else None
# 5. 调用RAG问答服务
chat_result = await run_in_threadpool(
chat_service.chat,
user_query=data.query,
tenant_id=str(current_user.tenant_id),
document_ids=document_ids,
chat_history=formatted_history,
enable_citation=True
)
# 6. 计算耗时
latency = int((time.time() - start_time) * 1000)
# 7. 保存用户提问和助手回答到数据库
chat_service.save_chat_message(
session_id=data.session_id,
role="user",
content=data.query,
tenant_id=current_user.tenant_id,
user_id=current_user.id,
token_usage=0,
latency=0,
db=db
)
chat_service.save_chat_message(
session_id=data.session_id,
role="assistant",
content=chat_result["answer"],
tenant_id=current_user.tenant_id,
user_id=current_user.id,
token_usage=chat_result["usage"]["total_tokens"],
latency=latency,
db=db
)
# 8. 更新会话最后更新时间
session.updated_at = func.now()
db.commit()
# 9. 返回结果
return ResponseModel.success(data=chat_result, message="问答成功")
@chat_router.post("/stream", summary="流式问答接口", description="SSE流式输出,逐字返回答案,前端用EventSource对接")
async def chat_stream(
data: ChatRequest,
current_user: CurrentUser = Depends(),
db: DBSession = Depends()
):
start_time = time.time()
# 1. 获取会话信息
session = chat_service.get_session_by_id(data.session_id, current_user.tenant_id, current_user.id, db)
# 2. 获取对话历史
history_messages = chat_service.get_chat_history(data.session_id, current_user.tenant_id, db)
formatted_history = prompt_manager.format_chat_history([
{"role": msg.role, "content": msg.content} for msg in history_messages
])
document_ids = session.document_ids.split(",") if session.document_ids else None
# 3. 保存用户提问
chat_service.save_chat_message(
session_id=data.session_id,
role="user",
content=data.query,
tenant_id=current_user.tenant_id,
user_id=current_user.id,
token_usage=0,
latency=0,
db=db
)
# 4. 异步流式生成器
async def event_generator():
full_answer = ""
token_usage = 0
try:
# 先发送开始事件
yield f"data: {json.dumps({'type': 'start', 'message': '开始生成答案'})}\n\n"
# 获取流式生成器
stream_generator = await run_in_threadpool(
chat_service.stream_chat,
user_query=data.query,
tenant_id=str(current_user.tenant_id),
document_ids=document_ids,
chat_history=formatted_history
)
# 逐字返回
for chunk in stream_generator:
if chunk.startswith("[ERROR]"):
yield f"data: {json.dumps({'type': 'error', 'message': chunk})}\n\n"
break
full_answer += chunk
yield f"data: {json.dumps({'type': 'chunk', 'content': chunk})}\n\n"
# 发送结束事件
latency = int((time.time() - start_time) * 1000)
yield f"data: {json.dumps({'type': 'end', 'full_answer': full_answer, 'latency': latency})}\n\n"
# 保存助手回答到数据库
chat_service.save_chat_message(
session_id=data.session_id,
role="assistant",
content=full_answer,
tenant_id=current_user.tenant_id,
user_id=current_user.id,
token_usage=token_usage,
latency=latency,
db=db
)
# 更新会话时间
session.updated_at = func.now()
db.commit()
except Exception as e:
logger.error(f"流式问答失败:{str(e)}", exc_info=True)
yield f"data: {json.dumps({'type': 'error', 'message': f'生成失败:{str(e)}'})}\n\n"
# 返回SSE流式响应
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no" # 禁用Nginx缓冲,保证流式输出正常
}
)
对应的Pydantic模型,新建app/schemas/chat.py:
from pydantic import BaseModel, Field
from datetime import datetime
from typing import List, Optional, Dict, Any
class ChatRequest(BaseModel):
session_id: str = Field(..., description="会话ID")
query: str = Field(..., description="用户提问", min_length=1, max_length=2000)
stream: bool = Field(True, description="是否流式输出")
class CitationSource(BaseModel):
citation_number: int
text: str
document_id: str
file_name: str
chunk_id: str
heading: str
class ChatResponse(BaseModel):
answer: str
original_answer: str
has_reference: bool
citation_sources: List[CitationSource]
usage: Dict[str, int]
latency: float
class SessionInfoResponse(BaseModel):
id: int
session_id: str
session_name: str
document_ids: str
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
六、接口测试与文档
FastAPI自带了自动生成的接口文档,项目启动后,直接访问以下地址即可:
- Swagger UI交互式文档:
http://localhost:8000/docs - Redoc文档:
http://localhost:8000/redoc
6.1 项目启动步骤
- 安装依赖:
pip install -r requirements.txt - 复制环境变量文件:
cp .env.example .env,修改对应的配置,尤其是PostgreSQL连接信息、JWT密钥 - 启动PostgreSQL、Redis、Milvus服务(第三篇的Docker Compose已包含)
- 启动项目:
python app/main.py - 访问接口文档:
http://localhost:8000/docs,即可看到所有接口,直接在页面上测试
6.2 核心接口测试流程
- 调用
/api/v1/auth/login接口,用默认的超级管理员账号admin/admin123456登录,获取Token - 点击Swagger UI右上角的
Authorize,输入Bearer 你的Token,完成鉴权 - 调用
/api/v1/document/upload接口,上传一个测试文档,比如PDF、Word - 等待文档后台处理完成,调用
/api/v1/document/list接口,查看文档状态,确认处理完成 - 调用
/api/v1/chat/session/create接口,创建对话会话,关联刚才上传的文档ID - 调用
/api/v1/chat/sync同步接口,或者/api/v1/chat/stream流式接口,输入问题,即可得到RAG问答结果
七、踩坑记录&避坑指南:新手接口开发必踩的8个大坑
坑1:同步代码阻塞async事件循环,服务卡死
踩坑场景:新手直接在async接口里调用文档处理、LLM调用等耗时的同步方法,导致FastAPI的事件循环被阻塞,所有请求都被卡住,服务直接卡死。
避坑方案:用fastapi.concurrency.run_in_threadpool把同步方法放到线程池中运行,避免阻塞事件循环,本篇的流式接口和文档上传接口都已经做了这个处理。
坑2:流式接口跨域/缓冲问题,前端收不到流式数据
踩坑场景:流式接口在本地测试正常,部署到服务器上,前端用Nginx反向代理后,收不到流式数据,必须等完整内容生成后才一次性返回。
避坑方案:流式响应必须添加X-Accel-Buffering: no响应头,禁用Nginx的缓冲,同时设置Cache-Control: no-cache、Connection: keep-alive,本篇的流式接口已经添加了这些头。
坑3:大文件上传超时/内存溢出
踩坑场景:新手上传大文件时,直接把整个文件读入内存,导致内存溢出,或者接口超时。
避坑方案:
- 限制文件大小,本篇的配置里设置了最大10MB,可根据业务调整
- 用异步分块读取文件,不要一次性读入内存
- 大文件用分块上传方案,前端把文件切成小块,后端合并
- 延长接口超时时间,或者用后台异步处理
坑4:全局异常处理不生效
踩坑场景:新手写了全局异常处理器,但异常还是直接抛出,没有被捕获,返回默认的FastAPI错误格式。
避坑方案:
- 异常处理器必须用
app.add_exception_handler注册到FastAPI实例上 - 自定义异常必须继承自
Exception,不能用HTTPException - 依赖项中抛出的异常,必须在全局异常处理器中注册对应的处理函数
坑5:租户隔离没做好,越权访问
踩坑场景:新手只在接口里校验了用户登录,没有校验租户权限,A租户的用户可以通过修改文档ID,访问、删除B租户的文档,出现严重的数据泄露。
避坑方案:
- 所有资源操作,必须校验租户ID,普通用户只能操作自己租户的资源
- 数据库查询时,必须带上
tenant_id过滤条件,不能只靠ID查询 - 本篇的
check_tenant_permission依赖项已经实现了权限校验,所有接口必须调用
坑6:接口没有参数校验,非法参数导致服务崩溃
踩坑场景:新手接口的入参没有做校验,用户传入非法的参数,比如超长的字符串、负数的页码,导致数据库查询报错、服务崩溃。
避坑方案:所有入参必须用Pydantic模型做严格校验,包括类型、长度、范围、格式,非法参数在接口层直接拦截,不会传入业务层。
坑7:敏感信息泄露,日志/返回里明文打印密钥、Token、密码
踩坑场景:新手在日志里打印完整的Token、用户密码、API密钥,或者在接口返回里把用户的密码哈希返回给前端,出现严重的敏感信息泄露。
避坑方案:
- 敏感信息必须脱敏,日志里只打印掩码后的内容
- 用户密码绝对不能返回给前端,哪怕是哈希后的
- API密钥、JWT密钥绝对不能打印到日志里,也不能硬编码在代码里,必须放在环境变量中
坑8:接口没有限流,被恶意刷接口导致服务崩溃
踩坑场景:接口没有限流,被恶意用户用脚本疯狂刷接口,导致LLM账单爆炸、服务资源耗尽、直接崩溃。
避坑方案:用SlowAPI实现接口限流,本篇已经集成了限流器,核心接口必须添加限流规则,比如登录接口限制10/minute,问答接口限制60/minute。
八、下一篇预告
本篇我们完成了生产级RAG系统的标准化后端接口开发,现在我们的系统已经可以对接前端页面、被第三方系统集成,具备了上线交付的基础能力。
下一篇预告:第7篇《用户权限与多租户系统设计,商用必备》,我会手把手带你完成:
- 商用级多租户系统的核心设计,租户隔离的3种实现方案
- RBAC角色权限体系设计,超级管理员、租户管理员、普通用户的权限管控
- 细粒度的资源权限控制,文档、会话的共享与权限分配
- 完整的用户、租户、角色管理接口,与本篇的接口无缝衔接
- 商用场景的租户计费、配额管理方案
结尾互动
本篇是《30天做一个生产级RAG知识库系统》全系列的第六篇,我们完成了商用级标准化接口的全量开发,现在你的RAG系统已经从一个本地demo,变成了一个可交付、可集成、可上线的商用服务。
最后想问一下大家:
- 你在开发后端接口的时候,遇到过最头疼的问题是什么?是流式接口对接、权限管控,还是服务性能问题?
- 你在对接前端的时候,遇到过哪些跨域、数据格式的坑?
欢迎在评论区留言,我会在后续的文章中,针对性地给你解决方案!
如果觉得这个系列对你有帮助,欢迎点赞、收藏、关注,后续的文章会第一时间推送给你,跟着更完,就能上线属于你的商用级RAG系统!
更多推荐
所有评论(0)