1. SQLModel 核心概念与架构

1.1 设计理念与目标

  • 提供简洁而强大的 API 来定义数据库模型
  • 支持类型提示,提供更好的 IDE 支持和代码自动完成
  • 与 Pydantic完全集成,实现数据验证
  • 与 SQLAlchemy 完全兼容,利用其强大的查询能力
  • 支持异步操作,适用于现代 Python 应用

1.2 与 Pydantic 和 SQLAlchemy 的关系

SQLModel 建立在两个强大的库之上:

  1. Pydantic:提供数据验证、序列化和类型提示
  2. SQLAlchemy:提供数据库操作、查询构建和 ORM 功能

SQLModel 的核心创新在于将这两个库的功能无缝集成,使开发者可以使用单一模型同时进行数据验证和数据库操作。

1.3 核心组件与工作原理

SQLModel 的核心组件包括:

  • Model:基础模型类,继承自 Pydantic 的 BaseModel 和 SQLAlchemy 的 declarative_base
  • Field:字段定义,支持 Pydantic 和 SQLAlchemy 的字段参数
  • Session:数据库会话,用于执行数据库操作
  • select:查询构建器,用于构建 SQL 查询

1.4 异步支持机制

SQLModel 支持异步操作,通过 SQLAlchemy 的异步扩展实现。这使得 SQLModel 可以在 FastAPI 的异步环境中无缝工作,提高应用的性能和并发处理能力。

1.5 SQLModel 配置选项

SQLModel 提供了多种配置选项,可以通过 SQLModelConfig 类进行配置:

from sqlmodel import SQLModel

class User(SQLModel, table=True):
    id: int | None = Field(default=None, primary_key=True)
    name: str
    email: str
    
    class Config:
        # 表名配置
        table_name = "users"
        # 启用自动生成表名
        use_enum_values = True
        # 验证模式
        validate_assignment = True
        # 禁止从字典创建模型
        from_attributes = True

常用配置选项

  • table_name:指定表名
  • use_enum_values:使用枚举值而不是枚举对象
  • validate_assignment:启用赋值验证
  • from_attributes:允许从属性创建模型
  • arbitrary_types_allowed:允许任意类型
  • json_encoders:自定义 JSON 编码器

1.6 数据库会话管理

数据库会话是 SQLModel 与数据库交互的核心。以下是会话管理的最佳实践:

from sqlmodel import Session, create_engine
from contextlib import contextmanager

# 创建引擎
engine = create_engine("sqlite:///./test.db")

# 会话上下文管理器
@contextmanager
def get_db():
    db = Session(engine)
    try:
        yield db
    finally:
        db.close()

# 使用示例
def create_user(name: str, email: str):
    with get_db() as db:
        user = User(name=name, email=email)
        db.add(user)
        db.commit()
        db.refresh(user)
        return user

1.7 异步示例

from sqlmodel import SQLModel, create_engine, select
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker

# 创建异步引擎
engine = create_async_engine("sqlite+aiosqlite:///./test.db")

# 创建会话工厂
AsyncSessionLocal = sessionmaker(
    engine, class_=AsyncSession, expire_on_commit=False
)

# 异步依赖
async def get_db():
    async with AsyncSessionLocal() as session:
        yield session

2. 安装与环境配置

2.1 安装 SQLModel 及其依赖

# 安装 SQLModel
pip install sqlmodel

# 安装数据库驱动
pip install psycopg2-binary  # PostgreSQL
pip install pymysql          # MySQL
pip install aiosqlite         # SQLite (异步)

# 安装 FastAPI
pip install fastapi uvicorn

# 安装其他依赖
pip install python-dotenv  # 环境变量管理
pip install alembic        # 数据库迁移

2.2 数据库连接配置

SQLModel 使用 SQLAlchemy 的连接 URL 格式,支持多种数据库:

数据库 连接 URL 格式
PostgreSQL postgresql://user:password@localhost:5432/dbname
MySQL mysql+pymysql://user:password@localhost:3306/dbname
SQLite sqlite:///./dbname.db
SQLite (异步) sqlite+aiosqlite:///./dbname.db

2.3 项目结构最佳实践

my_project/
├── app/
│   ├── __init__.py
│   ├── models/
│   │   ├── __init__.py
│   │   ├── base.py          # 基础模型
│   │   ├── user.py          # 用户模型
│   │   └── post.py          # 帖子模型
│   ├── schemas/
│   │   ├── __init__.py
│   │   ├── user.py          # 用户相关 Pydantic 模型
│   │   └── post.py          # 帖子相关 Pydantic 模型
│   ├── api/
│   │   ├── __init__.py
│   │   ├── user.py          # 用户相关 API 路由
│   │   └── post.py          # 帖子相关 API 路由
│   ├── dependencies.py      # 依赖注入
│   └── config.py            # 配置管理
├── main.py                  # FastAPI 应用入口
├── requirements.txt         # 依赖管理
├── alembic.ini              # Alembic 配置
└── migrations/             # 数据库迁移文件

2.4 初始化与配置示例

# app/config.py
import os
from dotenv import load_dotenv

load_dotenv()

DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///./test.db")

# app/models/base.py
from sqlmodel import SQLModel

# app/main.py
from fastapi import FastAPI
from sqlmodel import create_engine, SQLModel
from app.config import DATABASE_URL
from app.models import user, post

# 创建引擎
engine = create_engine(DATABASE_URL)

# 创建表
def create_tables():
    SQLModel.metadata.create_all(engine)

app = FastAPI()

@app.on_event("startup")
def startup_event():
    create_tables()

2.5 环境变量管理

使用 python-dotenv 管理环境变量,创建 .env 文件:

# .env
DATABASE_URL="postgresql://user:password@localhost:5432/dbname"

3. 模型定义与数据验证

3.1 基本模型定义语法

from sqlmodel import SQLModel, Field

class User(SQLModel, table=True):
    id: int | None = Field(default=None, primary_key=True)
    name: str
    email: str = Field(unique=True, index=True)
    age: int | None = None

3.2 字段类型与参数详解

SQLModel 支持多种字段类型,每种类型都有相应的参数:

字段类型 描述 常用参数
int 整数 primary_key, default, nullable
str 字符串 max_length, unique, index
float 浮点数 default, nullable
bool 布尔值 default, nullable
datetime 日期时间 default, nullable
date 日期 default, nullable
time 时间 default, nullable
UUID UUID primary_key, default
JSON JSON 数据 default, nullable

3.3 关系字段

from sqlmodel import SQLModel, Field, Relationship

class User(SQLModel, table=True):
    id: int | None = Field(default=None, primary_key=True)
    name: str
    email: str
    posts: list["Post"] = Relationship(back_populates="user")

class Post(SQLModel, table=True):
    id: int | None = Field(default=None, primary_key=True)
    title: str
    content: str
    user_id: int | None = Field(default=None, foreign_key="user.id")
    user: User | None = Relationship(back_populates="posts")

3.4 模型继承与组合

from sqlmodel import SQLModel, Field
from datetime import datetime

class BaseModel(SQLModel):
    id: int | None = Field(default=None, primary_key=True)
    created_at: datetime = Field(default_factory=datetime.utcnow)

class User(BaseModel, table=True):
    name: str
    email: str

class Post(BaseModel, table=True):
    title: str
    content: str
    user_id: int = Field(foreign_key="user.id")

3.5 索引与约束配置

from sqlmodel import SQLModel, Field, Index

class User(SQLModel, table=True):
    id: int | None = Field(default=None, primary_key=True)
    name: str
    email: str = Field(unique=True, index=True)
    age: int | None = None
    
    class Config:
        indexes = [
            Index("idx_user_name_age", "name", "age"),
        ]

3.6 Pydantic 集成的验证机制

SQLModel 继承自 Pydantic 的 BaseModel,因此支持所有 Pydantic 的验证功能:

from sqlmodel import SQLModel, Field

class UserCreate(SQLModel):
    name: str = Field(..., min_length=2, max_length=50)
    email: str = Field(..., pattern=r"^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+")
    age: int | None = Field(None, ge=0, le=120)

3.7 字段验证规则与参数

验证参数 描述 适用类型
min_length 最小长度 str
max_length 最大长度 str
ge 大于等于 int, float
le 小于等于 int, float
gt 大于 int, float
lt 小于 int, float
pattern 正则表达式 str
default 默认值 所有类型
nullable 是否可空 所有类型

3.8 自定义验证逻辑

from sqlmodel import SQLModel, Field
from pydantic import field_validator

class User(SQLModel, table=True):
    id: int | None = Field(default=None, primary_key=True)
    name: str
    email: str
    
    @field_validator('email')
    def email_must_contain_at(cls, v):
        if '@' not in v:
            raise ValueError('邮箱必须包含 @ 符号')
        return v

3.9 验证错误处理

from fastapi import FastAPI, HTTPException
from sqlmodel import SQLModel, Field

app = FastAPI()

class UserCreate(SQLModel):
    name: str = Field(..., min_length=2)
    email: str = Field(..., pattern=r"^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+")

@app.post("/users")
async def create_user(user: UserCreate):
    try:
        # 处理用户创建
        return user
    except ValueError as e:
        raise HTTPException(status_code=422, detail=str(e))

3.10 类型注解与类型系统

SQLModel 充分利用了 Python 的类型注解系统,提供了更好的类型提示和 IDE 支持:

from sqlmodel import SQLModel, Field
from typing import List
from datetime import datetime

class User(SQLModel, table=True):
    id: int | None = Field(default=None, primary_key=True)
    name: str
    email: str
    created_at: datetime = Field(default_factory=datetime.utcnow)
    posts: List["Post"] = Field(default_factory=list, sa_relationship={"back_populates": "user"})

class Post(SQLModel, table=True):
    id: int | None = Field(default=None, primary_key=True)
    title: str
    content: str
    user_id: int | None = Field(default=None, foreign_key="user.id")
    user: User | None = Field(default=None, sa_relationship={"back_populates": "posts"})

4. 数据库操作与查询

4.1 基础 CRUD 操作

from sqlmodel import SQLModel, Field, create_engine, select
from sqlalchemy.orm import sessionmaker

# 创建引擎和会话
engine = create_engine("sqlite:///./test.db")
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)

# 依赖注入
def get_db():
    db = SessionLocal()
    try:
        yield db
    finally:
        db.close()

# 示例:完整的 CRUD 操作
def crud_example():
    db = SessionLocal()
    try:
        # 创建用户
        user = User(name="Alice", email="alice@example.com")
        db.add(user)
        db.commit()
        db.refresh(user)
        print(f"Created user: {user}")
        
        # 读取用户
        read_user = db.exec(select(User).where(User.id == user.id)).first()
        print(f"Read user: {read_user}")
        
        # 更新用户
        read_user.name = "Alice Smith"
        db.add(read_user)
        db.commit()
        db.refresh(read_user)
        print(f"Updated user: {read_user}")
        
        # 删除用户
        db.delete(read_user)
        db.commit()
        print(f"Deleted user")
        
    finally:
        db.close()

4.2 高级查询技巧

from sqlmodel import select

# 过滤查询
def get_active_users(db):
    users = db.exec(select(User).where(User.is_active == True)).all()
    return users

# 排序查询
def get_users_ordered_by_name(db):
    users = db.exec(select(User).order_by(User.name)).all()
    return users

# 分组查询
def get_user_count_by_age(db):
    from sqlalchemy import func
    result = db.exec(
        select(User.age, func.count(User.id))
        .group_by(User.age)
    ).all()
    return result

# 限制和偏移(分页)
def get_users_paginated(db, skip: int = 0, limit: int = 10):
    users = db.exec(select(User).offset(skip).limit(limit)).all()
    return users

# 关系查询与预加载
def get_users_with_posts(db):
    from sqlalchemy.orm import selectinload
    users = db.exec(
        select(User).options(selectinload(User.posts))
    ).all()
    return users

# 嵌套预加载
def get_posts_with_user_and_comments(db):
    from sqlalchemy.orm import selectinload
    posts = db.exec(
        select(Post).options(
            selectinload(Post.user),
            selectinload(Post.comments)
        )
    ).all()
    return posts

# 原生 SQL 执行
def get_users_with_native_sql(db):
    result = db.execute("SELECT * FROM user WHERE age > :age", {"age": 18})
    users = [User(**dict(row)) for row in result]
    return users

4.3 事务管理

def transaction_example():
    db = SessionLocal()
    try:
        # 开始事务
        with db.begin():
            # 创建用户
            user = User(name="Bob", email="bob@example.com")
            db.add(user)
            
            # 创建帖子
            post = Post(title="Hello", content="World", user_id=user.id)
            db.add(post)
            
            # 如果发生异常,事务会自动回滚
            # db.commit() 会在 with 块结束时自动调用
    except Exception as e:
        print(f"Transaction failed: {e}")
        # 事务会自动回滚
    finally:
        db.close()

# 嵌套事务
def nested_transaction_example():
    db = SessionLocal()
    try:
        # 开始外层事务
        with db.begin():
            # 创建用户
            user = User(name="Alice", email="alice@example.com")
            db.add(user)
            db.flush()  # 刷新到数据库,获取 ID
            
            # 创建保存点
            savepoint = db.begin_nested()
            try:
                # 尝试创建帖子
                post = Post(title="Test", content="Content", user_id=user.id)
                db.add(post)
                
                # 模拟错误
                raise ValueError("Simulated error")
                
                savepoint.commit()
            except Exception as e:
                # 回滚到保存点
                savepoint.rollback()
                print(f"Nested transaction failed: {e}")
                # 外层事务继续执行
            
            # 外层事务提交
    except Exception as e:
        print(f"Outer transaction failed: {e}")
    finally:
        db.close()

嵌套事务的使用场景

  • 复杂业务逻辑中需要部分回滚
  • 批量操作中需要对部分操作进行单独处理
  • 测试场景中需要隔离不同操作

注意事项

  • 嵌套事务只在支持保存点的数据库中有效
  • 嵌套事务的性能开销较大,应谨慎使用

4.4 批量操作与性能优化

# 批量获取多个 ID 的用户
def get_users_by_ids(db, user_ids: list[int]):
    """批量获取用户,使用 IN 查询"""
    users = db.exec(select(User).where(User.id.in_(user_ids))).all()
    return users

# 批量插入用户
def bulk_create_users(db, users: list[User]):
    """批量创建用户,使用 add_all 方法"""
    db.add_all(users)
    db.commit()
    for user in users:
        db.refresh(user)
    return users

# 批量更新用户
def bulk_update_users(db, user_ids: list[int], update_data: dict):
    """批量更新用户,使用 update 方法"""
    result = db.exec(
        select(User).where(User.id.in_(user_ids))
    ).all()
    
    for user in result:
        for key, value in update_data.items():
            setattr(user, key, value)
    
    db.commit()
    return result

# 批量操作示例
def batch_operations():
    db = SessionLocal()
    try:
        # 批量创建
        users = [
            User(name=f"User {i}", email=f"user{i}@example.com")
            for i in range(100)
        ]
        db.add_all(users)
        db.commit()
        print(f"Created {len(users)} users")
        
        # 批量更新
        users_to_update = db.exec(select(User).where(User.id < 50)).scalars().all()
        for user in users_to_update:
            user.name = f"Updated {user.name}"
        db.commit()
        print(f"Updated {len(users_to_update)} users")
        
    finally:
        db.close()

性能对比

  • 单次插入:每次插入一条记录,需要多次数据库往返
  • 批量插入:一次插入多条记录,减少数据库往返次数
  • 批量更新:一次更新多条记录,提高更新效率

最佳实践

  • 对于大量数据,使用批量操作
  • 批量操作的大小应根据数据库性能和网络延迟调整
  • 批量操作时应注意事务管理,确保数据一致性

4.5 并发控制策略

from sqlalchemy.exc import StaleDataError

def optimistic_locking():
    db = SessionLocal()
    try:
        # 获取用户
        user = db.exec(select(User).where(User.id == 1)).first()
        original_version = user.version
        
        # 模拟并发更新
        # 假设另一个进程已经更新了这个用户
        
        # 尝试更新
        user.name = "New Name"
        user.version += 1
        
        try:
            db.commit()
            print("Update successful")
        except StaleDataError:
            print("Concurrent update detected, retry needed")
            db.rollback()
            # 重试逻辑
            
    finally:
        db.close()

4.6 分页与限流

from fastapi import Query

@app.get("/users")
def get_users(
    db, 
    skip: int = Query(0, ge=0), 
    limit: int = Query(10, ge=1, le=100)
):
    users = get_users_paginated(db, skip=skip, limit=limit)
    return users

4.7 数据库迁移(Alembic 集成)

# 初始化 Alembic
alembic init alembic

# 修改 alembic.ini 配置数据库 URL
# sqlalchemy.url = sqlite:///./test.db

# 修改 alembic/env.py 导入 SQLModel
# from app.models import SQLModel
# target_metadata = SQLModel.metadata

# 生成迁移
alembic revision --autogenerate -m "Initial migration"

# 执行迁移
alembic upgrade head

# 回滚迁移
alembic downgrade -1

5. FastAPI 集成最佳实践

5.1 依赖注入设计

from fastapi import FastAPI, Depends
from sqlmodel import Session, create_engine, select
from app.models.user import User
from app.schemas.user import UserCreate, UserResponse

app = FastAPI()

# 创建引擎
engine = create_engine("sqlite:///./test.db")

# 会话依赖
def get_db():
    db = Session(engine)
    try:
        yield db
    finally:
        db.close()

# API 路由
@app.post("/users", response_model=UserResponse)
def create_user(user: UserCreate, db: Session = Depends(get_db)):
    db_user = User.model_validate(user)
    db.add(db_user)
    db.commit()
    db.refresh(db_user)
    return db_user

@app.get("/users", response_model=list[UserResponse])
def get_users(db: Session = Depends(get_db)):
    users = db.exec(select(User)).all()
    return users

5.2 API 路由与模型绑定

from fastapi import APIRouter, Depends, HTTPException
from sqlmodel import Session, select
from app.models.user import User
from app.schemas.user import UserCreate, UserUpdate, UserResponse
from app.dependencies import get_db

router = APIRouter(prefix="/users", tags=["users"])

@router.post("/", response_model=UserResponse)
def create_user(user: UserCreate, db: Session = Depends(get_db)):
    # 检查邮箱是否已存在
    existing_user = db.exec(select(User).where(User.email == user.email)).first()
    if existing_user:
        raise HTTPException(status_code=400, detail="Email already registered")
    
    db_user = User.model_validate(user)
    db.add(db_user)
    db.commit()
    db.refresh(db_user)
    return db_user

@router.get("/{user_id}", response_model=UserResponse)
def get_user(user_id: int, db: Session = Depends(get_db)):
    user = db.exec(select(User).where(User.id == user_id)).first()
    if not user:
        raise HTTPException(status_code=404, detail="User not found")
    return user

@router.put("/{user_id}", response_model=UserResponse)
def update_user(user_id: int, user: UserUpdate, db: Session = Depends(get_db)):
    db_user = db.exec(select(User).where(User.id == user_id)).first()
    if not db_user:
        raise HTTPException(status_code=404, detail="User not found")
    
    user_data = user.model_dump(exclude_unset=True)
    for key, value in user_data.items():
        setattr(db_user, key, value)
    
    db.add(db_user)
    db.commit()
    db.refresh(db_user)
    return db_user

@router.delete("/{user_id}")
def delete_user(user_id: int, db: Session = Depends(get_db)):
    db_user = db.exec(select(User).where(User.id == user_id)).first()
    if not db_user:
        raise HTTPException(status_code=404, detail="User not found")
    
    db.delete(db_user)
    db.commit()
    return {"message": "User deleted"}

5.3 响应模型配置

from sqlmodel import SQLModel
from typing import Optional
from datetime import datetime

class UserBase(SQLModel):
    name: str
    email: str

class UserCreate(UserBase):
    password: str

class UserUpdate(SQLModel):
    name: Optional[str] = None
    email: Optional[str] = None

class UserResponse(UserBase):
    id: int
    created_at: datetime
    
    class Config:
        from_attributes = True

5.4 安全性考虑

from passlib.context import CryptContext
from sqlmodel import SQLModel, Field

# 密码哈希
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")

class User(SQLModel, table=True):
    id: int | None = Field(default=None, primary_key=True)
    name: str
    email: str = Field(unique=True, index=True)
    password_hash: str
    
    def set_password(self, password: str):
        self.password_hash = pwd_context.hash(password)
    
    def verify_password(self, password: str) -> bool:
        return pwd_context.verify(password, self.password_hash)

# 认证依赖
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from jose import JWTError, jwt
from datetime import datetime, timedelta

import os
from dotenv import load_dotenv

load_dotenv()

SECRET_KEY = os.getenv("SECRET_KEY", "your-secret-key")  # 从环境变量获取,默认值仅用于开发
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30

oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/token")

def create_access_token(data: dict, expires_delta: timedelta | None = None):
    to_encode = data.copy()
    if expires_delta:
        expire = datetime.utcnow() + expires_delta
    else:
        expire = datetime.utcnow() + timedelta(minutes=15)
    to_encode.update({"exp": expire})
    encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
    return encoded_jwt

async def get_current_user(token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)):
    credentials_exception = HTTPException(
        status_code=status.HTTP_401_UNAUTHORIZED,
        detail="Could not validate credentials",
        headers={"WWW-Authenticate": "Bearer"},
    )
    try:
        payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
        email: str = payload.get("sub")
        if email is None:
            raise credentials_exception
    except JWTError:
        raise credentials_exception
    user = db.exec(select(User).where(User.email == email)).first()
    if user is None:
        raise credentials_exception
    return user

# 登录路由
@app.post("/token")
def login(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)):
    user = db.exec(select(User).where(User.email == form_data.username)).first()
    if not user or not user.verify_password(form_data.password):
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Incorrect email or password",
            headers={"WWW-Authenticate": "Bearer"},
        )
    access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
    access_token = create_access_token(
        data={"sub": user.email}, expires_delta=access_token_expires
    )
    return {"access_token": access_token, "token_type": "bearer"}

# 受保护的路由
@app.get("/users/me", response_model=UserResponse)
def read_users_me(current_user: User = Depends(get_current_user)):
    return current_user

### OAuth2 完整实现示例

以下是一个完整的 OAuth2 实现示例,包括密码流和令牌验证:

```python
from fastapi import FastAPI, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from sqlmodel import Session, select
from jose import JWTError, jwt
from datetime import datetime, timedelta
from typing import Optional

app = FastAPI()

# 配置
SECRET_KEY = "your-secret-key"  # 实际应用中应从环境变量获取
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30

# OAuth2 密码流
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/token")

# 数据库依赖
def get_db():
    db = Session(engine)
    try:
        yield db
    finally:
        db.close()

# 创建访问令牌
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
    to_encode = data.copy()
    if expires_delta:
        expire = datetime.utcnow() + expires_delta
    else:
        expire = datetime.utcnow() + timedelta(minutes=15)
    to_encode.update({"exp": expire})
    encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
    return encoded_jwt

# 获取当前用户
async def get_current_user(token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)):
    credentials_exception = HTTPException(
        status_code=status.HTTP_401_UNAUTHORIZED,
        detail="Could not validate credentials",
        headers={"WWW-Authenticate": "Bearer"},
    )
    try:
        payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
        email: str = payload.get("sub")
        if email is None:
            raise credentials_exception
    except JWTError:
        raise credentials_exception
    user = db.exec(select(User).where(User.email == email)).first()
    if user is None:
        raise credentials_exception
    return user

# 登录路由
@app.post("/token")
def login(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)):
    user = db.exec(select(User).where(User.email == form_data.username)).first()
    if not user or not user.verify_password(form_data.password):
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Incorrect email or password",
            headers={"WWW-Authenticate": "Bearer"},
        )
    access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
    access_token = create_access_token(
        data={"sub": user.email}, expires_delta=access_token_expires
    )
    return {"access_token": access_token, "token_type": "bearer"}

# 受保护的路由示例
@app.get("/protected")
def protected_route(current_user: User = Depends(get_current_user)):
    return {"message": "Hello from protected route", "user": current_user.name}

OAuth2 实现的最佳实践

  • 使用环境变量存储密钥
  • 设置合理的令牌过期时间
  • 实现令牌刷新机制
  • 添加令牌撤销功能
  • 记录令牌使用日志

6. 常见问题与解决方案

6.1 易错点分析

  • 忘记提交事务
    • 错误:修改数据后没有调用 db.commit()
    • 解决方案:使用 with db.begin() 上下文管理器,或手动调用 db.commit()
  • N+1 查询问题
    • 错误:循环中执行查询
    • 解决方案:使用 selectinload 或 joinedload 预加载关系
  • 类型注解错误
    • 错误:使用 Optional[int] 而不是 int | None
    • 解决方案:使用 Python 3.10+ 的联合类型注解 int | None
  • 关系字段配置错误
    • 错误:Relationship 配置不正确
    • 解决方案:确保 back_populates 参数与关联模型的字段名一致
  • 数据库连接泄漏
    • 错误:没有正确关闭数据库连接
    • 解决方案:使用依赖注入和 try-finally 确保连接关闭

6.2 疑难点解析

  • 异步操作
    • 问题:异步会话和同步会话的区别
    • 解决方案:使用 AsyncSession 和 create_async_engine 进行异步操作
  • 事务隔离级别
    • 问题:不同隔离级别的影响
    • 解决方案:根据业务需求选择合适的隔离级别
  • 复杂查询构建
    • 问题:构建复杂的 SQL 查询
    • 解决方案:使用 SQLAlchemy 的查询构建器,或执行原生 SQL
  • 数据库迁移
    • 问题:迁移失败或数据丢失
    • 解决方案:使用 Alembic 管理迁移,迁移前备份数据库
  • 性能优化
    • 问题:查询速度慢
    • 解决方案:使用索引,优化查询,合理使用缓存

6.3 性能问题定位

  1. 使用 SQL 日志
    engine = create_engine( "sqlite:///./test.db", echo=True # 打印 SQL 语句 )
  2. 使用性能分析工具
import cProfile
def profile_query():
    cProfile.runctx(
       "db.exec(select(User).options(selectinload(User.posts)).all())",
       globals(),
       locals()
   )

3. 监控数据库连接

   from sqlalchemy import event
   
   @event.listens_for(engine, "connect")
   def receive_connect(dbapi_connection, connection_record):
       print(f"Connection established: {connection_record}")
   
   @event.listens_for(engine, "checkout")
   def receive_checkout(dbapi_connection, connection_record, connection_proxy):
       print(f"Connection checked out: {connection_record}")

7. 性能优化与扩展

7.1 缓存策略

from sqlmodel import Session, select

# 内存缓存
# 注意:不应该缓存数据库会话,这里仅作为示例
# 实际应用中应该使用 Redis 等外部缓存
def get_user_by_id(db: Session, user_id: int):
    return db.exec(select(User).where(User.id == user_id)).first()

# Redis 缓存
import redis
import json

redis_client = redis.Redis(host="localhost", port=6379, db=0)

def get_user_from_cache(user_id: int, db: Session):
    # 尝试从缓存获取
    cached_user = redis_client.get(f"user:{user_id}")
    if cached_user:
        return json.loads(cached_user)
    
    # 从数据库获取
    user = db.exec(select(User).where(User.id == user_id)).first()
    if user:
        # 存入缓存
        redis_client.set(f"user:{user_id}", json.dumps(user.model_dump()), ex=3600)
    return user

# 缓存失效
def invalidate_user_cache(user_id: int):
    redis_client.delete(f"user:{user_id}")

7.2 索引设计与优化

from sqlmodel import SQLModel, Field, Index

class User(SQLModel, table=True):
    id: int | None = Field(default=None, primary_key=True)
    name: str = Field(index=True)
    email: str = Field(unique=True, index=True)
    age: int | None = Field(index=True)
    
    class Config:
        indexes = [
            Index("idx_name_age", "name", "age"),
        ]

class Post(SQLModel, table=True):
    id: int | None = Field(default=None, primary_key=True)
    title: str = Field(index=True)
    content: str
    user_id: int = Field(foreign_key="user.id", index=True)

7.3 查询优化技巧

from sqlmodel import select
from sqlalchemy import func

# 避免 N+1 查询
def get_users_with_posts_optimized(db):
    from sqlalchemy.orm import selectinload
    users = db.exec(
        select(User).options(selectinload(User.posts))
    ).all()
    return users

# 使用 join 优化查询
def get_posts_with_user(db):
    posts = db.exec(
        select(Post, User)
        .join(User, Post.user_id == User.id)
    ).all()
    return posts

# 聚合查询优化
def get_user_statistics(db):
    result = db.exec(
        select(
            func.count(User.id).label("total_users"),
            func.avg(User.age).label("average_age"),
            func.min(User.age).label("min_age"),
            func.max(User.age).label("max_age")
        )
    ).first()
    return result

7.4 连接池配置

from sqlmodel import create_engine

# 连接池配置
engine = create_engine(
    "postgresql://user:password@localhost:5432/dbname",
    pool_size=20,          # 连接池大小
    max_overflow=10,        # 最大溢出连接数
    pool_recycle=3600,      # 连接回收时间
    pool_pre_ping=True,     # 连接前 ping 数据库
    pool_use_lifo=True      # 使用 LIFO 策略
)

7.5 扩展开发与自定义功能

from sqlmodel import SQLModel, Field
from sqlalchemy import Column, String
from sqlalchemy.dialects.postgresql import JSONB

# 自定义字段
class JSONBField(Field):
    sa_type = JSONB

class User(SQLModel, table=True):
    id: int | None = Field(default=None, primary_key=True)
    name: str
    preferences: dict | None = JSONBField(default=None)

# 自定义模型基类
from datetime import datetime

class TimestampedModel(SQLModel):
    created_at: datetime = Field(default_factory=datetime.utcnow)
    updated_at: datetime = Field(default_factory=datetime.utcnow)

class User(TimestampedModel, table=True):
    id: int | None = Field(default=None, primary_key=True)
    name: str
    email: str

8. 源码分析与高级特性

8.1 核心源码结构

SQLModel 的核心源码结构如下:

  • sqlmodel/init.py:主要导出
  • sqlmodel/main.py:核心模型类和函数
  • sqlmodel/fields.py:字段定义
  • sqlmodel/relationship.py:关系处理
  • sqlmodel/sql.py:SQL 相关功能
  • sqlmodel/utils.py:工具函数

8.2 继承关系分析

8.3 执行流程图解

Logo

腾讯云面向开发者汇聚海量精品云计算使用和开发经验,营造开放的云计算技术生态圈。

更多推荐