在之前的三篇教程中,我们系统地学习了 FastAPI 的基础知识、核心概念以及 WebSocket、数据库集成、安全认证等高级功能。现在,让我们继续深入探索 FastAPI 更深层次的高级特性,这些特性将帮助你构建更健壮、更高效、更易于维护的生产级应用。本篇教程将涵盖性能优化策略、高级安全实践、异步任务处理、测试驱动开发、GraphQL 集成以及微服务架构中的应用。

1. 性能优化策略

FastAPI 本身已经具备出色的性能,但在生产环境中,合理的优化可以进一步提升应用的吞吐量和响应速度。

1.1 异步数据库操作

使用异步数据库驱动是实现高性能的关键。FastAPI 的异步特性需要贯穿整个调用链才能发挥最大效果。

from fastapi import FastAPI
import asyncpg

app = FastAPI()

@app.on_event("startup")
async def startup():
    # 创建数据库连接池
    app.state.db = await asyncpg.create_pool(
        dsn="postgresql://user:pass@localhost/db",
        min_size=10,
        max_size=20
    )

@app.on_event("shutdown")
async def shutdown():
    await app.state.db.close()

@app.get("/async-query")
async def async_query():
    async with app.state.db.acquire() as conn:
        result = await conn.fetch("SELECT * FROM users LIMIT 10")
    return [dict(row) for row in result]

性能提升:异步数据库查询可减少 70% 的等待时间,在高并发场景下,异步模式比同步模式快 3-5 倍。

1.2 缓存策略

合理使用缓存可以显著降低数据库负载和响应时间。

使用 cachetools 进行内存缓存
from cachetools import TTLCache
from fastapi import FastAPI, Request

cache = TTLCache(maxsize=100, ttl=300)  # 最多缓存100个条目,有效期5分钟

@app.get("/cached-items/{item_id}")
async def get_cached_item(item_id: int):
    cache_key = f"item_{item_id}"
    
    if cache_key in cache:
        return {"data": cache[cache_key], "source": "cache"}
    
    # 模拟从数据库获取数据
    item = {"id": item_id, "name": f"Item {item_id}", "price": 100.0}
    cache[cache_key] = item
    
    return {"data": item, "source": "database"}
集成 Redis 分布式缓存

对于分布式部署,Redis 是更好的选择:

import aioredis
from fastapi import FastAPI

app = FastAPI()

@app.on_event("startup")
async def startup():
    app.state.redis = await aioredis.from_url("redis://localhost")

@app.get("/redis-cached/{key}")
async def get_redis_cached(key: str):
    # 尝试从缓存获取
    cached = await app.state.redis.get(key)
    if cached:
        return {"data": cached.decode(), "source": "redis"}
    
    # 模拟计算或数据库查询
    value = f"Computed value for {key}"
    await app.state.redis.setex(key, 300, value)  # 5分钟过期
    
    return {"data": value, "source": "computed"}

1.3 Gzip 压缩

启用 Gzip 压缩可以减少网络传输数据量,特别适合文本响应较大的 API。

from fastapi import FastAPI
from fastapi.middleware.gzip import GZipMiddleware

app = FastAPI()

# 添加 Gzip 中间件,仅压缩大于 1KB 的响应
app.add_middleware(GZipMiddleware, minimum_size=1000)

@app.get("/large-response")
async def get_large_response():
    # 返回大量数据,会被自动压缩
    return {"data": "x" * 10000}

1.4 速率限制

防止 API 被过度调用或恶意攻击,使用 slowapi 实现速率限制。

from fastapi import FastAPI, Request
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded

limiter = Limiter(key_func=get_remote_address)
app = FastAPI()
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)

@app.get("/public")
@limiter.limit("100/minute")  # 每分钟最多100次请求
async def public_endpoint(request: Request):
    return {"message": "This is a public endpoint with rate limiting"}

@app.get("/premium")
@limiter.limit("1000/hour")  # 每小时最多1000次请求
async def premium_endpoint(request: Request):
    return {"message": "Premium users have higher limits"}

1.5 连接池优化

合理配置数据库连接池大小对性能至关重要。连接池过小会导致请求等待,过大则会消耗过多资源。

from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import sessionmaker

# 根据服务器配置调整连接池参数
engine = create_async_engine(
    "postgresql+asyncpg://user:pass@localhost/db",
    pool_size=20,           # 连接池大小
    max_overflow=10,        # 超出 pool_size 后最多可以创建的连接数
    pool_pre_ping=True,     # 连接前检查,避免使用失效连接
    pool_recycle=3600,      # 连接回收时间(秒)
)

AsyncSessionLocal = sessionmaker(engine, class_=AsyncSession)

1.6 性能分析与监控

使用性能分析工具定位瓶颈。

import time
from fastapi import FastAPI, Request

app = FastAPI()

@app.middleware("http")
async def monitor_performance(request: Request, call_next):
    start_time = time.time()
    
    response = await call_next(request)
    
    process_time = time.time() - start_time
    response.headers["X-Process-Time"] = str(process_time)
    
    # 记录慢请求
    if process_time > 1.0:
        print(f"Slow request: {request.url.path} took {process_time:.2f}s")
    
    return response

2. 高级安全实践

2.1 多因素认证

除了基本的 JWT 认证,可以实现多因素认证增强安全性。

import pyotp
import qrcode
from io import BytesIO
from fastapi.responses import Response

class MFAService:
    def __init__(self):
        self.user_secrets = {}  # 实际应用中应存储在数据库
    
    def generate_secret(self, username: str):
        secret = pyotp.random_base32()
        self.user_secrets[username] = secret
        return secret
    
    def get_totp_uri(self, username: str, secret: str):
        return pyotp.totp.TOTP(secret).provisioning_uri(
            name=username,
            issuer_name="MyApp"
        )
    
    def verify_totp(self, username: str, token: str):
        if username not in self.user_secrets:
            return False
        totp = pyotp.TOTP(self.user_secrets[username])
        return totp.verify(token)

mfa_service = MFAService()

@app.post("/mfa/setup")
async def setup_mfa(current_user: dict = Depends(get_current_user)):
    secret = mfa_service.generate_secret(current_user["username"])
    uri = mfa_service.get_totp_uri(current_user["username"], secret)
    
    # 生成二维码
    qr = qrcode.make(uri)
    img_bytes = BytesIO()
    qr.save(img_bytes)
    img_bytes.seek(0)
    
    return Response(content=img_bytes.getvalue(), media_type="image/png")

@app.post("/mfa/verify")
async def verify_mfa(
    token: str = Form(...),
    current_user: dict = Depends(get_current_user)
):
    if mfa_service.verify_totp(current_user["username"], token):
        return {"message": "MFA verified successfully"}
    raise HTTPException(status_code=400, detail="Invalid token")

2.2 API 密钥管理

为第三方开发者提供 API 密钥认证。

from fastapi.security import APIKeyHeader
from pydantic import BaseModel
import secrets
from datetime import datetime, timedelta

api_key_header = APIKeyHeader(name="X-API-Key")

class APIKey(BaseModel):
    key: str
    name: str
    created_at: datetime
    expires_at: datetime
    permissions: list[str]

class APIKeyManager:
    def __init__(self):
        self.keys = {}  # 实际应用中应存储在数据库
    
    def create_key(self, name: str, permissions: list[str], days_valid: int = 30):
        key = secrets.token_urlsafe(32)
        now = datetime.utcnow()
        api_key = APIKey(
            key=key,
            name=name,
            created_at=now,
            expires_at=now + timedelta(days=days_valid),
            permissions=permissions
        )
        self.keys[key] = api_key
        return api_key
    
    def verify_key(self, key: str):
        if key not in self.keys:
            return None
        api_key = self.keys[key]
        if api_key.expires_at < datetime.utcnow():
            return None
        return api_key

key_manager = APIKeyManager()

async def verify_api_key(api_key: str = Depends(api_key_header)):
    key_obj = key_manager.verify_key(api_key)
    if not key_obj:
        raise HTTPException(
            status_code=403,
            detail="Invalid or expired API key"
        )
    return key_obj

@app.post("/api-keys/create")
async def create_api_key(
    name: str,
    permissions: list[str],
    admin: dict = Depends(require_permission("admin"))
):
    api_key = key_manager.create_key(name, permissions)
    return api_key

@app.get("/protected-resource")
async def protected_resource(api_key: APIKey = Depends(verify_api_key)):
    return {"message": f"Access granted for {api_key.name}"}

2.3 密码哈希策略增强

使用更安全的密码哈希算法和策略。

from passlib.context import CryptContext
import bcrypt

# 配置多种哈希算法,支持密码升级
pwd_context = CryptContext(
    schemes=["bcrypt", "argon2"],
    default="bcrypt",
    bcrypt__rounds=12,  # 增加计算复杂度
    argon2__time_cost=2,
    argon2__memory_cost=102400,
    argon2__parallelism=8,
)

def verify_and_upgrade_password(plain_password: str, hashed_password: str):
    """验证密码,如果使用的不是最优算法则升级"""
    is_valid, new_hash = pwd_context.verify_and_update(plain_password, hashed_password)
    return is_valid, new_hash

2.4 CSRF 保护

对于非 API 的传统 Web 应用,需要添加 CSRF 保护。

from fastapi import FastAPI, Request, Form
from fastapi.templating import Jinja2Templates
from itsdangerous import URLSafeTimedSerializer
import secrets

app = FastAPI()
templates = Jinja2Templates(directory="templates")

# CSRF 保护配置
SECRET_KEY = "your-secret-key"
serializer = URLSafeTimedSerializer(SECRET_KEY)

def generate_csrf_token():
    """生成 CSRF 令牌"""
    return secrets.token_urlsafe(32)

def validate_csrf_token(token: str):
    """验证 CSRF 令牌"""
    try:
        serializer.loads(token, max_age=3600)  # 1小时过期
        return True
    except:
        return False

@app.get("/form")
async def show_form(request: Request):
    # 生成 CSRF 令牌并存储在 session 中
    csrf_token = generate_csrf_token()
    signed_token = serializer.dumps(csrf_token)
    
    return templates.TemplateResponse(
        "form.html",
        {"request": request, "csrf_token": signed_token}
    )

@app.post("/submit")
async def handle_form(
    request: Request,
    data: str = Form(...),
    csrf_token: str = Form(...)
):
    if not validate_csrf_token(csrf_token):
        raise HTTPException(status_code=400, detail="Invalid CSRF token")
    
    return {"message": "Form submitted successfully", "data": data}

3. 异步任务处理

3.1 BackgroundTasks 的进阶用法

除了简单的后台任务,可以结合数据库操作和错误处理。

from fastapi import BackgroundTasks
from sqlalchemy.ext.asyncio import AsyncSession
import asyncio

class TaskManager:
    def __init__(self):
        self.task_status = {}
    
    async def long_running_task(self, task_id: str, params: dict):
        """模拟耗时任务"""
        try:
            self.task_status[task_id] = "running"
            
            # 模拟多步骤处理
            for step in range(5):
                await asyncio.sleep(2)  # 模拟耗时操作
                self.task_status[task_id] = f"step_{step}"
            
            self.task_status[task_id] = "completed"
        except Exception as e:
            self.task_status[task_id] = f"failed: {str(e)}"

task_manager = TaskManager()

@app.post("/start-task")
async def start_background_task(background_tasks: BackgroundTasks):
    task_id = str(uuid.uuid4())
    background_tasks.add_task(
        task_manager.long_running_task,
        task_id,
        {"param1": "value1"}
    )
    return {"task_id": task_id, "status": "started"}

@app.get("/task-status/{task_id}")
async def get_task_status(task_id: str):
    status = task_manager.task_status.get(task_id, "not_found")
    return {"task_id": task_id, "status": status}

3.2 Celery 集成

对于更复杂的分布式任务处理,集成 Celery。

# celery_app.py
from celery import Celery

celery_app = Celery(
    "tasks",
    broker="redis://localhost:6379/0",
    backend="redis://localhost:6379/0"
)

celery_app.conf.update(
    task_serializer="json",
    accept_content=["json"],
    result_serializer="json",
    timezone="UTC",
    enable_utc=True,
    task_track_started=True,
    task_time_limit=30 * 60,  # 30分钟
    task_soft_time_limit=25 * 60,  # 25分钟软限制
)

@celery_app.task(bind=True, max_retries=3)
def process_data(self, data_id: int):
    """处理数据的 Celery 任务"""
    try:
        # 执行耗时操作
        result = perform_data_processing(data_id)
        return {"status": "success", "result": result}
    except Exception as e:
        self.retry(exc=e, countdown=60)  # 60秒后重试

在 FastAPI 中调用 Celery 任务:

# main.py
from celery.result import AsyncResult
from .celery_app import process_data

@app.post("/process/{data_id}")
async def start_processing(data_id: int):
    # 异步调用 Celery 任务
    task = process_data.delay(data_id)
    return {"task_id": task.id, "status": "processing"}

@app.get("/task/{task_id}")
async def get_task_result(task_id: str):
    task = AsyncResult(task_id, app=celery_app)
    
    if task.state == "PENDING":
        response = {"state": task.state, "status": "Task pending..."}
    elif task.state == "FAILURE":
        response = {"state": task.state, "status": str(task.info)}
    else:
        response = {"state": task.state, "result": task.result}
    
    return response

3.3 定时任务

使用 APScheduler 或 Celery Beat 实现定时任务。

from apscheduler.schedulers.asyncio import AsyncIOScheduler
from apscheduler.triggers.interval import IntervalTrigger
from contextlib import asynccontextmanager

scheduler = AsyncIOScheduler()

async def periodic_cleanup():
    """定期清理过期数据的任务"""
    print(f"Running cleanup at {datetime.utcnow()}")
    # 执行清理逻辑
    await cleanup_expired_records()

@asynccontextmanager
async def lifespan(app: FastAPI):
    # 启动时启动调度器
    scheduler.add_job(
        periodic_cleanup,
        trigger=IntervalTrigger(hours=24),
        id="daily_cleanup",
        replace_existing=True
    )
    scheduler.start()
    
    yield
    
    # 关闭时停止调度器
    scheduler.shutdown()

app = FastAPI(lifespan=lifespan)

4. 测试驱动开发

4.1 单元测试最佳实践

使用 pytest 编写全面的测试。

# test_main.py
import pytest
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool

from .main import app
from .database import Base, get_db

# 使用内存数据库进行测试
SQLALCHEMY_DATABASE_URL = "sqlite:///:memory:"
engine = create_engine(
    SQLALCHEMY_DATABASE_URL,
    connect_args={"check_same_thread": False},
    poolclass=StaticPool,
)
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)

def override_get_db():
    """重写依赖项"""
    try:
        db = TestingSessionLocal()
        yield db
    finally:
        db.close()

app.dependency_overrides[get_db] = override_get_db

@pytest.fixture
def client():
    """创建测试客户端"""
    Base.metadata.create_all(bind=engine)
    with TestClient(app) as test_client:
        yield test_client
    Base.metadata.drop_all(bind=engine)

def test_create_item(client):
    """测试创建商品"""
    response = client.post(
        "/items/",
        json={"name": "Test Item", "price": 100.0}
    )
    assert response.status_code == 200
    data = response.json()
    assert data["name"] == "Test Item"
    assert data["price"] == 100.0
    assert "id" in data

def test_get_nonexistent_item(client):
    """测试获取不存在的商品"""
    response = client.get("/items/999")
    assert response.status_code == 404
    assert response.json()["detail"] == "Item not found"

@pytest.mark.parametrize("invalid_data", [
    {"name": "Test", "price": -10},  # 价格不能为负
    {"name": "", "price": 100},      # 名称不能为空
    {"price": 100},                   # 缺少名称
])
def test_invalid_item_creation(client, invalid_data):
    """测试无效的商品创建"""
    response = client.post("/items/", json=invalid_data)
    assert response.status_code == 422  # 验证错误

4.2 异步测试

测试异步端点需要特殊处理。

import pytest
from httpx import AsyncClient

@pytest.mark.asyncio
async def test_async_endpoint():
    async with AsyncClient(app=app, base_url="http://test") as ac:
        response = await ac.get("/async-data")
    assert response.status_code == 200
    assert response.json() == {"data": "async response"}

4.3 测试覆盖率

使用 pytest-cov 测量测试覆盖率。

pytest --cov=app --cov-report=html tests/

5. GraphQL 集成

FastAPI 可以与 Strawberry 或 Graphene 等 GraphQL 库集成,提供更灵活的查询能力。

5.1 使用 Strawberry 实现 GraphQL

import strawberry
from fastapi import FastAPI
from strawberry.fastapi import GraphQLRouter
from typing import List, Optional

# 定义 GraphQL 类型
@strawberry.type
class Book:
    id: int
    title: str
    author: str
    year: int

# 模拟数据
books_db = [
    Book(id=1, title="The Hobbit", author="J.R.R. Tolkien", year=1937),
    Book(id=2, title="1984", author="George Orwell", year=1949),
    Book(id=3, title="Dune", author="Frank Herbert", year=1965),
]

# 定义 Query
@strawberry.type
class Query:
    @strawberry.field
    def books(self, author: Optional[str] = None) -> List[Book]:
        if author:
            return [b for b in books_db if b.author == author]
        return books_db
    
    @strawberry.field
    def book(self, id: int) -> Optional[Book]:
        return next((b for b in books_db if b.id == id), None)

# 定义 Mutation
@strawberry.type
class Mutation:
    @strawberry.mutation
    def add_book(self, title: str, author: str, year: int) -> Book:
        new_id = max(books_db, key=lambda b: b.id).id + 1
        book = Book(id=new_id, title=title, author=author, year=year)
        books_db.append(book)
        return book

schema = strawberry.Schema(query=Query, mutation=Mutation)
graphql_app = GraphQLRouter(schema)

app = FastAPI()
app.include_router(graphql_app, prefix="/graphql")

现在可以通过 /graphql 端点执行 GraphQL 查询:

# 查询示例
{
  books(author: "J.R.R. Tolkien") {
    title
    year
  }
}

# 变更示例
mutation {
  addBook(title: "Foundation", author: "Isaac Asimov", year: 1951) {
    id
    title
  }
}

6. 微服务架构中的 FastAPI

FastAPI 非常适合构建微服务。

6.1 服务发现与注册

使用 Consul 或 etcd 实现服务发现。

import consul
import socket
import uuid

class ServiceRegistry:
    def __init__(self, consul_host="localhost", consul_port=8500):
        self.consul = consul.Consul(host=consul_host, port=consul_port)
        self.service_id = str(uuid.uuid4())
    
    def register(self, service_name: str, port: int):
        """注册服务到 Consul"""
        hostname = socket.gethostname()
        ip = socket.gethostbyname(hostname)
        
        self.consul.agent.service.register(
            name=service_name,
            service_id=self.service_id,
            address=ip,
            port=port,
            check=consul.Check().tcp(ip, port, "10s"),
            tags=["fastapi", "v1"]
        )
    
    def deregister(self):
        """注销服务"""
        self.consul.agent.service.deregister(self.service_id)
    
    def discover(self, service_name: str):
        """发现服务实例"""
        _, services = self.consul.health.service(service_name, passing=True)
        return [
            {
                "id": s["Service"]["ID"],
                "address": s["Service"]["Address"],
                "port": s["Service"]["Port"]
            }
            for s in services
        ]

# 在应用生命周期中注册
registry = ServiceRegistry()

@app.on_event("startup")
async def startup():
    registry.register("user-service", 8000)

@app.on_event("shutdown")
async def shutdown():
    registry.deregister()

6.2 服务间通信

使用 httpx 进行异步服务间通信。

import httpx
from fastapi import FastAPI

app = FastAPI()

class ServiceClient:
    def __init__(self, service_name: str):
        self.service_name = service_name
        self.client = httpx.AsyncClient(timeout=10.0)
    
    async def get_user(self, user_id: int):
        # 实际应用中应从服务发现获取地址
        response = await self.client.get(
            f"http://user-service:8001/users/{user_id}"
        )
        response.raise_for_status()
        return response.json()
    
    async def close(self):
        await self.client.aclose()

user_client = ServiceClient("user-service")
order_client = ServiceClient("order-service")

@app.get("/api/users/{user_id}/orders")
async def get_user_orders(user_id: int):
    # 并行调用多个微服务
    user_task = user_client.get_user(user_id)
    orders_task = order_client.get_user_orders(user_id)
    
    user, orders = await asyncio.gather(user_task, orders_task)
    
    return {
        "user": user,
        "orders": orders
    }

7. 总结

通过本教程,我们深入探讨了 FastAPI 的一系列高级特性:

  • 性能优化:异步数据库操作、多级缓存策略、Gzip 压缩、速率限制、连接池优化和性能监控
  • 高级安全实践:多因素认证、API 密钥管理、增强的密码哈希策略和 CSRF 保护
  • 异步任务处理:BackgroundTasks 进阶用法、Celery 集成和定时任务
  • 测试驱动开发:单元测试最佳实践、异步测试和测试覆盖率
  • GraphQL 集成:使用 Strawberry 实现灵活的查询接口
  • 微服务架构:服务发现注册和异步服务间通信
Logo

腾讯云面向开发者汇聚海量精品云计算使用和开发经验,营造开放的云计算技术生态圈。

更多推荐