Python之FastAPI 开发框架(第四篇):高级特性深度解析与实战进阶
本文深入探讨FastAPI的高级特性,涵盖性能优化和安全实践两大核心领域。在性能优化方面,介绍了异步数据库操作、缓存策略(内存缓存与Redis)、Gzip压缩、速率限制和连接池优化等技术,通过代码示例展示了如何提升应用吞吐量和响应速度。在安全实践部分,重点讲解了多因素认证的实现方法,包括使用pyotp生成TOTP密钥和验证码。这些高级特性将帮助开发者构建更高效、更安全的FastAPI生产级应用,特
在之前的三篇教程中,我们系统地学习了 FastAPI 的基础知识、核心概念以及 WebSocket、数据库集成、安全认证等高级功能。现在,让我们继续深入探索 FastAPI 更深层次的高级特性,这些特性将帮助你构建更健壮、更高效、更易于维护的生产级应用。本篇教程将涵盖性能优化策略、高级安全实践、异步任务处理、测试驱动开发、GraphQL 集成以及微服务架构中的应用。
1. 性能优化策略
FastAPI 本身已经具备出色的性能,但在生产环境中,合理的优化可以进一步提升应用的吞吐量和响应速度。
1.1 异步数据库操作
使用异步数据库驱动是实现高性能的关键。FastAPI 的异步特性需要贯穿整个调用链才能发挥最大效果。
from fastapi import FastAPI
import asyncpg
app = FastAPI()
@app.on_event("startup")
async def startup():
# 创建数据库连接池
app.state.db = await asyncpg.create_pool(
dsn="postgresql://user:pass@localhost/db",
min_size=10,
max_size=20
)
@app.on_event("shutdown")
async def shutdown():
await app.state.db.close()
@app.get("/async-query")
async def async_query():
async with app.state.db.acquire() as conn:
result = await conn.fetch("SELECT * FROM users LIMIT 10")
return [dict(row) for row in result]
性能提升:异步数据库查询可减少 70% 的等待时间,在高并发场景下,异步模式比同步模式快 3-5 倍。
1.2 缓存策略
合理使用缓存可以显著降低数据库负载和响应时间。
使用 cachetools 进行内存缓存
from cachetools import TTLCache
from fastapi import FastAPI, Request
cache = TTLCache(maxsize=100, ttl=300) # 最多缓存100个条目,有效期5分钟
@app.get("/cached-items/{item_id}")
async def get_cached_item(item_id: int):
cache_key = f"item_{item_id}"
if cache_key in cache:
return {"data": cache[cache_key], "source": "cache"}
# 模拟从数据库获取数据
item = {"id": item_id, "name": f"Item {item_id}", "price": 100.0}
cache[cache_key] = item
return {"data": item, "source": "database"}
集成 Redis 分布式缓存
对于分布式部署,Redis 是更好的选择:
import aioredis
from fastapi import FastAPI
app = FastAPI()
@app.on_event("startup")
async def startup():
app.state.redis = await aioredis.from_url("redis://localhost")
@app.get("/redis-cached/{key}")
async def get_redis_cached(key: str):
# 尝试从缓存获取
cached = await app.state.redis.get(key)
if cached:
return {"data": cached.decode(), "source": "redis"}
# 模拟计算或数据库查询
value = f"Computed value for {key}"
await app.state.redis.setex(key, 300, value) # 5分钟过期
return {"data": value, "source": "computed"}
1.3 Gzip 压缩
启用 Gzip 压缩可以减少网络传输数据量,特别适合文本响应较大的 API。
from fastapi import FastAPI
from fastapi.middleware.gzip import GZipMiddleware
app = FastAPI()
# 添加 Gzip 中间件,仅压缩大于 1KB 的响应
app.add_middleware(GZipMiddleware, minimum_size=1000)
@app.get("/large-response")
async def get_large_response():
# 返回大量数据,会被自动压缩
return {"data": "x" * 10000}
1.4 速率限制
防止 API 被过度调用或恶意攻击,使用 slowapi 实现速率限制。
from fastapi import FastAPI, Request
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
limiter = Limiter(key_func=get_remote_address)
app = FastAPI()
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
@app.get("/public")
@limiter.limit("100/minute") # 每分钟最多100次请求
async def public_endpoint(request: Request):
return {"message": "This is a public endpoint with rate limiting"}
@app.get("/premium")
@limiter.limit("1000/hour") # 每小时最多1000次请求
async def premium_endpoint(request: Request):
return {"message": "Premium users have higher limits"}
1.5 连接池优化
合理配置数据库连接池大小对性能至关重要。连接池过小会导致请求等待,过大则会消耗过多资源。
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import sessionmaker
# 根据服务器配置调整连接池参数
engine = create_async_engine(
"postgresql+asyncpg://user:pass@localhost/db",
pool_size=20, # 连接池大小
max_overflow=10, # 超出 pool_size 后最多可以创建的连接数
pool_pre_ping=True, # 连接前检查,避免使用失效连接
pool_recycle=3600, # 连接回收时间(秒)
)
AsyncSessionLocal = sessionmaker(engine, class_=AsyncSession)
1.6 性能分析与监控
使用性能分析工具定位瓶颈。
import time
from fastapi import FastAPI, Request
app = FastAPI()
@app.middleware("http")
async def monitor_performance(request: Request, call_next):
start_time = time.time()
response = await call_next(request)
process_time = time.time() - start_time
response.headers["X-Process-Time"] = str(process_time)
# 记录慢请求
if process_time > 1.0:
print(f"Slow request: {request.url.path} took {process_time:.2f}s")
return response
2. 高级安全实践
2.1 多因素认证
除了基本的 JWT 认证,可以实现多因素认证增强安全性。
import pyotp
import qrcode
from io import BytesIO
from fastapi.responses import Response
class MFAService:
def __init__(self):
self.user_secrets = {} # 实际应用中应存储在数据库
def generate_secret(self, username: str):
secret = pyotp.random_base32()
self.user_secrets[username] = secret
return secret
def get_totp_uri(self, username: str, secret: str):
return pyotp.totp.TOTP(secret).provisioning_uri(
name=username,
issuer_name="MyApp"
)
def verify_totp(self, username: str, token: str):
if username not in self.user_secrets:
return False
totp = pyotp.TOTP(self.user_secrets[username])
return totp.verify(token)
mfa_service = MFAService()
@app.post("/mfa/setup")
async def setup_mfa(current_user: dict = Depends(get_current_user)):
secret = mfa_service.generate_secret(current_user["username"])
uri = mfa_service.get_totp_uri(current_user["username"], secret)
# 生成二维码
qr = qrcode.make(uri)
img_bytes = BytesIO()
qr.save(img_bytes)
img_bytes.seek(0)
return Response(content=img_bytes.getvalue(), media_type="image/png")
@app.post("/mfa/verify")
async def verify_mfa(
token: str = Form(...),
current_user: dict = Depends(get_current_user)
):
if mfa_service.verify_totp(current_user["username"], token):
return {"message": "MFA verified successfully"}
raise HTTPException(status_code=400, detail="Invalid token")
2.2 API 密钥管理
为第三方开发者提供 API 密钥认证。
from fastapi.security import APIKeyHeader
from pydantic import BaseModel
import secrets
from datetime import datetime, timedelta
api_key_header = APIKeyHeader(name="X-API-Key")
class APIKey(BaseModel):
key: str
name: str
created_at: datetime
expires_at: datetime
permissions: list[str]
class APIKeyManager:
def __init__(self):
self.keys = {} # 实际应用中应存储在数据库
def create_key(self, name: str, permissions: list[str], days_valid: int = 30):
key = secrets.token_urlsafe(32)
now = datetime.utcnow()
api_key = APIKey(
key=key,
name=name,
created_at=now,
expires_at=now + timedelta(days=days_valid),
permissions=permissions
)
self.keys[key] = api_key
return api_key
def verify_key(self, key: str):
if key not in self.keys:
return None
api_key = self.keys[key]
if api_key.expires_at < datetime.utcnow():
return None
return api_key
key_manager = APIKeyManager()
async def verify_api_key(api_key: str = Depends(api_key_header)):
key_obj = key_manager.verify_key(api_key)
if not key_obj:
raise HTTPException(
status_code=403,
detail="Invalid or expired API key"
)
return key_obj
@app.post("/api-keys/create")
async def create_api_key(
name: str,
permissions: list[str],
admin: dict = Depends(require_permission("admin"))
):
api_key = key_manager.create_key(name, permissions)
return api_key
@app.get("/protected-resource")
async def protected_resource(api_key: APIKey = Depends(verify_api_key)):
return {"message": f"Access granted for {api_key.name}"}
2.3 密码哈希策略增强
使用更安全的密码哈希算法和策略。
from passlib.context import CryptContext
import bcrypt
# 配置多种哈希算法,支持密码升级
pwd_context = CryptContext(
schemes=["bcrypt", "argon2"],
default="bcrypt",
bcrypt__rounds=12, # 增加计算复杂度
argon2__time_cost=2,
argon2__memory_cost=102400,
argon2__parallelism=8,
)
def verify_and_upgrade_password(plain_password: str, hashed_password: str):
"""验证密码,如果使用的不是最优算法则升级"""
is_valid, new_hash = pwd_context.verify_and_update(plain_password, hashed_password)
return is_valid, new_hash
2.4 CSRF 保护
对于非 API 的传统 Web 应用,需要添加 CSRF 保护。
from fastapi import FastAPI, Request, Form
from fastapi.templating import Jinja2Templates
from itsdangerous import URLSafeTimedSerializer
import secrets
app = FastAPI()
templates = Jinja2Templates(directory="templates")
# CSRF 保护配置
SECRET_KEY = "your-secret-key"
serializer = URLSafeTimedSerializer(SECRET_KEY)
def generate_csrf_token():
"""生成 CSRF 令牌"""
return secrets.token_urlsafe(32)
def validate_csrf_token(token: str):
"""验证 CSRF 令牌"""
try:
serializer.loads(token, max_age=3600) # 1小时过期
return True
except:
return False
@app.get("/form")
async def show_form(request: Request):
# 生成 CSRF 令牌并存储在 session 中
csrf_token = generate_csrf_token()
signed_token = serializer.dumps(csrf_token)
return templates.TemplateResponse(
"form.html",
{"request": request, "csrf_token": signed_token}
)
@app.post("/submit")
async def handle_form(
request: Request,
data: str = Form(...),
csrf_token: str = Form(...)
):
if not validate_csrf_token(csrf_token):
raise HTTPException(status_code=400, detail="Invalid CSRF token")
return {"message": "Form submitted successfully", "data": data}
3. 异步任务处理
3.1 BackgroundTasks 的进阶用法
除了简单的后台任务,可以结合数据库操作和错误处理。
from fastapi import BackgroundTasks
from sqlalchemy.ext.asyncio import AsyncSession
import asyncio
class TaskManager:
def __init__(self):
self.task_status = {}
async def long_running_task(self, task_id: str, params: dict):
"""模拟耗时任务"""
try:
self.task_status[task_id] = "running"
# 模拟多步骤处理
for step in range(5):
await asyncio.sleep(2) # 模拟耗时操作
self.task_status[task_id] = f"step_{step}"
self.task_status[task_id] = "completed"
except Exception as e:
self.task_status[task_id] = f"failed: {str(e)}"
task_manager = TaskManager()
@app.post("/start-task")
async def start_background_task(background_tasks: BackgroundTasks):
task_id = str(uuid.uuid4())
background_tasks.add_task(
task_manager.long_running_task,
task_id,
{"param1": "value1"}
)
return {"task_id": task_id, "status": "started"}
@app.get("/task-status/{task_id}")
async def get_task_status(task_id: str):
status = task_manager.task_status.get(task_id, "not_found")
return {"task_id": task_id, "status": status}
3.2 Celery 集成
对于更复杂的分布式任务处理,集成 Celery。
# celery_app.py
from celery import Celery
celery_app = Celery(
"tasks",
broker="redis://localhost:6379/0",
backend="redis://localhost:6379/0"
)
celery_app.conf.update(
task_serializer="json",
accept_content=["json"],
result_serializer="json",
timezone="UTC",
enable_utc=True,
task_track_started=True,
task_time_limit=30 * 60, # 30分钟
task_soft_time_limit=25 * 60, # 25分钟软限制
)
@celery_app.task(bind=True, max_retries=3)
def process_data(self, data_id: int):
"""处理数据的 Celery 任务"""
try:
# 执行耗时操作
result = perform_data_processing(data_id)
return {"status": "success", "result": result}
except Exception as e:
self.retry(exc=e, countdown=60) # 60秒后重试
在 FastAPI 中调用 Celery 任务:
# main.py
from celery.result import AsyncResult
from .celery_app import process_data
@app.post("/process/{data_id}")
async def start_processing(data_id: int):
# 异步调用 Celery 任务
task = process_data.delay(data_id)
return {"task_id": task.id, "status": "processing"}
@app.get("/task/{task_id}")
async def get_task_result(task_id: str):
task = AsyncResult(task_id, app=celery_app)
if task.state == "PENDING":
response = {"state": task.state, "status": "Task pending..."}
elif task.state == "FAILURE":
response = {"state": task.state, "status": str(task.info)}
else:
response = {"state": task.state, "result": task.result}
return response
3.3 定时任务
使用 APScheduler 或 Celery Beat 实现定时任务。
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from apscheduler.triggers.interval import IntervalTrigger
from contextlib import asynccontextmanager
scheduler = AsyncIOScheduler()
async def periodic_cleanup():
"""定期清理过期数据的任务"""
print(f"Running cleanup at {datetime.utcnow()}")
# 执行清理逻辑
await cleanup_expired_records()
@asynccontextmanager
async def lifespan(app: FastAPI):
# 启动时启动调度器
scheduler.add_job(
periodic_cleanup,
trigger=IntervalTrigger(hours=24),
id="daily_cleanup",
replace_existing=True
)
scheduler.start()
yield
# 关闭时停止调度器
scheduler.shutdown()
app = FastAPI(lifespan=lifespan)
4. 测试驱动开发
4.1 单元测试最佳实践
使用 pytest 编写全面的测试。
# test_main.py
import pytest
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool
from .main import app
from .database import Base, get_db
# 使用内存数据库进行测试
SQLALCHEMY_DATABASE_URL = "sqlite:///:memory:"
engine = create_engine(
SQLALCHEMY_DATABASE_URL,
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
def override_get_db():
"""重写依赖项"""
try:
db = TestingSessionLocal()
yield db
finally:
db.close()
app.dependency_overrides[get_db] = override_get_db
@pytest.fixture
def client():
"""创建测试客户端"""
Base.metadata.create_all(bind=engine)
with TestClient(app) as test_client:
yield test_client
Base.metadata.drop_all(bind=engine)
def test_create_item(client):
"""测试创建商品"""
response = client.post(
"/items/",
json={"name": "Test Item", "price": 100.0}
)
assert response.status_code == 200
data = response.json()
assert data["name"] == "Test Item"
assert data["price"] == 100.0
assert "id" in data
def test_get_nonexistent_item(client):
"""测试获取不存在的商品"""
response = client.get("/items/999")
assert response.status_code == 404
assert response.json()["detail"] == "Item not found"
@pytest.mark.parametrize("invalid_data", [
{"name": "Test", "price": -10}, # 价格不能为负
{"name": "", "price": 100}, # 名称不能为空
{"price": 100}, # 缺少名称
])
def test_invalid_item_creation(client, invalid_data):
"""测试无效的商品创建"""
response = client.post("/items/", json=invalid_data)
assert response.status_code == 422 # 验证错误
4.2 异步测试
测试异步端点需要特殊处理。
import pytest
from httpx import AsyncClient
@pytest.mark.asyncio
async def test_async_endpoint():
async with AsyncClient(app=app, base_url="http://test") as ac:
response = await ac.get("/async-data")
assert response.status_code == 200
assert response.json() == {"data": "async response"}
4.3 测试覆盖率
使用 pytest-cov 测量测试覆盖率。
pytest --cov=app --cov-report=html tests/
5. GraphQL 集成
FastAPI 可以与 Strawberry 或 Graphene 等 GraphQL 库集成,提供更灵活的查询能力。
5.1 使用 Strawberry 实现 GraphQL
import strawberry
from fastapi import FastAPI
from strawberry.fastapi import GraphQLRouter
from typing import List, Optional
# 定义 GraphQL 类型
@strawberry.type
class Book:
id: int
title: str
author: str
year: int
# 模拟数据
books_db = [
Book(id=1, title="The Hobbit", author="J.R.R. Tolkien", year=1937),
Book(id=2, title="1984", author="George Orwell", year=1949),
Book(id=3, title="Dune", author="Frank Herbert", year=1965),
]
# 定义 Query
@strawberry.type
class Query:
@strawberry.field
def books(self, author: Optional[str] = None) -> List[Book]:
if author:
return [b for b in books_db if b.author == author]
return books_db
@strawberry.field
def book(self, id: int) -> Optional[Book]:
return next((b for b in books_db if b.id == id), None)
# 定义 Mutation
@strawberry.type
class Mutation:
@strawberry.mutation
def add_book(self, title: str, author: str, year: int) -> Book:
new_id = max(books_db, key=lambda b: b.id).id + 1
book = Book(id=new_id, title=title, author=author, year=year)
books_db.append(book)
return book
schema = strawberry.Schema(query=Query, mutation=Mutation)
graphql_app = GraphQLRouter(schema)
app = FastAPI()
app.include_router(graphql_app, prefix="/graphql")
现在可以通过 /graphql 端点执行 GraphQL 查询:
# 查询示例
{
books(author: "J.R.R. Tolkien") {
title
year
}
}
# 变更示例
mutation {
addBook(title: "Foundation", author: "Isaac Asimov", year: 1951) {
id
title
}
}
6. 微服务架构中的 FastAPI
FastAPI 非常适合构建微服务。
6.1 服务发现与注册
使用 Consul 或 etcd 实现服务发现。
import consul
import socket
import uuid
class ServiceRegistry:
def __init__(self, consul_host="localhost", consul_port=8500):
self.consul = consul.Consul(host=consul_host, port=consul_port)
self.service_id = str(uuid.uuid4())
def register(self, service_name: str, port: int):
"""注册服务到 Consul"""
hostname = socket.gethostname()
ip = socket.gethostbyname(hostname)
self.consul.agent.service.register(
name=service_name,
service_id=self.service_id,
address=ip,
port=port,
check=consul.Check().tcp(ip, port, "10s"),
tags=["fastapi", "v1"]
)
def deregister(self):
"""注销服务"""
self.consul.agent.service.deregister(self.service_id)
def discover(self, service_name: str):
"""发现服务实例"""
_, services = self.consul.health.service(service_name, passing=True)
return [
{
"id": s["Service"]["ID"],
"address": s["Service"]["Address"],
"port": s["Service"]["Port"]
}
for s in services
]
# 在应用生命周期中注册
registry = ServiceRegistry()
@app.on_event("startup")
async def startup():
registry.register("user-service", 8000)
@app.on_event("shutdown")
async def shutdown():
registry.deregister()
6.2 服务间通信
使用 httpx 进行异步服务间通信。
import httpx
from fastapi import FastAPI
app = FastAPI()
class ServiceClient:
def __init__(self, service_name: str):
self.service_name = service_name
self.client = httpx.AsyncClient(timeout=10.0)
async def get_user(self, user_id: int):
# 实际应用中应从服务发现获取地址
response = await self.client.get(
f"http://user-service:8001/users/{user_id}"
)
response.raise_for_status()
return response.json()
async def close(self):
await self.client.aclose()
user_client = ServiceClient("user-service")
order_client = ServiceClient("order-service")
@app.get("/api/users/{user_id}/orders")
async def get_user_orders(user_id: int):
# 并行调用多个微服务
user_task = user_client.get_user(user_id)
orders_task = order_client.get_user_orders(user_id)
user, orders = await asyncio.gather(user_task, orders_task)
return {
"user": user,
"orders": orders
}
7. 总结
通过本教程,我们深入探讨了 FastAPI 的一系列高级特性:
- 性能优化:异步数据库操作、多级缓存策略、Gzip 压缩、速率限制、连接池优化和性能监控
- 高级安全实践:多因素认证、API 密钥管理、增强的密码哈希策略和 CSRF 保护
- 异步任务处理:BackgroundTasks 进阶用法、Celery 集成和定时任务
- 测试驱动开发:单元测试最佳实践、异步测试和测试覆盖率
- GraphQL 集成:使用 Strawberry 实现灵活的查询接口
- 微服务架构:服务发现注册和异步服务间通信
更多推荐
所有评论(0)