深度学习的数学原理(十八)—— 视觉Transformer(ViT)
ViT通过将图像分块为序列并引入Transformer架构,突破了CNN的局部感受野限制。其核心设计包括:1)图像分块嵌入,将2D图像转为1D序列;2)类别嵌入实现全局信息聚合;3)多头自注意力机制建立像素间全局关联。与CNN相比,ViT的自注意力权重动态适应不同区域,计算复杂度为O(N²)。实验表明,ViT在CIFAR-10上展现出优于ResNet的性能,验证了Transformer在视觉任务中
之前的文章中,我们明确了CNN的核心局限性:受局部连接+滑动窗口约束,即使结合FPN多尺度融合,也无法高效捕捉全局信息。视觉Transformer(Vision Transformer, ViT)的出现彻底打破了这一桎梏——它将图像转化为序列,通过自注意力机制直接建立全局像素间的关联,无需依赖多层卷积堆叠扩大感受野。
本文及下篇文章将先推导ViT适配视觉任务的核心设计(图像分块、位置编码),再完整拆解Transformer的核心组件(多头自注意力、前馈网络、层归一化),最后实现简易ViT并对比其与ResNet在CIFAR-10上的表现。ViT仅根据已有资料来进行Transformer尝试,从实践说明Transformer相较于CNN的优势
本篇暂时不会做复杂数学推导,目的仅为了展示Transformer效果,后续会详细拆解每一部分和其数学原理。
一、ViT的提出背景:从CNN到Transformer的视觉任务适配
CNN处理图像的核心是空间结构优先,而Transformer的设计初衷是处理序列数据(如文本)。ViT的核心创新是将2D图像转化为1D序列,让Transformer能直接适配视觉任务,其逻辑围绕如何保留图像的空间信息展开。
1.1 图像分块嵌入的数学原理
对于尺寸为H×W×CH \times W \times CH×W×C的图像(如CIFAR-10的32×32×332 \times 32 \times 332×32×3),ViT首先将其划分为固定大小的非重叠块(Patch),这一过程的数学定义为:
(1)分块操作的数学表达
设分块尺寸为P×PP \times PP×P,则图像可划分为N=H×WP2N = \frac{H \times W}{P^2}N=P2H×W个块,每个块的尺寸为P×P×CP \times P \times CP×P×C。以CIFAR-10为例,若P=4P=4P=4,则:
N=32×324×4=64 N = \frac{32 \times 32}{4 \times 4} = 64 N=4×432×32=64
每个块的尺寸为4×4×3=484 \times 4 \times 3 = 484×4×3=48维。
分块操作可视为一种「硬编码的卷积」:用步长为PPP、尺寸为P×PP \times PP×P的卷积核对图像进行无重叠滑动,数学上与卷积的区别是:
- 卷积:参数可学习,输出通道数可自定义;
- 图像分块:无参数,输出维度固定为P2×CP^2 \times CP2×C。
(2)块嵌入(Patch Embedding)
将每个P×P×CP \times P \times CP×P×C的块展平为1D向量(维度为Dpatch=P2×CD_{patch} = P^2 \times CDpatch=P2×C),再通过线性层映射到模型的隐藏维度DDD,数学表达式为:
zi0=E⋅xi+be(i=1,2,...,N) \mathbf{z}_i^0 = \mathbf{E} \cdot \mathbf{x}_i + \mathbf{b}_e \quad (i=1,2,...,N) zi0=E⋅xi+be(i=1,2,...,N)
其中:
- xi∈RDpatch\mathbf{x}_i \in \mathbb{R}^{D_{patch}}xi∈RDpatch:第iii个展平后的块向量;
- E∈RD×Dpatch\mathbf{E} \in \mathbb{R}^{D \times D_{patch}}E∈RD×Dpatch:嵌入矩阵(线性层权重);
- be∈RD\mathbf{b}_e \in \mathbb{R}^Dbe∈RD:嵌入偏置;
- zi0∈RD\mathbf{z}_i^0 \in \mathbb{R}^Dzi0∈RD:第iii个块的嵌入向量(Transformer的输入序列元素)。
(3)分块尺寸的选择逻辑
分块尺寸PPP的选择需平衡局部信息保留与序列长度,数学上需满足:
- PPP过小:NNN过大(序列过长),自注意力的计算复杂度呈O(N2)O(N^2)O(N2)增长;
- PPP过大:每个块的局部信息不足,丢失细粒度特征;
- 经验公式:P=H×WNtargetP = \sqrt{\frac{H \times W}{N_{target}}}P=NtargetH×W,其中NtargetN_{target}Ntarget为目标序列长度(通常取16/32/64)。
以CIFAR-10(32×32)为例,若目标序列长度Ntarget=64N_{target}=64Ntarget=64,则P=4P=4P=4;若Ntarget=16N_{target}=16Ntarget=16,则P=8P=8P=8。
1.2 类别嵌入与序列构建
为让Transformer能完成分类任务,ViT在序列开头添加一个可学习的类别嵌入(Class Token)z00∈RD\mathbf{z}_0^0 \in \mathbb{R}^Dz00∈RD,最终输入序列为:
[z00;z10;z20;...;zN0]∈R(N+1)×D [\mathbf{z}_0^0; \mathbf{z}_1^0; \mathbf{z}_2^0; ...; \mathbf{z}_N^0] \in \mathbb{R}^{(N+1) \times D} [z00;z10;z20;...;zN0]∈R(N+1)×D
其中「;」表示拼接操作。分类时仅取类别嵌入的最终输出作为特征,数学上等价于让Transformer聚焦全局信息的汇总结果。
Class Token 本质就是对所有 patch 做自适应参数的加权平均这种设计其实是沿用了Transformer中的[CLS] token
- NLP 中 [CLS] token:汇总整句话的语义信息,用于文本分类;
- ViT 中 Class Token:汇总整张图的视觉信息,用于图像分类;
二、Transformer核心组件①:多头自注意力(MSA)的数学推导
自注意力(Self-Attention)是Transformer的核心,其数学本质是通过计算序列元素间的关联权重,实现全局信息的加权融合,而多头自注意力(Multi-Head Self-Attention, MSA)则进一步提升了特征的表达能力。
2.1 基础自注意力的完整公式
自注意力的输入是序列Z∈R(N+1)×D\mathbf{Z} \in \mathbb{R}^{(N+1) \times D}Z∈R(N+1)×D,输出是与输入维度相同的融合特征,其计算分为三步:
(1)查询/键/值(Q/K/V)映射
将输入序列分别映射到查询(Query)、键(Key)、值(Value)空间,数学表达式为:
Q=Z⋅WQ,K=Z⋅WK,V=Z⋅WV \mathbf{Q} = \mathbf{Z} \cdot \mathbf{W}_Q, \quad \mathbf{K} = \mathbf{Z} \cdot \mathbf{W}_K, \quad \mathbf{V} = \mathbf{Z} \cdot \mathbf{W}_V Q=Z⋅WQ,K=Z⋅WK,V=Z⋅WV
其中:
- WQ,WK,WV∈RD×D\mathbf{W}_Q, \mathbf{W}_K, \mathbf{W}_V \in \mathbb{R}^{D \times D}WQ,WK,WV∈RD×D:可学习的映射矩阵;
- Q,K,V∈R(N+1)×D\mathbf{Q}, \mathbf{K}, \mathbf{V} \in \mathbb{R}^{(N+1) \times D}Q,K,V∈R(N+1)×D:查询、键、值矩阵。
(2)注意力得分与权重计算
计算每个查询与所有键的相似度(得分),并归一化为权重,数学表达式为:
Attention(Q,K,V)=Softmax(QKTdk)V \text{Attention}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \text{Softmax}\left( \frac{\mathbf{Q} \mathbf{K}^T}{\sqrt{d_k}} \right) \mathbf{V} Attention(Q,K,V)=Softmax(dkQKT)V
其中:
- QKT∈R(N+1)×(N+1)\mathbf{Q} \mathbf{K}^T \in \mathbb{R}^{(N+1) \times (N+1)}QKT∈R(N+1)×(N+1):注意力得分矩阵,元素aija_{ij}aij表示第iii个元素对第jjj个元素的关注度;
- dk\sqrt{d_k}dk:缩放因子(dk=D/hd_k = D/hdk=D/h,hhh为头数),用于缓解维度DDD过大导致的得分值爆炸;
- Softmax\text{Softmax}Softmax:行归一化,让每个元素的注意力权重之和为1;
- 最终输出:值矩阵的加权和,实现全局信息融合。
(3)与CNN局部特征提取的对比
| 特征提取方式 | 关联范围 | 权重特性 | 计算复杂度 |
|---|---|---|---|
| CNN卷积 | 局部(卷积核尺寸) | 共享权重(滑动窗口) | O(K2CinCoutHW)O(K^2 C_{in} C_{out} HW)O(K2CinCoutHW) |
| 自注意力 | 全局(整个序列) | 动态权重(逐元素学习) | O(N2D)O(N^2 D)O(N2D) |
核心差异:CNN的权重是空间共享的,而自注意力的权重是动态自适应的——对于图像中的不同区域,自注意力能学习到不同的关联权重,这是其全局特征捕捉能力的核心。
2.2 多头自注意力的数学逻辑
多头自注意力将自注意力拆分为hhh个并行的头,每个头关注不同的特征维度,最终拼接融合,数学流程为:
(1)分拆头
将Q,K,V\mathbf{Q}, \mathbf{K}, \mathbf{V}Q,K,V按维度DDD拆分为hhh个头,每个头的维度为dk=D/hd_k = D/hdk=D/h:
Qi=Q⋅WQi,Ki=K⋅WKi,Vi=V⋅WVi(i=1..h) \mathbf{Q}_i = \mathbf{Q} \cdot \mathbf{W}_{Q_i}, \quad \mathbf{K}_i = \mathbf{K} \cdot \mathbf{W}_{K_i}, \quad \mathbf{V}_i = \mathbf{V} \cdot \mathbf{W}_{V_i} \quad (i=1..h) Qi=Q⋅WQi,Ki=K⋅WKi,Vi=V⋅WVi(i=1..h)
其中WQi,WKi,WVi∈RD×dk\mathbf{W}_{Q_i}, \mathbf{W}_{K_i}, \mathbf{W}_{V_i} \in \mathbb{R}^{D \times d_k}WQi,WKi,WVi∈RD×dk。
(2)多头计算与拼接
对每个头计算自注意力,再将结果拼接并线性映射:
MSA(Z)=Concat(Attention1,...,Attentionh)⋅WO \text{MSA}(\mathbf{Z}) = \text{Concat}(\text{Attention}_1, ..., \text{Attention}_h) \cdot \mathbf{W}_O MSA(Z)=Concat(Attention1,...,Attentionh)⋅WO
其中WO∈RD×D\mathbf{W}_O \in \mathbb{R}^{D \times D}WO∈RD×D为输出映射矩阵,Concat\text{Concat}Concat为拼接操作。
数学意义:多头自注意力让模型能同时捕捉不同类型的全局关联(如颜色关联、纹理关联、形状关联),提升特征的多样性。
三、Transformer核心组件②:FFN与LN的数学设计
除了MSA,Transformer的编码器还包含前馈网络(FFN)和层归一化(Layer Normalization, LN),前者实现特征的非线性变换,后者解决训练不稳定问题。
3.1 层归一化(LN)的数学原理
CNN中常用的Batch Normalization(BN)是批次维度归一化,而Transformer适配序列任务,采用层归一化(LN)——对每个序列元素的特征维度归一化,数学表达式为:
(1)LN的计算公式
LN(z)=γ⋅z−μσ2+ϵ+β \text{LN}(\mathbf{z}) = \gamma \cdot \frac{\mathbf{z} - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta LN(z)=γ⋅σ2+ϵz−μ+β
其中:
- z∈RD\mathbf{z} \in \mathbb{R}^Dz∈RD:单个序列元素的特征向量;
- μ=1D∑i=1Dzi\mu = \frac{1}{D} \sum_{i=1}^D z_iμ=D1∑i=1Dzi:特征维度的均值;
- σ2=1D∑i=1D(zi−μ)2\sigma^2 = \frac{1}{D} \sum_{i=1}^D (z_i - \mu)^2σ2=D1∑i=1D(zi−μ)2:特征维度的方差;
- γ,β∈RD\gamma, \beta \in \mathbb{R}^Dγ,β∈RD:可学习的缩放和平移参数;
- ϵ\epsilonϵ:防止除0的小常数。
(2)LN与BN的核心差异
| 归一化方式 | 归一化维度 | 适用场景 | 缺陷 |
|---|---|---|---|
| BN | 批次维度(同一批次的所有样本) | 图像任务(CNN) | 依赖批次大小,小批次效果差 |
| LN | 特征维度(单个样本的特征) | 序列任务(Transformer) | 对特征分布的鲁棒性稍弱 |
适配性解释:Transformer处理的序列长度NNN可能变化(如不同图像的分块数不同),而LN不依赖批次维度,更适合序列数据的动态特性。
3.2 前馈网络(FFN)的数学逻辑
FFN是对每个序列元素的独立非线性变换,数学表达式为:
FFN(z)=max(0,z⋅W1+b1)⋅W2+b2 \text{FFN}(\mathbf{z}) = \max(0, \mathbf{z} \cdot \mathbf{W}_1 + \mathbf{b}_1) \cdot \mathbf{W}_2 + \mathbf{b}_2 FFN(z)=max(0,z⋅W1+b1)⋅W2+b2
其中:
- W1∈RD×Dff\mathbf{W}_1 \in \mathbb{R}^{D \times D_{ff}}W1∈RD×Dff,W2∈RDff×D\mathbf{W}_2 \in \mathbb{R}^{D_{ff} \times D}W2∈RDff×D(通常Dff=4DD_{ff}=4DDff=4D);
- max(0,⋅)\max(0, \cdot)max(0,⋅):ReLU激活函数,实现非线性;
- 核心逻辑:先升维再降维,扩大特征的表达空间,同时保持序列长度不变。
与CNN的对比:CNN的非线性变换是「空间共享」的(卷积+激活),而FFN是「逐元素独立」的,更灵活但参数更多。
3.3 残差连接的复用
与ResNet类似,Transformer的编码器也使用残差连接,核心公式为:
z′=LN(z+MSA(LN(z))) \mathbf{z}' = \text{LN}(\mathbf{z} + \text{MSA}(\text{LN}(\mathbf{z}))) z′=LN(z+MSA(LN(z)))
z′′=LN(z′+FFN(z′)) \mathbf{z}'' = \text{LN}(\mathbf{z}' + \text{FFN}(\mathbf{z}')) z′′=LN(z′+FFN(z′))
这一设计保证了梯度的稳定传递,与ResNet的残差连接逻辑完全一致——这也是深层Transformer能稳定训练的核心原因。
关键点回顾
本文简要说明了Transformer(ViT)的各个组件,下一篇文章会尝试组合其各个组件,并训练一个简单的ViT对比ResNet的效果
- ViT的核心适配逻辑:图像分块(N=H×W/P2N = H×W/P²N=H×W/P2)→ 块嵌入(线性映射)→ 序列构建(添加类别嵌入+位置编码);
- 多头自注意力的核心公式:MSA=Concat(Attention1,...,Attentionh)⋅WO\text{MSA} = \text{Concat}(\text{Attention}_1,...,\text{Attention}_h) \cdot W_OMSA=Concat(Attention1,...,Attentionh)⋅WO,动态权重实现全局特征融合;
- LN与BN的核心差异:LN对特征维度归一化,适配序列任务;BN对批次维度归一化,适配图像任务;
- ViT的优势在于全局特征捕捉能力,但其训练依赖数据量,在小数据集上需结合数据增强/迁移学习。
更多推荐
所有评论(0)