《计算机视觉:从入门到精通》技术手册 第11章 Vision Transform er(ViT)及其变体
仅在深层使用自注意力;Vision Transformer(ViT)将自然语言处理中的Transformer架构迁移至计算机视觉领域,通过将图像分割为序列化的图像块(patch)并应用自注意力机制,实现了与卷积神经网络(CNN)相竞争甚至超越的性能。通过采用类似的训练配方(EMA、Stochastic Depth、Mixup、Cutmix、RandAugment、随机擦除)和架构设计(大核卷积、更
目录
第11章 Vision Transform er(ViT)及其变体
11.1.1 图像分块(Patch Embedding)与线性投影
11.1.2 类令牌(CLS Token)与全局平均池化对比
11.1.3 大规模预训练策略:ImageNet-21K, JFT-300M
11.2.1 Swin Transformer:移位窗口与层次化特征
11.2.2 Pyramid Vision Transformer(PVT)
11.2.4 ConvNeXt:CNN对Transformer的反击
11.3.1 MobileViT:轻量级视觉Transformer
11.3.2 EfficientFormer与EdgeViT
第11章 Vision Transform er(ViT)及其变体
11.1 ViT基础架构
Vision Transformer(ViT)将自然语言处理中的Transformer架构迁移至计算机视觉领域,通过将图像分割为序列化的图像块(patch)并应用自注意力机制,实现了与卷积神经网络(CNN)相竞争甚至超越的性能。ViT的核心创新在于摒弃了传统CNN中固有的归纳偏置,完全依赖注意力机制学习空间关系,这一设计范式在充足数据量和计算资源的条件下展现出强大的表征能力。
11.1.1 图像分块(Patch Embedding)与线性投影

在实际实现中,图像分块与线性投影通常通过卷积操作高效完成。使用步长等于卷积核大小的二维卷积层,可以在单次前向传播中同时完成分块和投影操作,其计算效率显著高于显式的循环展平操作。
实例:Patch Embedding的完整实现
以下代码展示了Patch Embedding模块的PyTorch实现,包含高效的卷积式分块、位置编码初始化以及可选的Dropout正则化:
"""
Script: vit_patch_embedding.py
Description: Vision Transformer Patch Embedding模块实现
Usage:
python vit_patch_embedding.py --image_size 224 --patch_size 16 --embed_dim 768
Features:
- 使用Conv2d实现高效图像分块与线性投影
- 支持可学习位置编码与CLS Token
- 包含完整的可视化与验证功能
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from typing import Tuple, Optional
import argparse
class PatchEmbedding(nn.Module):
"""
ViT Patch Embedding模块
将输入图像分割为固定大小的图像块,并通过线性投影映射到嵌入空间。
使用卷积操作实现高效的分块与投影,时间复杂度为O(HW)。
"""
def __init__(
self,
image_size: int = 224,
patch_size: int = 16,
in_channels: int = 3,
embed_dim: int = 768,
dropout: float = 0.0,
add_cls_token: bool = True
):
super().__init__()
self.image_size = image_size
self.patch_size = patch_size
self.num_patches = (image_size // patch_size) ** 2
self.embed_dim = embed_dim
self.add_cls_token = add_cls_token
# 使用Conv2d实现高效分块: 核大小=步长=patch_size
# 输入: (B, C, H, W) -> 输出: (B, embed_dim, H/P, W/P)
self.proj = nn.Conv2d(
in_channels=in_channels,
out_channels=embed_dim,
kernel_size=patch_size,
stride=patch_size,
padding=0,
bias=True
)
# 初始化:使用截断正态分布,标准差与patch_size相关
nn.init.trunc_normal_(self.proj.weight, std=0.02)
if self.proj.bias is not None:
nn.init.constant_(self.proj.bias, 0)
# CLS Token: 可学习的全局聚合表示
if add_cls_token:
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
nn.init.trunc_normal_(self.cls_token, std=0.02)
self.num_tokens = self.num_patches + 1
else:
self.num_tokens = self.num_patches
# 位置编码: 可学习参数,捕获空间位置信息
self.pos_embed = nn.Parameter(torch.zeros(1, self.num_tokens, embed_dim))
nn.init.trunc_normal_(self.pos_embed, std=0.02)
# Dropout正则化
self.dropout = nn.Dropout(p=dropout) if dropout > 0 else nn.Identity()
# 预计算位置编码的插值网格,用于动态分辨率
self.register_buffer('pos_embed_grid', self._create_pos_embed_grid(), persistent=False)
def _create_pos_embed_grid(self) -> torch.Tensor:
"""创建用于2D插值的位置编码网格"""
if not self.add_cls_token:
return torch.zeros(1, 1, 1) # 占位符
# 分离CLS token与patch位置编码
pos_embed_patch = self.pos_embed[:, 1:, :].reshape(
1, self.image_size // self.patch_size, self.image_size // self.patch_size, self.embed_dim
).permute(0, 3, 1, 2) # (1, D, H', W')
return pos_embed_patch
def interpolate_pos_encoding(self, x: torch.Tensor, h: int, w: int) -> torch.Tensor:
"""
支持动态分辨率的位置编码插值
当输入图像尺寸与预训练尺寸不同时,使用双线性插值调整位置编码。
这是处理可变分辨率输入的关键技术。
"""
npatches = x.shape[1] - 1 if self.add_cls_token else x.shape[1]
N = self.pos_embed.shape[1] - 1 if self.add_cls_token else self.pos_embed.shape[1]
if npatches == N and h == w == self.image_size:
return self.pos_embed
if self.add_cls_token:
class_pos_embed = self.pos_embed[:, 0:1, :]
patch_pos_embed = self.pos_embed[:, 1:, :]
else:
patch_pos_embed = self.pos_embed
# 重塑为2D特征图进行插值
dim = x.shape[-1]
h0 = h // self.patch_size
w0 = w // self.patch_size
patch_pos_embed = patch_pos_embed.reshape(
1, self.image_size // self.patch_size, self.image_size // self.patch_size, dim
).permute(0, 3, 1, 2)
# 双线性插值
patch_pos_embed = F.interpolate(
patch_pos_embed, size=(h0, w0), mode='bilinear', align_corners=False
).permute(0, 2, 3, 1).reshape(1, -1, dim)
if self.add_cls_token:
return torch.cat([class_pos_embed, patch_pos_embed], dim=1)
return patch_pos_embed
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
前向传播
Args:
x: 输入图像张量,形状 (B, C, H, W)
Returns:
x: 嵌入序列,形状 (B, N, D) 其中N为token数量
patch_indices: 可选,返回每个patch的原始空间位置索引
"""
B, C, H, W = x.shape
# 验证输入尺寸可被patch_size整除
if H % self.patch_size != 0 or W % self.patch_size != 0:
raise ValueError(f"Image dimensions ({H},{W}) must be divisible by patch_size {self.patch_size}")
# 卷积分块与投影: (B, C, H, W) -> (B, embed_dim, H/P, W/P)
x = self.proj(x)
# 展平为序列: (B, embed_dim, H/P, W/P) -> (B, embed_dim, N) -> (B, N, embed_dim)
x = x.flatten(2).transpose(1, 2) # (B, N, D)
# 添加CLS Token
if self.add_cls_token:
cls_tokens = self.cls_token.expand(B, -1, -1) # (B, 1, D)
x = torch.cat([cls_tokens, x], dim=1) # (B, N+1, D)
# 添加位置编码(支持动态分辨率)
x = x + self.interpolate_pos_encoding(x, H, W)
x = self.dropout(x)
return x
class PatchEmbeddingVisualizer:
"""Patch Embedding可视化工具"""
@staticmethod
def visualize_patches(image: torch.Tensor, patch_size: int, save_path: Optional[str] = None):
"""
可视化图像分块过程
Args:
image: 输入图像 (C, H, W) 或 (H, W, C)
patch_size: 分块大小
"""
if image.dim() == 3 and image.shape[0] in [1, 3]:
image = image.permute(1, 2, 0).cpu().numpy()
else:
image = image.cpu().numpy()
H, W, C = image.shape
num_patches_h = H // patch_size
num_patches_w = W // patch_size
fig, axes = plt.subplots(num_patches_h, num_patches_w, figsize=(12, 12))
fig.suptitle(f'Image Patches Visualization ({patch_size}x{patch_size})', fontsize=16)
for i in range(num_patches_h):
for j in range(num_patches_w):
patch = image[i*patch_size:(i+1)*patch_size, j*patch_size:(j+1)*patch_size, :]
if C == 1:
axes[i, j].imshow(patch.squeeze(), cmap='gray')
else:
# 反归一化用于显示
patch = (patch - patch.min()) / (patch.max() - patch.min() + 1e-8)
axes[i, j].imshow(patch)
axes[i, j].axis('off')
axes[i, j].set_title(f'({i},{j})', fontsize=8)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=150, bbox_inches='tight')
plt.show()
@staticmethod
def visualize_pos_embed(pos_embed: torch.Tensor, patch_size: int, grid_size: int):
"""可视化位置编码的余弦相似度矩阵"""
# 移除CLS token
if pos_embed.shape[1] > grid_size * grid_size:
pos_embed = pos_embed[:, 1:, :]
pos_embed = pos_embed.squeeze(0) # (N, D)
# 计算余弦相似度
sim_matrix = F.cosine_similarity(
pos_embed.unsqueeze(1), pos_embed.unsqueeze(0), dim=-1
).cpu().numpy()
plt.figure(figsize=(10, 8))
plt.imshow(sim_matrix, cmap='viridis', aspect='auto')
plt.colorbar(label='Cosine Similarity')
plt.title('Position Embedding Cosine Similarity Matrix')
plt.xlabel('Patch Index')
plt.ylabel('Patch Index')
plt.tight_layout()
plt.show()
def benchmark_patch_embedding():
"""性能基准测试"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
configs = [
{'image_size': 224, 'patch_size': 16, 'embed_dim': 768}, # ViT-Base
{'image_size': 224, 'patch_size': 32, 'embed_dim': 768}, # 大patch
{'image_size': 384, 'patch_size': 16, 'embed_dim': 768}, # 高分辨率
]
for config in configs:
print(f"\nConfig: {config}")
model = PatchEmbedding(**config).to(device)
x = torch.randn(32, 3, config['image_size'], config['image_size']).to(device)
# 预热
for _ in range(10):
_ = model(x)
# 同步GPU
if device.type == 'cuda':
torch.cuda.synchronize()
import time
start = time.time()
for _ in range(100):
out = model(x)
if device.type == 'cuda':
torch.cuda.synchronize()
elapsed = (time.time() - start) / 100 * 1000 # ms
print(f" Output shape: {out.shape}")
print(f" Average time: {elapsed:.3f} ms")
print(f" Parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='ViT Patch Embedding Demo')
parser.add_argument('--image_size', type=int, default=224)
parser.add_argument('--patch_size', type=int, default=16)
parser.add_argument('--embed_dim', type=int, default=768)
parser.add_argument('--visualize', action='store_true', help='启用可视化')
args = parser.parse_args()
# 单元测试
model = PatchEmbedding(
image_size=args.image_size,
patch_size=args.patch_size,
embed_dim=args.embed_dim
)
# 测试前向传播
x = torch.randn(2, 3, args.image_size, args.image_size)
output = model(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Number of patches: {model.num_patches}")
if args.visualize:
# 创建随机彩色图像进行可视化
dummy_image = torch.randn(3, args.image_size, args.image_size)
PatchEmbeddingVisualizer.visualize_patches(dummy_image, args.patch_size)
PatchEmbeddingVisualizer.visualize_pos_embed(
model.pos_embed, args.patch_size, args.image_size // args.patch_size
)
# 运行性能测试
benchmark_patch_embedding()
11.1.2 类令牌(CLS Token)与全局平均池化对比

其中查询(Query)来自CLS Token,键(Key)和值(Value)来自所有Token,这使得CLS Token能够加权聚合全局信息。
相比之下,全局平均池化(Global Average Pooling,GAP)直接对所有图像块Token的空间维度进行平均,生成固定长度的全局特征向量。GAP的计算成本更低,但缺乏CLS Token的动态聚合能力。
实验研究表明,在充足数据量下,CLS Token与GAP的性能差异不显著;但在中等规模数据集上,CLS Token通常表现更优,因其提供了额外的可学习参数用于优化。然而,GAP在密集预测任务(如语义分割)中具有天然优势,因其保留了空间对应关系。
实例:CLS Token与GAP的对比实现
"""
Script: cls_vs_gap.py
Description: 对比CLS Token与全局平均池化的实现与性能
Usage:
python cls_vs_gap.py --method cls --dataset cifar10 --epochs 10
"""
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
from typing import Literal
import argparse
import time
class PoolingClassifier(nn.Module):
"""
支持CLS Token与全局平均池化的分类头
架构特点:
- CLS Token: 使用可学习token聚合全局信息
- GAP: 空间平均池化,计算高效
- 可选:混合策略,结合两者优势
"""
def __init__(
self,
embed_dim: int,
num_classes: int,
pooling_type: Literal['cls', 'gap', 'both'] = 'cls',
dropout: float = 0.1
):
super().__init__()
self.pooling_type = pooling_type
self.embed_dim = embed_dim
# 层归一化(预归一化架构)
self.norm = nn.LayerNorm(embed_dim)
# 分类头
if pooling_type == 'cls':
head_dim = embed_dim
elif pooling_type == 'gap':
head_dim = embed_dim
else: # both
head_dim = embed_dim * 2
self.head = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(head_dim, num_classes)
)
# 初始化
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: 输入特征,形状 (B, N, D)
若为CLS模式,x[:, 0]为CLS Token
Returns:
logits: 分类 logits (B, num_classes)
"""
if self.pooling_type == 'cls':
# 提取CLS Token (第一个token)
cls_token = x[:, 0] # (B, D)
x = self.norm(cls_token)
elif self.pooling_type == 'gap':
# 全局平均池化(排除CLS Token如果存在)
if x.shape[1] > 1:
x = x[:, 1:] if x.shape[1] > 196 else x # 假设196个patch
x = x.mean(dim=1) # (B, D)
x = self.norm(x)
else: # both
cls_token = x[:, 0]
gap = x[:, 1:].mean(dim=1) if x.shape[1] > 1 else x.mean(dim=1)
x = torch.cat([self.norm(cls_token), self.norm(gap)], dim=-1)
return self.head(x)
class SimpleViT(nn.Module):
"""简化ViT用于对比实验"""
def __init__(
self,
image_size: int = 32,
patch_size: int = 4,
in_channels: int = 3,
embed_dim: int = 256,
depth: int = 6,
num_heads: int = 8,
mlp_ratio: float = 4.0,
num_classes: int = 10,
pooling_type: Literal['cls', 'gap'] = 'cls',
dropout: float = 0.1
):
super().__init__()
self.pooling_type = pooling_type
# Patch Embedding
self.patch_embed = nn.Conv2d(
in_channels, embed_dim, kernel_size=patch_size, stride=patch_size
)
num_patches = (image_size // patch_size) ** 2
# CLS Token(仅在需要时)
if pooling_type == 'cls':
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
nn.init.trunc_normal_(self.cls_token, std=0.02)
num_tokens = num_patches + 1
else:
num_tokens = num_patches
# 位置编码
self.pos_embed = nn.Parameter(torch.zeros(1, num_tokens, embed_dim))
nn.init.trunc_normal_(self.pos_embed, std=0.02)
# Transformer编码器
encoder_layer = nn.TransformerEncoderLayer(
d_model=embed_dim,
nhead=num_heads,
dim_feedforward=int(embed_dim * mlp_ratio),
dropout=dropout,
activation='gelu',
batch_first=True,
norm_first=True # 预归一化
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth)
# 分类头
self.classifier = PoolingClassifier(embed_dim, num_classes, pooling_type, dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B = x.shape[0]
# Patch embedding
x = self.patch_embed(x).flatten(2).transpose(1, 2) # (B, N, D)
# 添加CLS Token
if self.pooling_type == 'cls':
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat([cls_tokens, x], dim=1)
# 添加位置编码
x = x + self.pos_embed
# Transformer编码
x = self.transformer(x)
# 分类
return self.classifier(x)
class Trainer:
"""训练与评估框架"""
def __init__(self, model, device, lr=1e-3):
self.model = model.to(device)
self.device = device
self.criterion = nn.CrossEntropyLoss()
self.optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=0.05)
self.scheduler = optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=200)
def train_epoch(self, dataloader):
self.model.train()
total_loss, correct, total = 0, 0, 0
for batch_idx, (data, target) in enumerate(dataloader):
data, target = data.to(self.device), target.to(self.device)
self.optimizer.zero_grad()
output = self.model(data)
loss = self.criterion(output, target)
loss.backward()
# 梯度裁剪
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
self.optimizer.step()
total_loss += loss.item()
_, predicted = output.max(1)
total += target.size(0)
correct += predicted.eq(target).sum().item()
return total_loss / len(dataloader), 100. * correct / total
def evaluate(self, dataloader):
self.model.eval()
total_loss, correct, total = 0, 0, 0
with torch.no_grad():
for data, target in dataloader:
data, target = data.to(self.device), target.to(self.device)
output = self.model(data)
total_loss += self.criterion(output, target).item()
_, predicted = output.max(1)
total += target.size(0)
correct += predicted.eq(target).sum().item()
return total_loss / len(dataloader), 100. * correct / total
def run_comparison(args):
"""运行CLS vs GAP对比实验"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
# 数据准备
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
trainset = torchvision.datasets.CIFAR10(
root='./data', train=True, download=True, transform=transform_train
)
testset = torchvision.datasets.CIFAR10(
root='./data', train=False, download=True, transform=transform_test
)
trainloader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
testloader = DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)
# 创建模型
model = SimpleViT(
image_size=32,
patch_size=4,
embed_dim=256,
depth=6,
num_heads=8,
num_classes=10,
pooling_type=args.method
)
# 统计参数量
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Method: {args.method.upper()}")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
# 训练
trainer = Trainer(model, device, lr=args.lr)
best_acc = 0
for epoch in range(args.epochs):
start_time = time.time()
train_loss, train_acc = trainer.train_epoch(trainloader)
test_loss, test_acc = trainer.evaluate(testloader)
trainer.scheduler.step()
epoch_time = time.time() - start_time
print(f"Epoch {epoch+1}/{args.epochs} | "
f"Train Loss: {train_loss:.3f}, Acc: {train_acc:.2f}% | "
f"Test Loss: {test_loss:.3f}, Acc: {test_acc:.2f}% | "
f"Time: {epoch_time:.1f}s")
if test_acc > best_acc:
best_acc = test_acc
print(f"\nBest Test Accuracy: {best_acc:.2f}%")
return best_acc
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--method', type=str, default='cls', choices=['cls', 'gap'])
parser.add_argument('--epochs', type=int, default=10)
parser.add_argument('--lr', type=float, default=1e-3)
args = parser.parse_args()
run_comparison(args)
11.1.3 大规模预训练策略:ImageNet-21K, JFT-300M
ViT的性能高度依赖于预训练数据的规模。原始ViT研究表明,当在中等规模数据集(如ImageNet-1K,含128万张图像)上从头训练时,ViT的准确率低于同等计算量的ResNet。然而,当在大型数据集上进行预训练后,ViT展现出卓越的迁移学习能力。
ImageNet-21K是ImageNet-1K的超集,包含约1400万张图像和21,841个类别。在该数据集上预训练的ViT模型,通过微调即可在下游任务上取得优异性能。预训练策略通常采用监督学习,使用交叉熵损失函数,并配合强数据增强(RandAugment、Mixup、Cutmix)和随机深度(Stochastic Depth)正则化。
JFT-300M是一个更大规模的私有数据集,包含约3亿张图像和18,291个类别。在该数据集上预训练的ViT-Huge模型(632M参数)在ImageNet-1K上达到了88.55%的Top-1准确率,超越了当时的所有卷积网络。这一结果验证了Transformer架构在视觉任务中的可扩展性:随着数据量和模型规模的增加,性能持续提升,未出现饱和迹象。
预训练-微调范式中的关键超参数包括:基础学习率、层衰减率(Layer-wise Learning Rate Decay)、微调epoch数以及是否冻结部分层。研究表明,使用较小的层衰减率(如0.75)和较长的微调周期(如10000 steps)对于大型ViT模型至关重要。
实例:分层学习率衰减与微调策略
"""
Script: vit_finetuning.py
Description: ViT模型的分层学习率衰减与高效微调
Usage:
python vit_finetuning.py --pretrained_path vit_base_imagenet21k.pth --dataset custom_data
"""
import torch
import torch.nn as nn
import math
from typing import List, Dict, Optional
from collections import defaultdict
class LayerDecayOptimizer:
"""
分层学习率衰减优化器
为不同层分配不同的学习率,通常浅层使用较小学习率,深层使用较大学习率。
这对于微调预训练的大型Transformer模型至关重要。
"""
def __init__(
self,
optimizer: torch.optim.Optimizer,
model: nn.Module,
layer_decay: float = 0.75,
base_lr: float = 1e-3,
num_layers: int = 12,
patch_embed_lr_scale: float = 0.1,
cls_token_lr_scale: float = 1.0
):
self.optimizer = optimizer
self.layer_decay = layer_decay
self.base_lr = base_lr
self.num_layers = num_layers
# 按层分组参数
self.param_groups = self._group_parameters(model)
# 应用分层学习率
self._set_param_groups()
def _group_parameters(self, model: nn.Module) -> Dict[str, List[nn.Parameter]]:
"""将模型参数按层分组"""
groups = defaultdict(list)
# Patch Embedding层
if hasattr(model, 'patch_embed'):
groups['patch_embed'] = list(model.patch_embed.parameters())
# CLS Token和位置编码
if hasattr(model, 'cls_token'):
groups['cls_token'] = [model.cls_token]
if hasattr(model, 'pos_embed'):
groups['pos_embed'] = [model.pos_embed]
# Transformer层(假设命名为blocks或encoder.layers)
blocks = None
if hasattr(model, 'blocks'):
blocks = model.blocks
elif hasattr(model, 'encoder') and hasattr(model.encoder, 'layers'):
blocks = model.encoder.layers
if blocks is not None:
for i, block in enumerate(blocks):
groups[f'layer_{i}'] = list(block.parameters())
# 分类头(最高学习率)
if hasattr(model, 'head'):
groups['head'] = list(model.head.parameters())
elif hasattr(model, 'classifier'):
groups['head'] = list(model.classifier.parameters())
return groups
def _set_param_groups(self):
"""设置分层学习率"""
param_groups = []
for name, params in self.param_groups.items():
if not params:
continue
# 计算该层的学习率缩放
if name == 'patch_embed':
scale = self.patch_embed_lr_scale
elif name == 'head':
scale = 1.0 # 分类头使用基础学习率
elif name == 'cls_token' or name == 'pos_embed':
scale = self.cls_token_lr_scale
elif name.startswith('layer_'):
# 层越深,学习率越大(或越小,取决于策略)
layer_id = int(name.split('_')[1])
# 方案1:从浅层到深层递增
scale = self.layer_decay ** (self.num_layers - layer_id - 1)
# 方案2:从浅层到深层递减(更保守)
# scale = self.layer_decay ** layer_id
else:
scale = 1.0
lr = self.base_lr * scale
param_groups.append({
'params': params,
'lr': lr,
'name': name,
'lr_scale': scale
})
print(f"Layer: {name:15s} | LR: {lr:.2e} | Scale: {scale:.3f}")
# 更新优化器的参数组
self.optimizer.param_groups = param_groups
def step(self, closure=None):
self.optimizer.step(closure)
def zero_grad(self, set_to_none=False):
self.optimizer.zero_grad(set_to_none)
def state_dict(self):
return self.optimizer.state_dict()
def load_state_dict(self, state_dict):
self.optimizer.load_state_dict(state_dict)
class WarmupCosineScheduler:
"""
带预热的余弦退火学习率调度器
训练初期使用线性预热避免训练不稳定,随后使用余弦退火收敛至最小学习率。
"""
def __init__(
self,
optimizer: torch.optim.Optimizer,
warmup_epochs: int,
total_epochs: int,
min_lr: float = 1e-6,
warmup_lr_init: float = 1e-6
):
self.optimizer = optimizer
self.warmup_epochs = warmup_epochs
self.total_epochs = total_epochs
self.min_lr = min_lr
self.warmup_lr_init = warmup_lr_init
self.base_lrs = [group['lr'] for group in optimizer.param_groups]
def step(self, epoch: int):
if epoch < self.warmup_epochs:
# 线性预热
alpha = epoch / self.warmup_epochs
factor = self.warmup_lr_init + alpha * (1 - self.warmup_lr_init)
else:
# 余弦退火
progress = (epoch - self.warmup_epochs) / (self.total_epochs - self.warmup_epochs)
factor = self.min_lr + 0.5 * (1 - self.min_lr) * (1 + math.cos(math.pi * progress))
for i, group in enumerate(self.optimizer.param_groups):
group['lr'] = self.base_lrs[i] * factor
return self.optimizer.param_groups[0]['lr']
class MixupAugmentation:
"""
Mixup与CutMix数据增强
Mixup: 将两张图像按一定比例线性插值
CutMix: 将一张图像的局部区域替换为另一张图像的对应区域
"""
def __init__(self, mixup_alpha: float = 0.8, cutmix_alpha: float = 1.0, prob: float = 0.5):
self.mixup_alpha = mixup_alpha
self.cutmix_alpha = cutmix_alpha
self.prob = prob
def __call__(self, x: torch.Tensor, y: torch.Tensor) -> tuple:
if torch.rand(1).item() > self.prob:
return x, y
if torch.rand(1).item() < 0.5:
return self._mixup(x, y)
else:
return self._cutmix(x, y)
def _mixup(self, x: torch.Tensor, y: torch.Tensor):
"""Mixup增强"""
lam = torch.distributions.Beta(self.mixup_alpha, self.mixup_alpha).sample().item()
batch_size = x.size(0)
index = torch.randperm(batch_size).to(x.device)
mixed_x = lam * x + (1 - lam) * x[index]
y_a, y_b = y, y[index]
return mixed_x, (y_a, y_b, lam)
def _cutmix(self, x: torch.Tensor, y: torch.Tensor):
"""CutMix增强"""
lam = torch.distributions.Beta(self.cutmix_alpha, self.cutmix_alpha).sample().item()
batch_size = x.size(0)
index = torch.randperm(batch_size).to(x.device)
# 随机生成裁剪区域
H, W = x.shape[2], x.shape[3]
cut_ratio = math.sqrt(1 - lam)
cut_h, cut_w = int(H * cut_ratio), int(W * cut_ratio)
cx, cy = torch.randint(H, (1,)).item(), torch.randint(W, (1,)).item()
bbx1 = max(0, cx - cut_h // 2)
bby1 = max(0, cy - cut_w // 2)
bbx2 = min(H, cx + cut_h // 2)
bby2 = min(W, cy + cut_w // 2)
x[:, :, bbx1:bbx2, bby1:bby2] = x[index, :, bbx1:bbx2, bby1:bby2]
lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (H * W))
return x, (y, y[index], lam)
class FineTuningPipeline:
"""完整的微调流程"""
def __init__(
self,
model: nn.Module,
num_classes: int,
pretrained_path: Optional[str] = None,
freeze_patch_embed: bool = True,
layer_decay: float = 0.75
):
self.model = model
self.num_classes = num_classes
# 加载预训练权重
if pretrained_path:
self._load_pretrained(pretrained_path)
# 调整分类头
self._adapt_head(num_classes)
# 冻结策略
if freeze_patch_embed:
self._freeze_patch_embed()
self.layer_decay = layer_decay
def _load_pretrained(self, path: str):
"""加载预训练权重,处理尺寸不匹配"""
checkpoint = torch.load(path, map_location='cpu')
state_dict = checkpoint.get('model', checkpoint)
# 过滤掉分类头(尺寸不匹配)
model_state = self.model.state_dict()
filtered_state = {}
for k, v in state_dict.items():
if k in model_state and v.shape == model_state[k].shape:
filtered_state[k] = v
else:
print(f"Skipping {k}: shape mismatch or not in model")
self.model.load_state_dict(filtered_state, strict=False)
print(f"Loaded {len(filtered_state)}/{len(state_dict)} parameters from checkpoint")
def _adapt_head(self, num_classes: int):
"""调整分类头以适应新的类别数"""
if hasattr(self.model, 'head'):
in_features = self.model.head.in_features
self.model.head = nn.Linear(in_features, num_classes)
elif hasattr(self.model, 'classifier'):
in_features = self.model.classifier.in_features
self.model.classifier = nn.Linear(in_features, num_classes)
# 初始化新的分类头
nn.init.trunc_normal_(self.model.head.weight, std=0.02)
nn.init.constant_(self.model.head.bias, 0)
def _freeze_patch_embed(self):
"""冻结Patch Embedding层(通常预训练充分)"""
if hasattr(self.model, 'patch_embed'):
for param in self.model.patch_embed.parameters():
param.requires_grad = False
print("Frozen patch embedding parameters")
def get_optimizer(self, base_lr: float = 1e-3, weight_decay: float = 0.05):
"""获取配置好的优化器"""
base_optimizer = torch.optim.AdamW(
self.model.parameters(),
lr=base_lr,
weight_decay=weight_decay,
betas=(0.9, 0.999)
)
return LayerDecayOptimizer(
base_optimizer,
self.model,
layer_decay=self.layer_decay,
base_lr=base_lr
)
# 使用示例
if __name__ == "__main__":
# 假设有一个预训练的ViT模型
from torchvision.models import vit_b_16
model = vit_b_16(pretrained=False)
pipeline = FineTuningPipeline(
model=model,
num_classes=100, # 新的分类任务
pretrained_path='vit_base_imagenet21k.pth',
layer_decay=0.75
)
optimizer = pipeline.get_optimizer(base_lr=1e-3)
scheduler = WarmupCosineScheduler(optimizer.optimizer, warmup_epochs=5, total_epochs=100)
print("Fine-tuning pipeline initialized successfully")
11.1.4 与CNN的对比:归纳偏置与数据效率
卷积神经网络(CNN)通过局部连接、权重共享和平移不变性等归纳偏置,在中小规模数据集上表现出色。这些偏置使CNN能够高效地从有限样本中学习局部特征层次结构。相比之下,ViT缺乏这些显式的归纳偏置,必须完全从数据中学习空间关系,这导致其在数据量不足时容易过拟合。
归纳偏置的差异体现在多个层面。CNN的局部感受野假设邻近像素具有强相关性,这一假设在图像数据中普遍成立。ViT的自注意力机制虽然理论上可以建模任意距离的关系,但需要大量数据才能学习到有效的局部聚焦模式。此外,CNN的层次化结构通过池化操作逐步降低空间分辨率,自然地构建了多尺度特征表示;而原始ViT保持固定的空间分辨率,缺乏显式的多尺度处理能力。
数据效率方面,研究表明ViT在ImageNet-1K规模(约130万张图像)的数据集上,其样本效率显著低于ResNet。然而,当数据量扩大至ImageNet-21K(1400万张)或JFT-300M(3亿张)时,ViT不仅超越了CNN,还展现出更好的规模扩展性。这一现象表明,ViT的容量上限更高,能够充分利用大规模数据中的模式。
计算效率上,ViT的自注意力机制具有 O(N2) 的复杂度,其中 N 为图像块数量,这使得处理高分辨率图像时计算成本急剧上升。CNN的卷积操作虽然参数量较大,但其计算复杂度与图像像素数成线性关系,在特定硬件上具有更好的内存局部性。
实例:CNN与ViT特征可视化对比
"""
Script: cnn_vs_vit_comparison.py
Description: CNN与ViT的特征图、注意力图可视化对比
Usage:
python cnn_vs_vit_comparison.py --model_type vit --layer 6 --image_path sample.jpg
"""
import torch
import torch.nn as nn
import torchvision.models as models
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
import cv2
from typing import Tuple, List
import argparse
class FeatureExtractor:
"""通用特征提取器,支持CNN和ViT"""
def __init__(self, model: nn.Module, model_type: str):
self.model = model
self.model_type = model_type.lower()
self.features = {}
self.attentions = {}
self._register_hooks()
def _register_hooks(self):
"""注册前向传播钩子以捕获中间特征"""
if self.model_type == 'cnn':
# 为CNN的每个卷积层注册钩子
def get_hook(name):
def hook(module, input, output):
self.features[name] = output.detach()
return hook
for name, module in self.model.named_modules():
if isinstance(module, nn.Conv2d):
module.register_forward_hook(get_hook(f'conv_{name}'))
elif self.model_type == 'vit':
# 为ViT的每个Transformer块注册钩子
def get_attention_hook(layer_idx):
def hook(module, input, output):
# 捕获自注意力权重
if hasattr(module, 'self_attention'):
# 标准ViT结构
pass
# 存储特征
self.features[f'block_{layer_idx}'] = output.detach()
return hook
# 注册注意力捕获(针对timm实现的ViT)
def get_attn_map_hook(layer_idx):
def hook(module, input, output):
# 假设output包含注意力图
if isinstance(output, tuple) and len(output) > 1:
self.attentions[f'layer_{layer_idx}'] = output[1].detach()
self.features[f'layer_{layer_idx}'] = output[0] if isinstance(output, tuple) else output
return hook
if hasattr(self.model, 'blocks'):
for idx, block in enumerate(self.model.blocks):
block.register_forward_hook(get_attn_map_hook(idx))
def extract(self, x: torch.Tensor) -> dict:
"""提取特征"""
self.model.eval()
with torch.no_grad():
_ = self.model(x)
return self.features
def get_attention_maps(self) -> dict:
"""获取注意力图(仅ViT)"""
return self.attentions
class AttentionVisualizer:
"""ViT注意力图可视化工具"""
@staticmethod
def rollout_attention(attentions: List[torch.Tensor], discard_ratio: float = 0.9) -> torch.Tensor:
"""
注意力展开(Attention Rollout)
通过递归乘法传播注意力权重,可视化输入图像中对分类决策的贡献区域。
公式: A_rollout = A_l * A_{l-1} * ... * A_1
"""
result = torch.eye(attentions[0].size(-1)).to(attentions[0].device)
for attention in attentions:
# 平均多头注意力
attention_heads_fused = attention.mean(axis=1)
# 添加残差连接并归一化
I = torch.eye(attention_heads_fused.size(-1)).to(attention_heads_fused.device)
a = (attention_heads_fused + I) / 2.0
a = a / a.sum(dim=-1, keepdim=True)
result = torch.matmul(a, result)
# 提取CLS token对其他token的注意力
mask = result[0, 0, 1:] # 排除CLS token自身
return mask
@staticmethod
def visualize_attention(
image: np.ndarray,
attention_map: torch.Tensor,
patch_size: int = 16,
save_path: Optional[str] = None
):
"""
将注意力图叠加到原图上
Args:
image: 原始图像 (H, W, 3)
attention_map: 注意力权重 (num_patches,)
patch_size: 图像块大小
"""
H, W, _ = image.shape
nh = H // patch_size
nw = W // patch_size
# 重塑为2D
attention_map = attention_map.reshape(nh, nw).cpu().numpy()
# 上采样到原图尺寸
attention_map = cv2.resize(attention_map, (W, H))
# 归一化
attention_map = (attention_map - attention_map.min()) / (attention_map.max() - attention_map.min())
# 应用颜色映射
heatmap = plt.cm.jet(attention_map)[:, :, :3] # RGB
# 叠加
overlay = image * 0.5 + heatmap * 255 * 0.5
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
axes[0].imshow(image.astype(np.uint8))
axes[0].set_title('Original Image')
axes[0].axis('off')
axes[1].imshow(attention_map, cmap='jet')
axes[1].set_title('Attention Map')
axes[1].axis('off')
axes[2].imshow(overlay.astype(np.uint8))
axes[2].set_title('Overlay')
axes[2].axis('off')
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=150, bbox_inches='tight')
plt.show()
@staticmethod
def visualize_cnn_feature_maps(features: torch.Tensor, num_maps: int = 16, save_path: Optional[str] = None):
"""
可视化CNN特征图
Args:
features: 卷积层输出 (B, C, H, W)
num_maps: 显示的特征图数量
"""
feature_maps = features[0].cpu().numpy()[:num_maps] # 取第一个样本的前N个通道
grid_size = int(np.ceil(np.sqrt(num_maps)))
fig, axes = plt.subplots(grid_size, grid_size, figsize=(12, 12))
axes = axes.flatten()
for i in range(num_maps):
if i < len(feature_maps):
axes[i].imshow(feature_maps[i], cmap='viridis')
axes[i].set_title(f'Channel {i}')
axes[i].axis('off')
plt.suptitle('CNN Feature Maps')
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=150, bbox_inches='tight')
plt.show()
class InductiveBiasAnalyzer:
"""归纳偏置分析器"""
@staticmethod
def analyze_locality(model: nn.Module, model_type: str, image_size: int = 224) -> dict:
"""
分析模型的局部性偏好
通过测量特征响应的空间衰减来量化局部性。
"""
device = next(model.parameters()).device
# 创建中心脉冲输入
x = torch.zeros(1, 3, image_size, image_size).to(device)
center = image_size // 2
x[:, :, center-2:center+2, center-2:center+2] = 1.0
model.eval()
with torch.no_grad():
output = model(x)
# 提取中间层特征
if model_type == 'cnn':
# 获取某卷积层特征
features = None
for name, module in model.named_modules():
if isinstance(module, nn.Conv2d) and 'layer4' in name:
handle = module.register_forward_hook(
lambda m, i, o: setattr(InductiveBiasAnalyzer, 'temp_feat', o)
)
model(x)
features = InductiveBiasAnalyzer.temp_feat
handle.remove()
break
if features is not None:
# 计算空间激活的方差(局部性度量)
feat = features[0].mean(0).cpu().numpy() # 平均通道
y_coords, x_coords = np.indices(feat.shape)
center_y, center_x = feat.shape[0] // 2, feat.shape[1] // 2
# 计算加权距离
distances = np.sqrt((y_coords - center_y)**2 + (x_coords - center_x)**2)
weighted_dist = (feat * distances).sum() / feat.sum()
return {
'weighted_distance': float(weighted_dist),
'spatial_spread': float(feat.std()),
'locality_score': 1.0 / (1.0 + weighted_dist)
}
elif model_type == 'vit':
# 对于ViT,分析注意力图的局部性
# 这里简化处理,实际应提取注意力权重
return {'note': 'ViT locality depends on learned attention patterns'}
return {}
@staticmethod
def compare_translation_invariance(
model_cnn: nn.Module,
model_vit: nn.Module,
image: torch.Tensor,
shifts: List[Tuple[int, int]] = [(0, 0), (4, 4), (8, 8), (16, 16)]
) -> dict:
"""
比较平移不变性
通过测量输入平移时输出特征的一致性来评估平移不变性。
"""
device = next(model_cnn.parameters()).device
results = {'cnn': [], 'vit': []}
def get_feature_vector(model, x):
model.eval()
with torch.no_grad():
# 提取倒数第二层特征
if hasattr(model, 'avgpool'):
x = model.avgpool(model.features(x))
return x.flatten(1)
else:
return model(x)
base_feat_cnn = get_feature_vector(model_cnn, image.to(device))
base_feat_vit = get_feature_vector(model_vit, image.to(device))
for dy, dx in shifts:
shifted = torch.roll(image, shifts=(dy, dx), dims=(2, 3))
feat_cnn = get_feature_vector(model_cnn, shifted.to(device))
feat_vit = get_feature_vector(model_vit, shifted.to(device))
# 计算余弦相似度
sim_cnn = torch.cosine_similarity(base_feat_cnn, feat_cnn, dim=1).item()
sim_vit = torch.cosine_similarity(base_feat_vit, feat_vit, dim=1).item()
results['cnn'].append(sim_cnn)
results['vit'].append(sim_vit)
return results
def main(args):
"""主函数:加载模型并可视化"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 加载图像
image = Image.open(args.image_path).convert('RGB')
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
input_tensor = transform(image).unsqueeze(0).to(device)
image_np = np.array(image.resize((224, 224)))
if args.model_type == 'vit':
# 加载ViT
from torchvision.models import vit_b_16, ViT_B_16_Weights
model = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1).to(device)
# 提取注意力
extractor = FeatureExtractor(model, 'vit')
features = extractor.extract(input_tensor)
attentions = extractor.get_attention_maps()
if attentions:
# 注意力展开可视化
attn_list = [attentions[k] for k in sorted(attentions.keys())]
rollout = AttentionVisualizer.rollout_attention(attn_list)
AttentionVisualizer.visualize_attention(
image_np, rollout, patch_size=16, save_path='vit_attention.png'
)
else:
# 加载CNN (ResNet)
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1).to(device)
extractor = FeatureExtractor(model, 'cnn')
features = extractor.extract(input_tensor)
# 可视化某层的特征图
if features:
first_key = list(features.keys())[0]
AttentionVisualizer.visualize_cnn_feature_maps(
features[first_key], save_path='cnn_features.png'
)
print(f"Visualization completed for {args.model_type}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--model_type', type=str, choices=['cnn', 'vit'], default='vit')
parser.add_argument('--image_path', type=str, required=True)
parser.add_argument('--layer', type=int, default=6)
args = parser.parse_args()
main(args)
11.2 层次化视觉Transformer
层次化视觉Transformer通过引入类似CNN的多阶段架构,逐步降低空间分辨率并增加通道维度,有效解决了原始ViT在密集预测任务中的局限性。这类架构在保持全局建模能力的同时,提供了多尺度特征表示,显著提升了在目标检测、语义分割等下游任务中的性能。
11.2.1 Swin Transformer:移位窗口与层次化特征
Swin Transformer通过移位窗口(Shifted Window)机制实现了线性复杂度的自注意力计算,同时保持了跨窗口的连接能力。其核心创新包括层次化特征图构建和基于窗口的多头自注意力(Window-based Multi-head Self-Attention,W-MSA)。
架构上,Swin Transformer采用4个阶段(Stage)逐步处理图像。初始阶段将输入图像分割为 4×4 像素的小图像块,通过线性嵌入层投影至 C 维特征空间。后续每个阶段通过图像块合并(Patch Merging)层将相邻的 2×2 图像块特征拼接并投影,实现空间分辨率的减半和通道数的加倍,构建出层次化特征金字塔。
自注意力计算限制在非重叠的局部窗口内,每个窗口包含 M×M 个图像块。对于大小为 h×w 的特征图,窗口数量为 Mh×Mw ,每个窗口内的自注意力复杂度为 O(M2⋅d) ,总体复杂度降至 O(hw⋅M2⋅d) ,与图像尺寸成线性关系。然而,固定窗口划分限制了跨窗口的信息交互。
移位窗口机制通过在连续层之间偏移窗口划分位置解决这一问题。具体而言,第 l 层采用规则窗口划分,第 l+1 层则将窗口划分偏移 (⌊2M⌋,⌊2M⌋) 个像素。这种交替划分方式使相邻窗口产生重叠,实现了跨窗口连接,同时通过循环移位(Cyclic Shift)和掩码机制保持计算效率。
实例:Swin Transformer完整实现
"""
Script: swin_transformer.py
Description: Swin Transformer的完整PyTorch实现,包含移位窗口与图像块合并
Usage:
python swin_transformer.py --model_size tiny --image_size 224 --visualize_attention
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Optional, Tuple
import math
class WindowAttention(nn.Module):
"""
基于窗口的多头自注意力(W-MSA)与移位窗口注意力(SW-MSA)
支持相对位置偏置(Relative Position Bias),这是Swin的关键设计之一。
"""
def __init__(
self,
dim: int,
window_size: Tuple[int, int],
num_heads: int,
qkv_bias: bool = True,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
pretrained_window_size: Tuple[int, int] = (0, 0)
):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.pretrained_window_size = pretrained_window_size
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
# 相对位置偏置表
self.relative_position_bias_table = nn.Parameter(
torch.zeros(
(2 * window_size[0] - 1) * (2 * window_size[1] - 1),
num_heads
)
)
nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02)
# 计算相对位置索引
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing='ij')) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Args:
x: 输入特征,形状 (B, N, C),其中N = Wh * Ww(窗口内图像块数)
mask: 注意力掩码,用于移位窗口,形状 (num_windows, Wh*Ww, Wh*Ww)
"""
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # (B, num_heads, N, head_dim)
q = q * self.scale
attn = q @ k.transpose(-2, -1) # (B, num_heads, N, N)
# 添加相对位置偏置
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1)
].view(
self.window_size[0] * self.window_size[1],
self.window_size[0] * self.window_size[1],
-1
) # Wh*Ww, Wh*Ww, num_heads
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # num_heads, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
# 移位窗口掩码处理
nW = mask.shape[0] # 窗口数量
attn = attn.view(B // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class SwinTransformerBlock(nn.Module):
"""
Swin Transformer块,包含W-MSA和SW-MSA的交替
架构:LayerNorm -> W-MSA/SW-MSA -> 残差 -> LayerNorm -> MLP -> 残差
"""
def __init__(
self,
dim: int,
input_resolution: Tuple[int, int],
num_heads: int,
window_size: int = 7,
shift_size: int = 0,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
drop: float = 0.0,
attn_drop: float = 0.0,
drop_path: float = 0.0,
act_layer: nn.Module = nn.GELU,
norm_layer: nn.Module = nn.LayerNorm
):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
# 如果窗口大小大于特征图尺寸,不执行移位
if min(self.input_resolution) <= self.window_size:
self.shift_size = 0
self.window_size = min(self.input_resolution)
assert 0 <= self.shift_size < self.window_size, "shift_size must be in [0, window_size)"
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim,
window_size=(self.window_size, self.window_size),
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=drop
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
# 为移位窗口计算注意力掩码
if self.shift_size > 0:
attn_mask = self._calculate_mask(input_resolution)
else:
attn_mask = None
self.register_buffer("attn_mask", attn_mask)
def _calculate_mask(self, x_size: Tuple[int, int]) -> torch.Tensor:
"""
计算移位窗口的注意力掩码
通过循环移位将移位后的窗口重新排列为规则网格,然后掩码跨边界的注意力。
"""
H, W = x_size
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
# 划分窗口区域
h_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None)
)
w_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None)
)
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
# 窗口划分
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
# 计算掩码:不同区域的窗口不应互相注意
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
return attn_mask
def forward(self, x: torch.Tensor) -> torch.Tensor:
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, f"input feature has wrong size {L} vs {H}*{W}"
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
# 循环移位
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
else:
shifted_x = x
# 窗口划分
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
# W-MSA/SW-MSA
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
# 窗口还原
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
# 反向循环移位
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
x = x.view(B, H * W, C)
# 残差连接1
x = shortcut + self.drop_path(x)
# FFN
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class PatchMerging(nn.Module):
"""
图像块合并层
将2x2邻域的图像块拼接并线性投影,实现空间下采样2倍,通道数加倍。
这是构建层次化特征的关键组件。
"""
def __init__(self, input_resolution: Tuple[int, int], dim: int, norm_layer=nn.LayerNorm):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
x = x.view(B, H, W, C)
# 将2x2邻域的特征拼接
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x) # B H/2*W/2 2*C
return x
class Mlp(nn.Module):
"""多层感知机,用于Transformer块中的前馈网络"""
def __init__(
self,
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
act_layer=nn.GELU,
drop: float = 0.0
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample"""
def __init__(self, drop_prob: float = 0.0):
super().__init__()
self.drop_prob = drop_prob
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.drop_prob == 0.0 or not self.training:
return x
keep_prob = 1 - self.drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_()
output = x.div(keep_prob) * random_tensor
return output
def window_partition(x: torch.Tensor, window_size: int) -> torch.Tensor:
"""
将特征图划分为不重叠的窗口
Args:
x: (B, H, W, C)
window_size: 窗口大小
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous()
windows = windows.view(-1, window_size, window_size, C)
return windows
def window_reverse(windows: torch.Tensor, window_size: int, H: int, W: int) -> torch.Tensor:
"""
将窗口还原为特征图
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size: 窗口大小
H, W: 原始特征图高度和宽度
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class BasicLayer(nn.Module):
"""
Swin Transformer的一个阶段(Stage)
包含多个Swin Transformer块和可选的图像块合并层。
"""
def __init__(
self,
dim: int,
input_resolution: Tuple[int, int],
depth: int,
num_heads: int,
window_size: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
drop: float = 0.0,
attn_drop: float = 0.0,
drop_path: float = 0.0,
norm_layer: nn.Module = nn.LayerNorm,
downsample: Optional[nn.Module] = None
):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
# 构建Swin Transformer块,交替使用规则窗口和移位窗口
self.blocks = nn.ModuleList([
SwinTransformerBlock(
dim=dim,
input_resolution=input_resolution,
num_heads=num_heads,
window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer
)
for i in range(depth)
])
# 下采样层(图像块合并)
if downsample is not None:
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
else:
self.downsample = None
def forward(self, x: torch.Tensor) -> torch.Tensor:
for blk in self.blocks:
x = blk(x)
if self.downsample is not None:
x = self.downsample(x)
return x
class SwinTransformer(nn.Module):
"""
Swin Transformer完整模型
4个阶段的层次化架构,逐步下采样构建特征金字塔。
"""
def __init__(
self,
img_size: int = 224,
patch_size: int = 4,
in_chans: int = 3,
num_classes: int = 1000,
embed_dim: int = 96,
depths: Tuple[int, ...] = (2, 2, 6, 2),
num_heads: Tuple[int, ...] = (3, 6, 12, 24),
window_size: int = 7,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
drop_rate: float = 0.0,
attn_drop_rate: float = 0.0,
drop_path_rate: float = 0.1,
norm_layer: nn.Module = nn.LayerNorm,
ape: bool = False, # 绝对位置编码
patch_norm: bool = True
):
super().__init__()
self.num_classes = num_classes
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.ape = ape
self.patch_norm = patch_norm
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
self.mlp_ratio = mlp_ratio
# 图像块嵌入
self.patch_embed = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
norm_layer=norm_layer if patch_norm else None
)
num_patches = self.patch_embed.num_patches
patches_resolution = self.patch_embed.patches_resolution
self.patches_resolution = patches_resolution
# 绝对位置编码(可选,Swin通常不使用)
if ape:
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
nn.init.trunc_normal_(self.absolute_pos_embed, std=0.02)
else:
self.absolute_pos_embed = None
self.pos_drop = nn.Dropout(p=drop_rate)
# 随机深度衰减规则
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
# 构建4个阶段
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = BasicLayer(
dim=int(embed_dim * 2 ** i_layer),
input_resolution=(
patches_resolution[0] // (2 ** i_layer),
patches_resolution[1] // (2 ** i_layer)
),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
norm_layer=norm_layer,
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None
)
self.layers.append(layer)
self.norm = norm_layer(self.num_features)
self.avgpool = nn.AdaptiveAvgPool1d(1)
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
x = self.patch_embed(x)
if self.absolute_pos_embed is not None:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
for layer in self.layers:
x = layer(x)
x = self.norm(x) # B L C
x = self.avgpool(x.transpose(1, 2)) # B C 1
x = torch.flatten(x, 1)
return x
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.forward_features(x)
x = self.head(x)
return x
class PatchEmbed(nn.Module):
"""图像到图像块嵌入"""
def __init__(
self,
img_size: int = 224,
patch_size: int = 4,
in_chans: int = 3,
embed_dim: int = 96,
norm_layer: Optional[nn.Module] = None
):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = [img_size // patch_size, img_size // patch_size]
self.num_patches = self.patches_resolution[0] * self.patches_resolution[1]
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, C, H, W = x.shape
x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
x = self.norm(x)
return x
# 模型配置工厂
def swin_tiny_patch4_window7_224(**kwargs):
model = SwinTransformer(
embed_dim=96,
depths=(2, 2, 6, 2),
num_heads=(3, 6, 12, 24),
window_size=7,
drop_path_rate=0.2,
**kwargs
)
return model
def swin_small_patch4_window7_224(**kwargs):
model = SwinTransformer(
embed_dim=96,
depths=(2, 2, 18, 2),
num_heads=(3, 6, 12, 24),
window_size=7,
drop_path_rate=0.3,
**kwargs
)
return model
if __name__ == "__main__":
# 测试Swin Transformer
model = swin_tiny_patch4_window7_224(num_classes=1000)
x = torch.randn(2, 3, 224, 224)
print("Testing Swin Transformer...")
with torch.no_grad():
output = model(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Model parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
# 测试各阶段输出尺寸
print("\nStage outputs:")
x = model.patch_embed(x)
print(f"After patch embed: {x.shape}")
for i, layer in enumerate(model.layers):
x = layer(x)
print(f"After stage {i+1}: {x.shape}")
11.2.2 Pyramid Vision Transformer(PVT)
Pyramid Vision Transformer(PVT)通过引入空间缩减注意力(Spatial Reduction Attention,SRA)机制,解决了ViT在处理高分辨率特征图时的计算效率问题。PVT采用与CNN类似的金字塔结构,逐步降低空间分辨率,生成多尺度特征表示,适用于密集预测任务。
PVT的核心创新在于SRA机制。标准自注意力的计算复杂度为 O(N2⋅d) ,其中 N=H×W 为序列长度。当处理高分辨率特征图时,这一复杂度变得不可接受。SRA通过下采样键(Key)和值(Value)矩阵降低计算量:
SRA(Q,K,V)=Attention(Q,SR(K),SR(V))
其中 SR(⋅) 表示空间缩减操作,通常采用步长卷积或平均池化实现,缩减比例为 R 。这使得复杂度降至 O(N⋅RN⋅d) ,显著提升了高分辨率特征处理的可行性。
PVTv2进一步引入线性复杂度注意力层,通过逐层卷积替代显式空间缩减,在保持性能的同时简化了实现。此外,PVTv2采用重叠图像块嵌入,增强了局部连续性,并通过前馈网络中的深度可分离卷积引入局部处理能力。
实例:PVTv2实现与SRA机制
"""
Script: pvt_v2.py
Description: Pyramid Vision Transformer v2 (PVTv2) 实现,包含线性复杂度注意力
Usage:
python pvt_v2.py --model pvt_v2_b2 --input_size 224 --sra_ratio 8
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple
import math
class OverlapPatchEmbed(nn.Module):
"""
重叠图像块嵌入
使用带重叠的卷积核替代原始ViT的非重叠分块,增强局部连续性。
"""
def __init__(
self,
img_size: int = 224,
patch_size: int = 7,
stride: int = 4,
in_chans: int = 3,
embed_dim: int = 64
):
super().__init__()
self.proj = nn.Conv2d(
in_chans, embed_dim,
kernel_size=patch_size,
stride=stride,
padding=patch_size // 2
)
self.norm = nn.LayerNorm(embed_dim)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Tuple[int, int]]:
x = self.proj(x) # B C H W
_, _, H, W = x.shape
x = x.flatten(2).transpose(1, 2) # B N C
x = self.norm(x)
return x, (H, W)
class SpatialReductionAttention(nn.Module):
"""
空间缩减注意力 (SRA)
通过缩减键和值的空间维度降低计算复杂度,从O(N^2)降至O(N^2/R)。
"""
def __init__(
self,
dim: int,
num_heads: int = 8,
sr_ratio: int = 1,
qkv_bias: bool = False,
attn_drop: float = 0.0,
proj_drop: float = 0.0
):
super().__init__()
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.q = nn.Linear(dim, dim, bias=qkv_bias)
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
# 空间缩减
self.sr_ratio = sr_ratio
if sr_ratio > 1:
self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
self.norm = nn.LayerNorm(dim)
def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
B, N, C = x.shape
# 生成Query
q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
# 空间缩减Key和Value
if self.sr_ratio > 1:
x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
x_ = self.norm(x_)
kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
else:
kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
k, v = kv[0], kv[1]
# 注意力计算
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class LinearAttention(nn.Module):
"""
线性复杂度注意力 (PVTv2改进)
使用深度可分离卷积实现局部感知,保持线性复杂度。
"""
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
attn_drop: float = 0.0,
proj_drop: float = 0.0
):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
# 使用Flash Attention风格的内存高效实现
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class DWConv(nn.Module):
"""深度可分离卷积,用于增强局部感知"""
def __init__(self, dim: int):
super().__init__()
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, groups=dim)
def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
B, N, C = x.shape
x = x.transpose(1, 2).view(B, C, H, W)
x = self.dwconv(x)
x = x.flatten(2).transpose(1, 2)
return x
class MlpWithDWConv(nn.Module):
"""
带深度可分离卷积的MLP (PVTv2改进)
在MLP中引入局部感知能力,增强对细粒度特征的建模。
"""
def __init__(
self,
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
act_layer=nn.GELU,
drop: float = 0.0,
linear: bool = False
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.dwconv = DWConv(hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
self.linear = linear
if linear:
self.relu = nn.ReLU(inplace=True)
def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
x = self.fc1(x)
if self.linear:
x = self.relu(x)
x = self.dwconv(x, H, W)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Block(nn.Module):
"""
PVTv2基础块
包含注意力层和带DWConv的MLP,使用预归一化架构。
"""
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = False,
drop: float = 0.0,
attn_drop: float = 0.0,
drop_path: float = 0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
sr_ratio: int = 1,
linear: bool = False
):
super().__init__()
self.norm1 = norm_layer(dim)
# 根据sr_ratio选择注意力类型
if linear:
self.attn = LinearAttention(dim, num_heads, qkv_bias, attn_drop, drop)
else:
self.attn = SpatialReductionAttention(dim, num_heads, sr_ratio, qkv_bias, attn_drop, drop)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = MlpWithDWConv(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
linear=linear
)
def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
x = x + self.drop_path(self.attn(self.norm1(x), H, W))
x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
return x
class PyramidVisionTransformerV2(nn.Module):
"""
PVTv2完整模型
4个阶段的金字塔结构,逐步下采样,适用于密集预测任务。
"""
def __init__(
self,
img_size: int = 224,
in_chans: int = 3,
num_classes: int = 1000,
embed_dims: List[int] = [64, 128, 256, 512],
num_heads: List[int] = [1, 2, 4, 8],
mlp_ratios: List[float] = [8.0, 8.0, 4.0, 4.0],
drop_rate: float = 0.0,
attn_drop_rate: float = 0.0,
drop_path_rate: float = 0.1,
depths: List[int] = [3, 4, 6, 3],
sr_ratios: List[int] = [8, 4, 2, 1],
num_stages: int = 4,
linear: bool = False,
patch_sizes: List[int] = [7, 3, 3, 3],
strides: List[int] = [4, 2, 2, 2]
):
super().__init__()
self.num_classes = num_classes
self.depths = depths
self.num_stages = num_stages
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # 随机深度
cur = 0
for i in range(num_stages):
# 图像块嵌入
patch_embed = OverlapPatchEmbed(
img_size=img_size if i == 0 else img_size // (2 ** (i + 1)),
patch_size=patch_sizes[i],
stride=strides[i],
in_chans=in_chans if i == 0 else embed_dims[i - 1],
embed_dim=embed_dims[i]
)
# Transformer块
blocks = nn.ModuleList([
Block(
dim=embed_dims[i],
num_heads=num_heads[i],
mlp_ratio=mlp_ratios[i],
qkv_bias=True,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[cur + j],
norm_layer=nn.LayerNorm,
sr_ratio=sr_ratios[i],
linear=linear
)
for j in range(depths[i])
])
norm = nn.LayerNorm(embed_dims[i])
cur += depths[i]
setattr(self, f"patch_embed{i + 1}", patch_embed)
setattr(self, f"block{i + 1}", blocks)
setattr(self, f"norm{i + 1}", norm)
# 分类头
self.head = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity()
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def freeze_patch_emb(self):
self.patch_embed1.requires_grad = False
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes: int):
self.num_classes = num_classes
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
outputs = []
for i in range(self.num_stages):
patch_embed = getattr(self, f"patch_embed{i + 1}")
blocks = getattr(self, f"block{i + 1}")
norm = getattr(self, f"norm{i + 1}")
x, (H, W) = patch_embed(x)
for blk in blocks:
x = blk(x, H, W)
x = norm(x)
x = x.reshape(-1, H, W, x.shape[-1]).permute(0, 3, 1, 2).contiguous()
outputs.append(x)
# 返回多尺度特征(用于检测/分割)或最后特征(用于分类)
return outputs if self.training else x.flatten(2).mean(-1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.forward_features(x)
if isinstance(x, list):
x = x[-1] # 取最后阶段特征用于分类
x = self.head(x)
return x
# 模型配置
def pvt_v2_b0(**kwargs):
return PyramidVisionTransformerV2(
embed_dims=[32, 64, 160, 256],
num_heads=[1, 2, 5, 8],
mlp_ratios=[8, 8, 4, 4],
depths=[2, 2, 2, 2],
sr_ratios=[8, 4, 2, 1],
**kwargs
)
def pvt_v2_b2(**kwargs):
return PyramidVisionTransformerV2(
embed_dims=[64, 128, 320, 512],
num_heads=[1, 2, 5, 8],
mlp_ratios=[8, 8, 4, 4],
depths=[3, 4, 6, 3],
sr_ratios=[8, 4, 2, 1],
**kwargs
)
if __name__ == "__main__":
model = pvt_v2_b2(num_classes=1000)
x = torch.randn(2, 3, 224, 224)
with torch.no_grad():
output = model(x)
print(f"Input: {x.shape}")
print(f"Output: {output.shape}")
print(f"Parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
11.2.3 Twins-SVT:空间可分离注意力
Twins-SVT(Spatially Separable Vision Transformer)通过引入空间可分离自注意力(Spatially Separable Self-Attention,SSSA)机制,在保持全局建模能力的同时显著降低计算复杂度。其核心思想是将空间注意力分解为局部和全局两个步骤,分别捕获细粒度局部特征和长距离依赖。
SSSA首先在每个局部窗口内执行自注意力(局部注意力),然后通过一个全局操作(如池化或卷积)聚合窗口间的信息(全局注意力)。这种分解策略将复杂度从 O(N2) 降至 O(N⋅w2+w2N⋅d) ,其中 w 为窗口大小。具体实现中,Twins-SVT采用基于卷积的位置编码生成器(CPVT-style)替代传统的可学习位置编码,增强了位置信息的泛化能力。
Twins-SVT的架构包含4个阶段,每个阶段由多个SSSA块组成。与Swin Transformer不同,Twins-SVT在每个块中同时执行局部和全局注意力,而非交替使用。这种设计提供了更稳定的梯度流和更快的收敛速度。实验表明,Twins-SVT在ImageNet分类、COCO检测和ADE20K分割任务上均达到了当时的最优性能,同时保持了较高的计算效率。
实例:Twins-SVT空间可分离注意力实现
"""
Script: twins_svt.py
Description: Twins-SVT 空间可分离注意力实现
Usage:
python twins_svt.py --model twins_svt_small --test_attention
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Optional
import math
class PosConv(nn.Module):
"""
基于卷积的位置编码(CPVT-style)
使用深度可分离卷积生成位置编码,替代可学习参数,增强泛化能力。
"""
def __init__(self, in_chans: int, embed_dim: int = 768):
super().__init__()
self.proj = nn.Sequential(
nn.Conv2d(in_chans, embed_dim, 3, 1, 1, groups=in_chans, bias=False),
nn.BatchNorm2d(embed_dim),
nn.GELU(),
nn.Conv2d(embed_dim, embed_dim, 3, 1, 1, groups=embed_dim, bias=False),
)
def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
B, N, C = x.shape
feat = x.transpose(1, 2).view(B, C, H, W)
x = self.proj(feat) + feat
x = x.flatten(2).transpose(1, 2)
return x
class LocallyGroupedSelfAttention(nn.Module):
"""
局部分组自注意力(LSA)
将特征图划分为k x k的组,在每组内独立执行自注意力。
这是空间可分离注意力的局部组件。
"""
def __init__(
self,
dim: int,
num_heads: int = 8,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
ws: int = 7 # 窗口大小
):
super().__init__()
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.ws = ws # window size
self.qkv = nn.Linear(dim, dim * 3, bias=True)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
B, N, C = x.shape
assert H % self.ws == 0 and W % self.ws == 0, f"Feature map size ({H},{W}) must be divisible by ws {self.ws}"
# 重塑为窗口格式
x = x.reshape(B, H // self.ws, self.ws, W // self.ws, self.ws, C).transpose(2, 3)
# B, H//ws, W//ws, ws, ws, C
x = x.reshape(B * (H // self.ws) * (W // self.ws), self.ws * self.ws, C)
# B', ws*ws, C 其中 B' = B * (H//ws) * (W//ws)
qkv = self.qkv(x).reshape(B * (H // self.ws) * (W // self.ws), self.ws * self.ws, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # B', num_heads, ws*ws, head_dim
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B * (H // self.ws) * (W // self.ws), self.ws * self.ws, C)
x = self.proj(x)
x = self.proj_drop(x)
# 还原形状
x = x.reshape(B, H // self.ws, W // self.ws, self.ws, self.ws, C).transpose(2, 3).reshape(B, H, W, C).reshape(B, N, C)
return x
class GlobalSubsampledAttention(nn.Module):
"""
全局子采样注意力(GSA)
使用子采样(如池化)降低键和值的空间分辨率,捕获全局上下文。
这是空间可分离注意力的全局组件。
"""
def __init__(
self,
dim: int,
num_heads: int = 8,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
sr_ratio: int = 4 # 子采样比例
):
super().__init__()
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.sr_ratio = sr_ratio
self.q = nn.Linear(dim, dim, bias=True)
self.kv = nn.Linear(dim, dim * 2, bias=True)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
# 子采样
if sr_ratio > 1:
self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
self.norm = nn.LayerNorm(dim)
def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
B, N, C = x.shape
q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
if self.sr_ratio > 1:
x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
x_ = self.norm(x_)
kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
else:
kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
k, v = kv[0], kv[1]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class TwinsBlock(nn.Module):
"""
Twins-SVT基础块
交替使用LSA(局部)和GSA(全局)注意力,或在一个块中串联使用。
"""
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
drop: float = 0.0,
attn_drop: float = 0.0,
drop_path: float = 0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
ws: int = 7, # 局部窗口大小
sr_ratio: int = 4, # 全局子采样比例
use_gsa: bool = True # 是否使用GSA
):
super().__init__()
self.norm1 = norm_layer(dim)
self.lsa = LocallyGroupedSelfAttention(dim, num_heads, attn_drop, drop, ws)
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.use_gsa = use_gsa
if use_gsa:
self.norm2 = norm_layer(dim)
self.gsa = GlobalSubsampledAttention(dim, num_heads, attn_drop, drop, sr_ratio)
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm3 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
self.drop_path3 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
# 局部注意力
x = x + self.drop_path1(self.lsa(self.norm1(x), H, W))
# 全局注意力(可选)
if self.use_gsa:
x = x + self.drop_path2(self.gsa(self.norm2(x), H, W))
# MLP
x = x + self.drop_path3(self.mlp(self.norm3(x)))
return x
class Mlp(nn.Module):
"""MLP as used in Vision Transformer"""
def __init__(
self,
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
act_layer=nn.GELU,
drop: float = 0.0
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth)"""
def __init__(self, drop_prob: float = 0.0):
super().__init__()
self.drop_prob = drop_prob
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.drop_prob == 0.0 or not self.training:
return x
keep_prob = 1 - self.drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_()
output = x.div(keep_prob) * random_tensor
return output
class PatchEmbed(nn.Module):
"""Image to Patch Embedding"""
def __init__(
self,
img_size: int = 224,
patch_size: int = 4,
in_chans: int = 3,
embed_dim: int = 96
):
super().__init__()
self.patch_size = patch_size
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
self.norm = nn.LayerNorm(embed_dim)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Tuple[int, int]]:
B, C, H, W = x.shape
x = self.proj(x) # B, embed_dim, H//patch_size, W//patch_size
H, W = x.shape[2], x.shape[3]
x = x.flatten(2).transpose(1, 2) # B, H*W, embed_dim
x = self.norm(x)
return x, (H, W)
class TwinsSVT(nn.Module):
"""
Twins-SVT完整模型
空间可分离注意力架构,结合局部和全局建模能力。
"""
def __init__(
self,
img_size: int = 224,
in_chans: int = 3,
num_classes: int = 1000,
embed_dims: Tuple[int, ...] = (64, 128, 256, 512),
num_heads: Tuple[int, ...] = (1, 2, 4, 8),
mlp_ratios: Tuple[float, ...] = (4.0, 4.0, 4.0, 4.0),
drop_rate: float = 0.0,
attn_drop_rate: float = 0.0,
drop_path_rate: float = 0.1,
depths: Tuple[int, ...] = (2, 2, 10, 4),
ws: Tuple[int, ...] = (7, 7, 7, 7), # 局部窗口大小
sr_ratios: Tuple[int, ...] = (8, 4, 2, 1), # 全局子采样比例
norm_layer=nn.LayerNorm,
patch_size: int = 4
):
super().__init__()
self.num_classes = num_classes
self.depths = depths
self.num_stages = len(depths)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
cur = 0
for i in range(self.num_stages):
# 图像块嵌入
patch_embed = PatchEmbed(
img_size=img_size if i == 0 else img_size // (2 ** (i + 1)),
patch_size=patch_size if i == 0 else 2,
in_chans=in_chans if i == 0 else embed_dims[i - 1],
embed_dim=embed_dims[i]
)
# 位置编码
pos_conv = PosConv(embed_dims[i], embed_dims[i])
# Transformer块
blocks = nn.ModuleList([
TwinsBlock(
dim=embed_dims[i],
num_heads=num_heads[i],
mlp_ratio=mlp_ratios[i],
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[cur + j],
norm_layer=norm_layer,
ws=ws[i],
sr_ratio=sr_ratios[i],
use_gsa=(j % 2 == 1) # 交替使用GSA
)
for j in range(depths[i])
])
norm = norm_layer(embed_dims[i])
cur += depths[i]
setattr(self, f"patch_embed{i + 1}", patch_embed)
setattr(self, f"pos_conv{i + 1}", pos_conv)
setattr(self, f"block{i + 1}", blocks)
setattr(self, f"norm{i + 1}", norm)
self.head = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity()
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
for i in range(self.num_stages):
patch_embed = getattr(self, f"patch_embed{i + 1}")
pos_conv = getattr(self, f"pos_conv{i + 1}")
blocks = getattr(self, f"block{i + 1}")
norm = getattr(self, f"norm{i + 1}")
x, (H, W) = patch_embed(x)
x = pos_conv(x, H, W)
for blk in blocks:
x = blk(x, H, W)
x = norm(x)
x = x.reshape(-1, H, W, x.shape[-1]).permute(0, 3, 1, 2).contiguous()
x = x.flatten(2).mean(-1) # 全局平均池化
return x
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.forward_features(x)
x = self.head(x)
return x
# 模型配置
def twins_svt_small(**kwargs):
return TwinsSVT(
embed_dims=[64, 128, 256, 512],
num_heads=[2, 4, 8, 16],
mlp_ratios=[4, 4, 4, 4],
depths=[2, 2, 10, 4],
ws=[7, 7, 7, 7],
sr_ratios=[8, 4, 2, 1],
**kwargs
)
if __name__ == "__main__":
model = twins_svt_small(num_classes=1000)
x = torch.randn(2, 3, 224, 224)
with torch.no_grad():
output = model(x)
print(f"Input: {x.shape}")
print(f"Output: {output.shape}")
print(f"Parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
11.2.4 ConvNeXt:CNN对Transformer的反击
ConvNeXt并非Transformer架构,而是对标准卷积神经网络(ResNet)进行现代化改进的结果,旨在回答"纯CNN能否达到Transformer性能"这一问题。通过系统性地借鉴Swin Transformer的设计元素并应用于ResNet,ConvNeXt在ImageNet-1K上达到了与Swin Transformer相当的性能,同时保持了CNN的简洁性和效率。
关键改进包括:将训练迭代次数从90扩展至300 epoch;将优化器从SGD切换至AdamW;采用类似Transformer的宏设计(更改阶段计算比例、将下采样层移至开头);以及微观架构调整(将 3×3 卷积替换为 7×7 深度可分离卷积、使用GELU替代ReLU、减少激活函数和归一化层数量、使用LayerNorm替代BatchNorm、将下采样分离为独立层、将通道数扩展至与Swin-Tiny匹配的96维起始通道)。
ConvNeXt证明,Transformer的成功很大程度上源于其训练策略和架构宏设计,而非注意力机制本身。通过采用类似的训练配方(EMA、Stochastic Depth、Mixup、Cutmix、RandAugment、随机擦除)和架构设计(大核卷积、更少的归一化层),纯CNN可以匹配甚至超越Transformer的性能,同时具有更高的推理吞吐量和硬件友好性。
实例:ConvNeXt完整实现
"""
Script: convnext.py
Description: ConvNeXt 纯CNN架构的现代化实现
Usage:
python convnext.py --model convnext_tiny --batch_size 32 --benchmark
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple, Optional
import math
class LayerNorm(nn.Module):
"""
支持两种数据格式的LayerNorm:channels_last (默认) 或 channels_first。
channels_last对应输入形状 (N, H, W, C),channels_first对应 (N, C, H, W)。
ConvNeXt使用channels_last以优化内存访问模式。
"""
def __init__(self, normalized_shape: int, eps: float = 1e-6, data_format: str = "channels_last"):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.eps = eps
self.data_format = data_format
if self.data_format not in ["channels_last", "channels_first"]:
raise NotImplementedError
self.normalized_shape = (normalized_shape,)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.data_format == "channels_last":
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
elif self.data_format == "channels_first":
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
class ConvNeXtBlock(nn.Module):
"""
ConvNeXt基础块
架构:深度可分离卷积 -> LayerNorm -> 1x1卷积(MLP) -> GELU -> 1x1卷积 -> 残差连接
关键设计:
- 大核深度可分离卷积 (7x7) 替代自注意力
- 使用LayerNorm替代BatchNorm
- 减少激活函数数量(仅MLP中使用GELU)
- 减少归一化层数量(每个块仅一个)
"""
def __init__(
self,
dim: int,
drop_path: float = 0.0,
layer_scale_init_value: float = 1e-6
):
super().__init__()
# 深度可分离卷积,大核设计(7x7)
self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim)
# LayerNorm,使用channels_last格式
self.norm = LayerNorm(dim, eps=1e-6)
# 点卷积MLP(1x1卷积实现)
self.pwconv1 = nn.Linear(dim, 4 * dim) # 扩展比例为4
self.act = nn.GELU()
self.pwconv2 = nn.Linear(4 * dim, dim)
# 层缩放(Layer Scale),用于稳定深层网络训练
self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) if layer_scale_init_value > 0 else None
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
input = x
# 深度可分离卷积
x = self.dwconv(x)
# 转换为channels_last格式以使用LayerNorm
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
x = self.norm(x)
# MLP
x = self.pwconv1(x)
x = self.act(x)
x = self.pwconv2(x)
# 层缩放
if self.gamma is not None:
x = self.gamma * x
# 转换回channels_first
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
x = input + self.drop_path(x)
return x
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth)"""
def __init__(self, drop_prob: float = 0.0):
super().__init__()
self.drop_prob = drop_prob
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.drop_prob == 0.0 or not self.training:
return x
keep_prob = 1 - self.drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_()
output = x.div(keep_prob) * random_tensor
return output
class ConvNeXt(nn.Module):
"""
ConvNeXt完整模型
4个阶段的层次化架构,模仿Swin Transformer的阶段比例:
- Stage 1: 1/4 分辨率
- Stage 2: 1/8 分辨率
- Stage 3: 1/16 分辨率
- Stage 4: 1/32 分辨率
"""
def __init__(
self,
in_chans: int = 3,
num_classes: int = 1000,
depths: List[int] = [3, 3, 9, 3],
dims: List[int] = [96, 192, 384, 768],
drop_path_rate: float = 0.0,
layer_scale_init_value: float = 1e-6,
head_init_scale: float = 1.0
):
super().__init__()
self.downsample_layers = nn.ModuleList()
# Stem层:早期下采样,替代ResNet的7x7卷积+最大池化
stem = nn.Sequential(
nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
)
self.downsample_layers.append(stem)
# 阶段间的下采样层:2x2卷积,步长2
for i in range(3):
downsample_layer = nn.Sequential(
LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2),
)
self.downsample_layers.append(downsample_layer)
# 4个阶段,每个阶段包含多个ConvNeXt块
self.stages = nn.ModuleList()
dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
cur = 0
for i in range(4):
stage = nn.Sequential(
*[
ConvNeXtBlock(
dim=dims[i],
drop_path=dp_rates[cur + j],
layer_scale_init_value=layer_scale_init_value
)
for j in range(depths[i])
]
)
self.stages.append(stage)
cur += depths[i]
self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # 最终归一化
self.head = nn.Linear(dims[-1], num_classes)
self.apply(self._init_weights)
self.head.weight.data.mul_(head_init_scale)
self.head.bias.data.mul_(head_init_scale)
def _init_weights(self, m):
if isinstance(m, (nn.Conv2d, nn.Linear)):
nn.init.trunc_normal_(m.weight, std=0.02)
nn.init.constant_(m.bias, 0)
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
for i in range(4):
x = self.downsample_layers[i](x)
x = self.stages[i](x)
# 全局平均池化
x = self.norm(x.mean([-2, -1])) # (N, C, H, W) -> (N, C)
return x
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.forward_features(x)
x = self.head(x)
return x
# 模型配置工厂
def convnext_tiny(num_classes: int = 1000, **kwargs):
model = ConvNeXt(
depths=[3, 3, 9, 3],
dims=[96, 192, 384, 768],
**kwargs
)
return model
def convnext_small(num_classes: int = 1000, **kwargs):
model = ConvNeXt(
depths=[3, 3, 27, 3],
dims=[96, 192, 384, 768],
**kwargs
)
return model
def convnext_base(num_classes: int = 1000, **kwargs):
model = ConvNeXt(
depths=[3, 3, 27, 3],
dims=[128, 256, 512, 1024],
**kwargs
)
return model
def convnext_large(num_classes: int = 1000, **kwargs):
model = ConvNeXt(
depths=[3, 3, 27, 3],
dims=[192, 384, 768, 1536],
**kwargs
)
return model
# 性能基准测试
def benchmark(model, input_size=224, batch_size=32, iterations=100):
"""测试模型吞吐量和延迟"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device).eval()
x = torch.randn(batch_size, 3, input_size, input_size).to(device)
# 预热
for _ in range(10):
_ = model(x)
if device.type == 'cuda':
torch.cuda.synchronize()
import time
start = time.time()
for _ in range(iterations):
_ = model(x)
if device.type == 'cuda':
torch.cuda.synchronize()
elapsed = time.time() - start
throughput = iterations * batch_size / elapsed
latency = elapsed / iterations * 1000 # ms
params = sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6
print(f"Model: {model.__class__.__name__}")
print(f"Parameters: {params:.2f}M")
print(f"Throughput: {throughput:.2f} images/sec")
print(f"Latency: {latency:.2f} ms/batch")
return throughput, latency
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default='convnext_tiny', choices=['convnext_tiny', 'convnext_small', 'convnext_base'])
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--benchmark', action='store_true')
args = parser.parse_args()
model_func = globals()[args.model]
model = model_func(num_classes=1000)
x = torch.randn(2, 3, 224, 224)
with torch.no_grad():
output = model(x)
print(f"Input: {x.shape}")
print(f"Output: {output.shape}")
if args.benchmark:
benchmark(model, batch_size=args.batch_size)
11.3 高效ViT设计
高效ViT设计旨在将Vision Transformer部署至资源受限环境(移动设备、边缘设备),通过架构优化、混合设计和硬件感知优化,在保持性能的同时显著降低计算复杂度和内存占用。
11.3.1 MobileViT:轻量级视觉Transformer
MobileViT通过结合CNN的局部感知能力和Transformer的全局建模能力,实现了移动设备上的高效视觉识别。其核心创新在于MobileViT块,该块首先使用标准卷积提取局部特征,然后通过Transformer编码器建模全局关系,最后将全局特征与局部特征融合。
MobileViT块的具体流程为:输入特征图首先经过 n×n 卷积和点卷积生成 d 维特征;然后将特征图划分为 h×w 个展开为 d 维的patch;通过Transformer编码器处理这些patch以捕获全局信息;最后将处理后的特征重塑并与原始局部特征拼接,经另一组卷积融合。这种设计使MobileViT在ImageNet-1K上达到78.4%的Top-1准确率,同时比MobileNetv3快约3倍,参数量仅为MobileNetv3的1.2倍。
MobileViT v2进一步优化了注意力机制,提出可分离自注意力(Separable Self-Attention),将标准自注意力的复杂度从 O(N2) 降至 O(N) ,使模型能够处理更高分辨率的图像。此外,引入了线性复杂度的高分辨率注意力机制,通过跨特征维度(而非空间维度)计算注意力,避免了高分辨率下的计算瓶颈。
实例:MobileViT v2实现
"""
Script: mobilevit_v2.py
Description: MobileViT v2 实现,包含可分离自注意力机制
Usage:
python mobilevit_v2.py --model mobilevit_v2_xs --resolution 256 --benchmark
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, List, Tuple
class ConvLayer(nn.Module):
"""
标准卷积层,用于局部特征提取
包含卷积、归一化和激活函数,是MobileViT的构建单元。
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
stride: int = 1,
groups: int = 1,
use_norm: bool = True,
use_act: bool = True
):
super().__init__()
padding = (kernel_size - 1) // 2
self.conv = nn.Conv2d(
in_channels, out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups,
bias=False
)
self.norm = nn.BatchNorm2d(out_channels) if use_norm else nn.Identity()
self.act = nn.SiLU(inplace=True) if use_act else nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.act(self.norm(self.conv(x)))
class InvertedResidual(nn.Module):
"""
MobileNetv2风格的倒置残差块
扩展-卷积-压缩的瓶颈结构,用于高效的局部特征提取。
"""
def __init__(
self,
in_channels: int,
out_channels: int,
stride: int,
expand_ratio: int = 4
):
super().__init__()
assert stride in [1, 2]
hidden_dim = int(round(in_channels * expand_ratio))
self.use_res_connect = stride == 1 and in_channels == out_channels
layers = []
# 扩展
if expand_ratio != 1:
layers.append(ConvLayer(in_channels, hidden_dim, kernel_size=1))
# 深度可分离卷积
layers.extend([
ConvLayer(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
nn.Conv2d(hidden_dim, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels)
])
self.conv = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.use_res_connect:
return x + self.conv(x)
else:
return self.conv(x)
class MobileViTAttention(nn.Module):
"""
MobileViT v2的可分离自注意力
关键创新:跨通道维度计算注意力,而非空间维度,实现线性复杂度O(N)。
"""
def __init__(
self,
dim: int,
num_heads: int = 8,
attn_dropout: float = 0.0,
proj_dropout: float = 0.0
):
super().__init__()
self.num_heads = num_heads
self.scale = (dim // num_heads) ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=False)
self.attn_dropout = nn.Dropout(attn_dropout)
self.proj = nn.Linear(dim, dim)
self.proj_dropout = nn.Dropout(proj_dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, N, C = x.shape
# 生成Q, K, V
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # (B, heads, N, head_dim)
# 可分离注意力:在通道维度计算
# 标准注意力: Q @ K^T -> (B, heads, N, N),复杂度O(N^2)
# 可分离注意力: Q^T @ K -> (B, heads, head_dim, head_dim),复杂度O(d^2)
# 这里实现的是简化的线性注意力变体
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
attn = self.attn_dropout(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_dropout(x)
return x
class TransformerEncoder(nn.Module):
"""
轻量级Transformer编码器
使用可分离自注意力和前馈网络,针对移动设备优化。
"""
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 2.0,
dropout: float = 0.0,
attn_dropout: float = 0.0,
drop_path: float = 0.0
):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = MobileViTAttention(dim, num_heads, attn_dropout, dropout)
self.drop_path1 = DropPath(drop_path) if drop_path > 0 else nn.Identity()
self.norm2 = nn.LayerNorm(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(dim, mlp_hidden_dim),
nn.SiLU(inplace=True),
nn.Dropout(dropout),
nn.Linear(mlp_hidden_dim, dim),
nn.Dropout(dropout)
)
self.drop_path2 = DropPath(drop_path) if drop_path > 0 else nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
# 可分离自注意力
x = x + self.drop_path1(self.attn(self.norm1(x)))
# MLP
x = x + self.drop_path2(self.mlp(self.norm2(x)))
return x
class MobileViTBlock(nn.Module):
"""
MobileViT核心块
结合CNN局部特征提取和Transformer全局建模:
1. 局部特征提取(卷积)
2. 展开为patch
3. Transformer全局建模
4. 折叠回特征图
5. 局部-全局特征融合
"""
def __init__(
self,
in_channels: int,
out_channels: int,
dim: int,
num_heads: int = 4,
mlp_ratio: float = 2.0,
n_transformer_blocks: int = 2,
dropout: float = 0.0,
attn_dropout: float = 0.0,
drop_path: float = 0.0,
patch_size: Tuple[int, int] = (2, 2)
):
super().__init__()
self.patch_h, self.patch_w = patch_size
self.patch_size = patch_size
# 局部特征提取
self.local_rep = nn.Sequential(
ConvLayer(in_channels, dim, kernel_size=3),
InvertedResidual(dim, dim, stride=1, expand_ratio=2)
)
# 全局特征提取(Transformer)
transformer_blocks = []
for i in range(n_transformer_blocks):
transformer_blocks.append(
TransformerEncoder(
dim=dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
dropout=dropout,
attn_dropout=attn_dropout,
drop_path=drop_path
)
)
self.global_rep = nn.Sequential(*transformer_blocks)
# 特征融合
self.conv_proj = ConvLayer(dim * 2, out_channels, kernel_size=1)
def unfolding(self, x: torch.Tensor) -> Tuple[torch.Tensor, Tuple[int, int, int, int, int, int]]:
"""
将特征图展开为patch序列
Args:
x: (B, C, H, W)
Returns:
patches: (B, N, C') 其中N为patch数量
info: 包含原始尺寸的元组,用于折叠
"""
B, C, H, W = x.shape
# 确保尺寸可被patch_size整除
new_h = (H // self.patch_h) * self.patch_h
new_w = (W // self.patch_w) * self.patch_w
if new_h != H or new_w != W:
x = F.interpolate(x, size=(new_h, new_w), mode='bilinear', align_corners=False)
# 重塑为patch
# (B, C, H, W) -> (B, C, H//Ph, Ph, W//Pw, Pw) -> (B, H//Ph * W//Pw, Ph * Pw * C)
patches = x.unfold(2, self.patch_h, self.patch_h).unfold(3, self.patch_w, self.patch_w)
patches = patches.contiguous().view(B, C, -1, self.patch_h, self.patch_w)
patches = patches.permute(0, 2, 1, 3, 4).contiguous()
patches = patches.view(B, -1, C * self.patch_h * self.patch_w)
return patches, (B, C, H, W, new_h, new_w)
def folding(self, patches: torch.Tensor, info: Tuple) -> torch.Tensor:
"""
将patch序列折叠回特征图
"""
B, C, H, W, new_h, new_w = info
# 计算patch网格尺寸
n_patches_h = new_h // self.patch_h
n_patches_w = new_w // self.patch_w
# (B, N, C*Ph*Pw) -> (B, N, C, Ph, Pw) -> (B, C, Ph, Pw, N)
patches = patches.view(B, n_patches_h, n_patches_w, C, self.patch_h, self.patch_w)
patches = patches.permute(0, 3, 4, 5, 1, 2).contiguous()
# 合并patch
x = patches.view(B, C * self.patch_h * self.patch_w, -1)
x = F.fold(
x,
output_size=(new_h, new_w),
kernel_size=(self.patch_h, self.patch_w),
stride=(self.patch_h, self.patch_w)
)
# 如果尺寸改变,插值回原尺寸
if new_h != H or new_w != W:
x = F.interpolate(x, size=(H, W), mode='bilinear', align_corners=False)
return x
def forward(self, x: torch.Tensor) -> torch.Tensor:
# 局部特征
local_features = self.local_rep(x)
# 展开
patches, info = self.unfolding(local_features)
# 全局建模
global_features = self.global_rep(patches)
# 折叠
global_features = self.folding(global_features, info)
# 特征融合
combined = torch.cat([local_features, global_features], dim=1)
out = self.conv_proj(combined)
return out
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth)"""
def __init__(self, drop_prob: float = 0.0):
super().__init__()
self.drop_prob = drop_prob
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.drop_prob == 0.0 or not self.training:
return x
keep_prob = 1 - self.drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_()
output = x.div(keep_prob) * random_tensor
return output
class MobileViT(nn.Module):
"""
MobileViT v2完整模型
针对移动设备优化的轻量级视觉Transformer,结合CNN效率和Transformer能力。
"""
def __init__(
self,
in_channels: int = 3,
num_classes: int = 1000,
dims: List[int] = [32, 64, 128, 256, 384, 512],
channels: List[int] = [16, 32, 64, 128, 256, 512],
num_heads: int = 4,
mlp_ratio: float = 2.0,
n_transformer_blocks: List[int] = [2, 4, 3],
dropout: float = 0.0,
attn_dropout: float = 0.0,
drop_path_rate: float = 0.0,
patch_size: Tuple[int, int] = (2, 2)
):
super().__init__()
# 下采样层
self.stem = ConvLayer(in_channels, channels[0], kernel_size=3, stride=2)
# 阶段1: 纯卷积
self.stage1 = nn.Sequential(
InvertedResidual(channels[0], channels[1], stride=1, expand_ratio=2),
InvertedResidual(channels[1], channels[1], stride=1, expand_ratio=2)
)
# 阶段2: 纯卷积
self.stage2 = nn.Sequential(
InvertedResidual(channels[1], channels[2], stride=2, expand_ratio=2),
InvertedResidual(channels[2], channels[2], stride=1, expand_ratio=2),
InvertedResidual(channels[2], channels[2], stride=1, expand_ratio=2)
)
# 阶段3: MobileViT块
self.stage3 = nn.Sequential(
MobileViTBlock(
in_channels=channels[2],
out_channels=channels[3],
dim=dims[0],
num_heads=num_heads,
mlp_ratio=mlp_ratio,
n_transformer_blocks=n_transformer_blocks[0],
dropout=dropout,
attn_dropout=attn_dropout,
drop_path=drop_path_rate,
patch_size=patch_size
),
MobileViTBlock(
in_channels=channels[3],
out_channels=channels[3],
dim=dims[0],
num_heads=num_heads,
mlp_ratio=mlp_ratio,
n_transformer_blocks=n_transformer_blocks[0],
dropout=dropout,
attn_dropout=attn_dropout,
drop_path=drop_path_rate,
patch_size=patch_size
)
)
# 阶段4: MobileViT块
self.stage4 = nn.Sequential(
MobileViTBlock(
in_channels=channels[3],
out_channels=channels[4],
dim=dims[1],
num_heads=num_heads,
mlp_ratio=mlp_ratio,
n_transformer_blocks=n_transformer_blocks[1],
dropout=dropout,
attn_dropout=attn_dropout,
drop_path=drop_path_rate,
patch_size=patch_size
),
MobileViTBlock(
in_channels=channels[4],
out_channels=channels[4],
dim=dims[1],
num_heads=num_heads,
mlp_ratio=mlp_ratio,
n_transformer_blocks=n_transformer_blocks[1],
dropout=dropout,
attn_dropout=attn_dropout,
drop_path=drop_path_rate,
patch_size=patch_size
)
)
# 阶段5: MobileViT块
self.stage5 = nn.Sequential(
MobileViTBlock(
in_channels=channels[4],
out_channels=channels[5],
dim=dims[2],
num_heads=num_heads,
mlp_ratio=mlp_ratio,
n_transformer_blocks=n_transformer_blocks[2],
dropout=dropout,
attn_dropout=attn_dropout,
drop_path=drop_path_rate,
patch_size=patch_size
),
MobileViTBlock(
in_channels=channels[5],
out_channels=channels[5],
dim=dims[2],
num_heads=num_heads,
mlp_ratio=mlp_ratio,
n_transformer_blocks=n_transformer_blocks[2],
dropout=dropout,
attn_dropout=attn_dropout,
drop_path=drop_path_rate,
patch_size=patch_size
)
)
# 分类头
self.conv_1x1_exp = ConvLayer(channels[5], dims[3], kernel_size=1)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Sequential(
nn.Linear(dims[3], 1024),
nn.SiLU(inplace=True),
nn.Dropout(dropout),
nn.Linear(1024, num_classes)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.stem(x)
x = self.stage1(x)
x = self.stage2(x)
x = self.stage3(x)
x = self.stage4(x)
x = self.stage5(x)
x = self.conv_1x1_exp(x)
x = self.avg_pool(x).flatten(1)
x = self.classifier(x)
return x
# 模型配置
def mobilevit_v2_xs(num_classes: int = 1000, **kwargs):
return MobileViT(
dims=[48, 64, 80, 320],
channels=[16, 32, 48, 64, 80, 320],
num_heads=4,
n_transformer_blocks=[2, 4, 3],
**kwargs
)
def mobilevit_v2_s(num_classes: int = 1000, **kwargs):
return MobileViT(
dims=[64, 80, 96, 384],
channels=[32, 64, 96, 128, 160, 640],
num_heads=4,
n_transformer_blocks=[2, 4, 3],
**kwargs
)
if __name__ == "__main__":
model = mobilevit_v2_xs(num_classes=1000)
x = torch.randn(1, 3, 256, 256)
with torch.no_grad():
output = model(x)
print(f"Input: {x.shape}")
print(f"Output: {output.shape}")
# 计算参数量和FLOPs
params = sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6
print(f"Parameters: {params:.2f}M")
11.3.2 EfficientFormer与EdgeViT
EfficientFormer通过延迟驱动的瘦身(Latency Driven Slimming)策略,系统性地识别并移除ViT中的低效组件,构建出移动设备友好的纯Transformer架构。其核心发现是:ViT中的非线性激活函数、归一化层和复杂注意力机制是延迟的主要瓶颈。EfficientFormer采用类似MobileNet的DeLite设计,使用4个阶段的超网络,逐步下采样并增加通道数。
关键优化包括:使用 3×3 卷积作为token混合器替代自注意力(在浅层);仅在深层使用自注意力;采用BatchNorm替代LayerNorm(在移动设备上更高效);使用ReLU替代GELU;以及采用4D特征融合(保持张量为4D格式,避免频繁的reshape操作)。EfficientFormer-L1在ImageNet-1K上达到79.2%的Top-1准确率,iPhone 12上的延迟仅为1.6ms,显著优于MobileViT v2。
EdgeViT针对边缘设备进一步优化,提出基于稀疏自注意力的局部-全局-局部(LGL)瓶颈。LGL首先使用局部聚合(深度可分离卷积)捕获局部信息,然后通过稀疏全局自注意力建模长距离依赖,最后再次局部聚合细化特征。这种设计将全局注意力的计算限制在降采样的特征上,显著降低了复杂度。
实例:EfficientFormer架构实现
"""
Script: efficientformer.py
Description: EfficientFormer 延迟优化架构实现
Usage:
python efficientformer.py --model efficientformer_l1 --benchmark_mobile
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Optional, Tuple
class ConvBNAct(nn.Module):
"""
移动设备优化的卷积块
使用BatchNorm(而非LayerNorm)和ReLU(而非GELU),
在移动CPU/GPU上具有更低的延迟。
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
stride: int = 1,
groups: int = 1,
use_act: bool = True
):
super().__init__()
padding = (kernel_size - 1) // 2
self.conv = nn.Conv2d(
in_channels, out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups,
bias=False
)
self.bn = nn.BatchNorm2d(out_channels)
self.act = nn.ReLU(inplace=True) if use_act else nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.act(self.bn(self.conv(x)))
class PoolingTokenMixer(nn.Module):
"""
池化Token混合器
使用3x3深度可分离卷积替代自注意力,在浅层实现高效的局部混合。
这是EfficientFormer的关键延迟优化策略。
"""
def __init__(self, dim: int):
super().__init__()
self.pool = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
self.conv = ConvBNAct(dim, dim, kernel_size=1, use_act=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.conv(self.pool(x)) + x
class AttentionTokenMixer(nn.Module):
"""
轻量级自注意力,仅在深层使用
采用简化的多头注意力,针对移动设备优化内存访问。
"""
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
attn_drop: float = 0.0,
proj_drop: float = 0.0
):
super().__init__()
assert dim % num_heads == 0, 'dim must be divisible by num_heads'
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class MLP(nn.Module):
"""
移动设备优化的MLP
使用1x1卷积实现(而非Linear),保持4D张量格式,
避免reshape操作的开销。
"""
def __init__(
self,
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
drop: float = 0.0
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = ConvBNAct(in_features, hidden_features, kernel_size=1)
self.fc2 = ConvBNAct(hidden_features, out_features, kernel_size=1, use_act=False)
self.drop = nn.Dropout(drop)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.fc1(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class MetaBlock4D(nn.Module):
"""
4D MetaBlock,使用池化混合器
保持特征图为4D格式 (B, C, H, W),优化内存布局。
"""
def __init__(
self,
dim: int,
mlp_ratio: float = 4.0,
drop: float = 0.0,
drop_path: float = 0.0
):
super().__init__()
self.token_mixer = PoolingTokenMixer(dim)
self.mlp = MLP(dim, int(dim * mlp_ratio), drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0 else nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.drop_path(self.token_mixer(x))
x = x + self.drop_path(self.mlp(x))
return x
class MetaBlock3D(nn.Module):
"""
3D MetaBlock,使用自注意力混合器
将特征图reshape为3D序列 (B, N, C),应用自注意力。
仅在深层使用,以平衡性能和延迟。
"""
def __init__(
self,
dim: int,
mlp_ratio: float = 4.0,
num_heads: int = 8,
drop: float = 0.0,
attn_drop: float = 0.0,
drop_path: float = 0.0
):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.token_mixer = AttentionTokenMixer(dim, num_heads, attn_drop=attn_drop, proj_drop=drop)
self.norm2 = nn.LayerNorm(dim)
self.mlp = nn.Sequential(
nn.Linear(dim, int(dim * mlp_ratio)),
nn.ReLU(inplace=True),
nn.Dropout(drop),
nn.Linear(int(dim * mlp_ratio), dim),
nn.Dropout(drop)
)
self.drop_path1 = DropPath(drop_path) if drop_path > 0 else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0 else nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, C, H, W = x.shape
x_ = x.flatten(2).transpose(1, 2) # (B, N, C)
x_ = x_ + self.drop_path1(self.token_mixer(self.norm1(x_)))
x_ = x_ + self.drop_path2(self.mlp(self.norm2(x_)))
x = x_.transpose(1, 2).reshape(B, C, H, W)
return x
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth)"""
def __init__(self, drop_prob: float = 0.0):
super().__init__()
self.drop_prob = drop_prob
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.drop_prob == 0.0 or not self.training:
return x
keep_prob = 1 - self.drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_()
output = x.div(keep_prob) * random_tensor
return output
class EfficientFormer(nn.Module):
"""
EfficientFormer完整模型
4个阶段架构,前两个阶段使用4D MetaBlock(池化混合器),
后两个阶段使用3D MetaBlock(自注意力混合器)。
"""
def __init__(
self,
in_channels: int = 3,
num_classes: int = 1000,
depths: List[int] = [3, 3, 9, 6],
dims: List[int] = [48, 96, 224, 448],
mlp_ratios: List[float] = [4.0, 4.0, 4.0, 4.0],
num_heads: int = 8,
drop_rate: float = 0.0,
attn_drop_rate: float = 0.0,
drop_path_rate: float = 0.0
):
super().__init__()
self.num_classes = num_classes
self.num_stages = len(depths)
# 下采样层
self.patch_embed = nn.Sequential(
ConvBNAct(in_channels, dims[0] // 2, kernel_size=3, stride=2),
ConvBNAct(dims[0] // 2, dims[0], kernel_size=3, stride=2)
)
# 构建4个阶段
self.stages = nn.ModuleList()
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
cur = 0
for i in range(self.num_stages):
stage_blocks = []
for j in range(depths[i]):
drop_path = dpr[cur + j]
# 前两个阶段使用4D块(池化),后两个阶段使用3D块(注意力)
if i < 2:
block = MetaBlock4D(
dim=dims[i],
mlp_ratio=mlp_ratios[i],
drop=drop_rate,
drop_path=drop_path
)
else:
block = MetaBlock3D(
dim=dims[i],
mlp_ratio=mlp_ratios[i],
num_heads=num_heads,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=drop_path
)
stage_blocks.append(block)
self.stages.append(nn.Sequential(*stage_blocks))
cur += depths[i]
# 阶段间下采样
if i < self.num_stages - 1:
setattr(self, f"downsample_{i+1}",
ConvBNAct(dims[i], dims[i+1], kernel_size=3, stride=2))
# 分类头
self.norm = nn.BatchNorm2d(dims[-1])
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.head = nn.Linear(dims[-1], num_classes) if num_classes > 0 else nn.Identity()
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, (nn.Conv2d, nn.Linear)):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, (nn.BatchNorm2d, nn.LayerNorm)):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
x = self.patch_embed(x)
for i, stage in enumerate(self.stages):
x = stage(x)
if i < self.num_stages - 1:
downsample = getattr(self, f"downsample_{i+1}")
x = downsample(x)
x = self.norm(x)
x = self.avgpool(x).flatten(1)
return x
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.forward_features(x)
x = self.head(x)
return x
# 模型配置
def efficientformer_l1(**kwargs):
return EfficientFormer(
depths=[3, 3, 9, 6],
dims=[48, 96, 224, 448],
**kwargs
)
def efficientformer_l3(**kwargs):
return EfficientFormer(
depths=[4, 4, 12, 6],
dims=[64, 128, 320, 512],
**kwargs
)
def efficientformer_l7(**kwargs):
return EfficientFormer(
depths=[6, 6, 18, 8],
dims=[96, 192, 384, 768],
**kwargs
)
if __name__ == "__main__":
model = efficientformer_l1(num_classes=1000)
x = torch.randn(1, 3, 224, 224)
with torch.no_grad():
output = model(x)
print(f"Input: {x.shape}")
print(f"Output: {output.shape}")
params = sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6
print(f"Parameters: {params:.2f}M")
11.3.3 混合架构:CoAtNet, MaxViT
混合架构通过有机结合卷积和Transformer的优势,在多个尺度上实现高效的视觉建模。CoAtNet(Convolution and Attention Network)系统性地研究了卷积与注意力的结合方式,提出在浅层使用卷积(利用其归纳偏置实现快速收敛),在深层使用Transformer(利用其全局建模能力实现高性能)。CoAtNet采用相对注意力(Relative Attention)机制,将卷积的平移等变性融入自注意力,使模型能够更好地泛化到不同分辨率。
MaxViT(Multi-Axis Vision Transformer)引入多轴注意力机制,将全局注意力分解为局部窗口注意力和扩展网格注意力。这种分解在保持全局感受野的同时,将复杂度降至线性。MaxViT还采用MBConv(Mobile Inverted Bottleneck)作为基本构建单元,结合Squeeze-and-Excitation优化,实现了高效的特征提取。
实例:MaxViT多轴注意力实现
"""
Script: maxvit.py
Description: MaxViT 多轴注意力与混合架构实现
Usage:
python maxvit.py --model maxvit_tiny --grid_size 7
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Optional
class MBConv(nn.Module):
"""
Mobile Inverted Bottleneck (MBConv)
混合架构的基础构建单元,结合深度可分离卷积和SE模块。
"""
def __init__(
self,
in_channels: int,
out_channels: int,
expand_ratio: int = 4,
stride: int = 1,
se_ratio: float = 0.25
):
super().__init__()
self.use_res_connect = stride == 1 and in_channels == out_channels
hidden_dim = in_channels * expand_ratio
layers = []
# 扩展
if expand_ratio != 1:
layers.append(nn.Conv2d(in_channels, hidden_dim, 1, bias=False))
layers.append(nn.BatchNorm2d(hidden_dim))
layers.append(nn.GELU())
# 深度可分离卷积
layers.extend([
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.GELU()
])
# Squeeze-and-Excitation
if se_ratio > 0:
se_channels = max(1, int(hidden_dim * se_ratio))
layers.append(SEModule(hidden_dim, se_channels))
# 投影
layers.extend([
nn.Conv2d(hidden_dim, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels)
])
self.conv = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.use_res_connect:
return x + self.conv(x)
return self.conv(x)
class SEModule(nn.Module):
"""Squeeze-and-Excitation模块"""
def __init__(self, channels: int, reduction_channels: int):
super().__init__()
self.fc1 = nn.Conv2d(channels, reduction_channels, 1)
self.fc2 = nn.Conv2d(reduction_channels, channels, 1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
scale = x.mean((2, 3), keepdim=True)
scale = F.gelu(self.fc1(scale))
scale = torch.sigmoid(self.fc2(scale))
return x * scale
class RelativeAttention(nn.Module):
"""
相对位置注意力
引入相对位置偏置,增强平移等变性。
"""
def __init__(
self,
dim: int,
num_heads: int = 8,
window_size: Tuple[int, int] = (7, 7),
qkv_bias: bool = False,
attn_drop: float = 0.0,
proj_drop: float = 0.0
):
super().__init__()
self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.window_size = window_size
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
# 相对位置偏置表
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
)
nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02)
# 计算位置索引
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing='ij'))
coords_flatten = torch.flatten(coords, 1)
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
relative_coords[:, :, 0] += window_size[0] - 1
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = relative_coords.sum(-1)
self.register_buffer("relative_position_index", relative_position_index)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
q = q * self.scale
attn = q @ k.transpose(-2, -1)
# 添加相对位置偏置
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1)
].view(N, N, -1)
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
attn = attn + relative_position_bias.unsqueeze(0)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class MaxViTBlock(nn.Module):
"""
MaxViT基础块
包含三个连续操作:
1. MBConv局部特征提取
2. 块注意力(Block Attention,局部窗口)
3. 网格注意力(Grid Attention,全局稀疏)
"""
def __init__(
self,
dim: int,
num_heads: int,
window_size: Tuple[int, int] = (7, 7),
grid_size: Tuple[int, int] = (7, 7),
mlp_ratio: float = 4.0,
drop: float = 0.0,
attn_drop: float = 0.0,
drop_path: float = 0.0
):
super().__init__()
self.dim = dim
self.window_size = window_size
self.grid_size = grid_size
# 1. MBConv
self.mbconv = MBConv(dim, dim, expand_ratio=4)
# 2. 块注意力(局部)
self.norm1 = nn.LayerNorm(dim)
self.block_attn = RelativeAttention(dim, num_heads, window_size, attn_drop=attn_drop, proj_drop=drop)
self.norm2 = nn.LayerNorm(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp1 = nn.Sequential(
nn.Linear(dim, mlp_hidden_dim),
nn.GELU(),
nn.Dropout(drop),
nn.Linear(mlp_hidden_dim, dim),
nn.Dropout(drop)
)
# 3. 网格注意力(全局稀疏)
self.norm3 = nn.LayerNorm(dim)
self.grid_attn = RelativeAttention(dim, num_heads, grid_size, attn_drop=attn_drop, proj_drop=drop)
self.norm4 = nn.LayerNorm(dim)
self.mlp2 = nn.Sequential(
nn.Linear(dim, mlp_hidden_dim),
nn.GELU(),
nn.Dropout(drop),
nn.Linear(mlp_hidden_dim, dim),
nn.Dropout(drop)
)
self.drop_path1 = DropPath(drop_path) if drop_path > 0 else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0 else nn.Identity()
self.drop_path3 = DropPath(drop_path) if drop_path > 0 else nn.Identity()
def window_partition(self, x: torch.Tensor, window_size: Tuple[int, int]) -> Tuple[torch.Tensor, Tuple[int, int]]:
"""划分窗口"""
B, H, W, C = x.shape
x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous()
windows = windows.view(-1, window_size[0] * window_size[1], C)
return windows, (H, W)
def window_reverse(self, windows: torch.Tensor, window_size: Tuple[int, int], H: int, W: int) -> torch.Tensor:
"""还原窗口"""
C = windows.shape[-1]
x = windows.view(-1, H // window_size[0], W // window_size[1], window_size[0], window_size[1], C)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C)
return x
def grid_partition(self, x: torch.Tensor, grid_size: Tuple[int, int]) -> Tuple[torch.Tensor, Tuple[int, int]]:
"""
划分网格(全局稀疏注意力)
将特征图划分为稀疏的网格单元,实现全局感受野。
"""
B, H, W, C = x.shape
# 重排为网格
x = x.view(B, grid_size[0], H // grid_size[0], grid_size[1], W // grid_size[1], C)
grids = x.permute(0, 2, 4, 1, 3, 5).contiguous()
grids = grids.view(-1, grid_size[0] * grid_size[1], C)
return grids, (H, W)
def grid_reverse(self, grids: torch.Tensor, grid_size: Tuple[int, int], H: int, W: int) -> torch.Tensor:
"""还原网格"""
C = grids.shape[-1]
x = grids.view(-1, H // grid_size[0], W // grid_size[1], grid_size[0], grid_size[1], C)
x = x.permute(0, 3, 1, 4, 2, 5).contiguous().view(-1, H, W, C)
return x
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, C, H, W = x.shape
# 1. MBConv
x = x + self.drop_path1(self.mbconv(x))
# 转换为序列格式
x = x.permute(0, 2, 3, 1) # (B, H, W, C)
# 2. 块注意力(局部窗口)
shortcut = x
x = self.norm1(x)
x_windows, _ = self.window_partition(x, self.window_size)
attn_windows = self.block_attn(x_windows)
x = self.window_reverse(attn_windows, self.window_size, H, W)
x = shortcut + self.drop_path2(x)
x = x + self.drop_path2(self.mlp1(self.norm2(x)))
# 3. 网格注意力(全局)
shortcut = x
x = self.norm3(x)
x_grids, _ = self.grid_partition(x, self.grid_size)
attn_grids = self.grid_attn(x_grids)
x = self.grid_reverse(attn_grids, self.grid_size, H, W)
x = shortcut + self.drop_path3(x)
x = x + self.drop_path3(self.mlp2(self.norm4(x)))
# 转换回卷积格式
x = x.permute(0, 3, 1, 2) # (B, C, H, W)
return x
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth)"""
def __init__(self, drop_prob: float = 0.0):
super().__init__()
self.drop_prob = drop_prob
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.drop_prob == 0.0 or not self.training:
return x
keep_prob = 1 - self.drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_()
output = x.div(keep_prob) * random_tensor
return output
class MaxViT(nn.Module):
"""
MaxViT完整模型
4个阶段的混合架构,每个阶段包含多个MaxViT块。
"""
def __init__(
self,
in_channels: int = 3,
num_classes: int = 1000,
depths: Tuple[int, ...] = (2, 2, 5, 2),
dims: Tuple[int, ...] = (64, 128, 256, 512),
num_heads: Tuple[int, ...] = (2, 4, 8, 16),
window_size: Tuple[int, int] = (7, 7),
grid_size: Tuple[int, int] = (7, 7),
mlp_ratio: float = 4.0,
drop_rate: float = 0.0,
attn_drop_rate: float = 0.0,
drop_path_rate: float = 0.0
):
super().__init__()
self.num_classes = num_classes
# Stem
self.stem = nn.Sequential(
nn.Conv2d(in_channels, dims[0], 3, 2, 1, bias=False),
nn.BatchNorm2d(dims[0]),
nn.GELU(),
nn.Conv2d(dims[0], dims[0], 3, 1, 1, bias=False),
nn.BatchNorm2d(dims[0])
)
# 4个阶段
self.stages = nn.ModuleList()
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
cur = 0
for i in range(len(depths)):
stage_blocks = []
for j in range(depths[i]):
stage_blocks.append(
MaxViTBlock(
dim=dims[i],
num_heads=num_heads[i],
window_size=window_size,
grid_size=grid_size,
mlp_ratio=mlp_ratio,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[cur + j]
)
)
self.stages.append(nn.Sequential(*stage_blocks))
cur += depths[i]
# 下采样
if i < len(depths) - 1:
setattr(self, f"downsample_{i+1}",
nn.Sequential(
nn.LayerNorm(dims[i]),
nn.Conv2d(dims[i], dims[i+1], 3, 2, 1, bias=False)
))
# 分类头
self.norm = nn.LayerNorm(dims[-1])
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.head = nn.Linear(dims[-1], num_classes) if num_classes > 0 else nn.Identity()
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, (nn.Conv2d, nn.Linear)):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
x = self.stem(x)
for i, stage in enumerate(self.stages):
x = stage(x)
if i < len(self.stages) - 1:
downsample = getattr(self, f"downsample_{i+1}")
x = downsample(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
x = self.avgpool(x).flatten(1)
return x
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.forward_features(x)
x = self.head(x)
return x
# 模型配置
def maxvit_tiny(**kwargs):
return MaxViT(
depths=(2, 2, 5, 2),
dims=(64, 128, 256, 512),
num_heads=(2, 4, 8, 16),
**kwargs
)
if __name__ == "__main__":
model = maxvit_tiny(num_classes=1000)
x = torch.randn(1, 3, 224, 224)
with torch.no_grad():
output = model(x)
print(f"Input: {x.shape}")
print(f"Output: {output.shape}")
params = sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6
print(f"Parameters: {params:.2f}M")
更多推荐
所有评论(0)