深度学习:基于自定义 ResNet 的手写数字识别实践(MNIST 数据集)
本次基于自定义 ResNet 的 MNIST 手写数字识别实践,以 “简化、适配、高效” 为核心,既验证了残差网络在简单任务中的有效性,也为初学者提供了 “从模型设计到训练落地” 的完整实战路径。实践表明:残差网络的价值不仅在于 “深层”,更在于 “通过跳跃连接解决训练难题”;针对任务特点的轻量化设计,往往比直接套用复杂预训练模型更具性价比。后续可基于此框架,扩展至更复杂的图像分类任务(如 Fas
目录
一、任务背景与模型选择
手写数字识别是计算机视觉领域的 “入门标杆任务”,其核心是从灰度图像中提取特征并分类(0-9 共 10 类)。常用数据集 MNIST 包含 60000 张训练集和 10000 张测试集,每张图片为 28×28 的灰度图,虽数据规模小、特征简单,但传统 CNN 在深层训练时易出现 “梯度消失” 或 “退化问题”。
本次实践未采用预训练的 ResNet-18(原模型为 3 通道 RGB 输入,适配 MNIST 需额外处理),而是自定义轻量级 ResNet:
- 针对 MNIST 1 通道灰度图设计输入层,避免通道转换冗余;
- 简化残差块结构,减少计算量(适配 CPU/GPU 轻量化训练);
- 保留 ResNet 核心的 “跳跃连接”,解决深层网络训练问题。
该方案的优势在于:无需迁移学习适配,从零训练即可快速收敛,同时让初学者直观理解残差网络的核心逻辑。
二、核心原理:残差块与轻量级 ResNet 设计
ResNet 的核心是残差块(Residual Block),通过 “跳跃连接(Skip Connection)” 让梯度直接回传,避免梯度消失。本次自定义的 ResNet 模型针对 MNIST 数据特点做了 3 点优化:
1. 残差块设计(ResBlock)
传统 ResNet 残差块多为 “3×3 卷积→BN→ReLU” 的组合,本次简化为 “5×3 卷积组合”,在保证特征提取能力的同时减少参数:
- 输入特征图与输出特征图通道数一致(通过
channels_in控制),确保 “跳跃连接” 时可直接元素相加; - 先通过 5×5 卷积扩大感受野(捕捉数字轮廓),再通过 3×3 卷积细化特征,最后与原始输入相加并 ReLU 激活。
残差块前向传播公式:![]()
其中x为原始输入(跳跃连接路径),
为残差路径。
2. 整体网络结构
自定义 ResNet 针对 28×28 灰度图设计,共包含 “特征提取层→残差块→分类层” 三部分,具体结构如下:


三、完整代码实现与逐段解析
本次实践基于 PyTorch 框架,代码包含 “数据加载→模型定义→训练测试→结果输出” 全流程,可直接复制运行(自动下载 MNIST 数据集)。
1. 环境依赖
确保安装以下库(Python 3.8+,PyTorch 1.10+):
pip install torch torchvision matplotlib
2. 完整代码与逐段解析
# -------------------------- 1. 导入必要库 --------------------------
import torch
from torch import nn # 神经网络核心模块
from torch.utils.data import DataLoader # 批量加载数据
from torchvision import datasets # 加载MNIST数据集
from torchvision.transforms import ToTensor # 图像转为Tensor
from matplotlib import pyplot as plt # 可选:可视化数据(本文暂未用)
# -------------------------- 2. 加载并预处理MNIST数据集 --------------------------
# 加载训练集:root为数据保存路径,train=True表示训练集,download=True自动下载
train_data = datasets.MNIST(
root='data', # 数据保存在./data目录下(不存在则自动创建)
train=True,
download=True,
transform=ToTensor(), # 转为Tensor:维度H×W×C→C×H×W,值归一化到0-1
)
# 加载测试集:train=False表示测试集
test_data = datasets.MNIST(
root='data',
train=False,
download=True,
transform=ToTensor(),
)
# 创建数据加载器:按批次加载数据,训练集打乱(shuffle=True)提升泛化能力
train_dataloader = DataLoader(train_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=False) # 测试集无需打乱
# 查看数据形状:验证输入格式是否正确
for x, y in test_dataloader:
print(f"输入图像形状 [批次大小, 通道数, 高度, 宽度]: {x.shape}") # 输出:torch.Size([64, 1, 28, 28])
print(f"标签形状: {y.shape},标签数据类型: {y.dtype}") # 输出:torch.Size([64]) torch.int64
break # 仅查看第一个批次
# -------------------------- 3. 配置训练设备(CPU/GPU自动适配) --------------------------
device = (
"cuda" # NVIDIA GPU
if torch.cuda.is_available()
else "mps" # 苹果M系列芯片GPU
if torch.backends.mps.is_available()
else "cpu" # 无GPU则用CPU
)
print(f"使用训练设备: {device}")
# -------------------------- 4. 定义残差块(ResBlock)与完整ResNet模型 --------------------------
# 残差块:ResNet的核心组件,实现跳跃连接
class ResBlock(nn.Module):
def __init__(self, channels_in):
super().__init__() # 继承nn.Module的初始化方法
# 残差路径:两次卷积(5×5→3×3)
self.conv1 = nn.Conv2d(channels_in, 32, 5, padding=2) # 5×5卷积,输出32通道,padding=2保证尺寸不变
self.conv2 = nn.Conv2d(32, channels_in, 3, padding=1) # 3×3卷积,输出通道数与输入一致(适配跳跃连接)
self.relu = nn.ReLU() # 激活函数
def forward(self, x):
# 残差路径计算
out = self.conv1(x)
out = self.conv2(out)
# 跳跃连接:残差路径结果 + 原始输入,再激活
return self.relu(out + x)
# 完整ResNet模型:针对MNIST设计的轻量级版本
class ResNet(nn.Module):
def __init__(self):
super().__init__()
self.relu = nn.ReLU() # 激活函数(复用)
# 初始卷积层:1→64通道,5×5卷积(捕捉大尺度边缘)
self.conv1 = nn.Conv2d(1, 64, 5, 1, 2) # 输入1通道(灰度图),输出64通道,padding=2保证尺寸不变
# 二次卷积层:64→128通道,3×3卷积(细化局部特征)
self.conv2 = nn.Conv2d(64, 128, 3, 1, 1) # padding=1保证尺寸不变
self.maxpool = nn.MaxPool2d(2) # 2×2最大池化(降维减参)
# 残差块:分别对应64通道和128通道的特征图
self.resblock1 = ResBlock(channels_in=64)
self.resblock2 = ResBlock(channels_in=128)
# 全连接层:输入为展平后的特征数(128×7×7=6272),输出10类
self.full_c = nn.Linear(6272, 10)
def forward(self, x):
size = x.shape[0] # 获取批次大小(用于后续展平)
# 第一层:卷积→池化→激活
x = self.maxpool(self.conv1(x)) # Conv1→MaxPool:28×28→14×14
x = self.relu(x)
x = self.resblock1(x) # 残差块1:处理64通道特征图
# 第二层:卷积→池化→激活
x = self.maxpool(self.conv2(x)) # Conv2→MaxPool:14×14→7×7
x = self.relu(x)
x = self.resblock2(x) # 残差块2:处理128通道特征图
# 展平特征图:从[batch, 128, 7, 7]→[batch, 6272]
x = x.view(size, -1) # -1表示自动计算剩余维度
# 全连接层分类
x = self.full_c(x)
return x
# 初始化模型并转移到目标设备
model = ResNet().to(device)
print("\n自定义残差神经网络结构:")
print(model) # 打印模型结构,验证是否正确
# -------------------------- 5. 定义训练与测试函数 --------------------------
def train(dataloader, model, loss_fn, optimizer):
"""训练函数:单轮训练,更新模型参数"""
model.train() # 开启训练模式(启用Dropout/BN更新等)
batch_size_num = 0 # 批次计数器,用于打印日志
for x, y in dataloader:
# 将数据转移到训练设备(CPU/GPU)
x, y = x.to(device), y.to(device)
# 前向传播:计算模型预测值
pred = model(x)
# 计算损失(多分类任务用CrossEntropyLoss)
loss = loss_fn(pred, y)
# 反向传播:更新参数
optimizer.zero_grad() # 梯度清零(避免累积)
loss.backward() # 计算梯度
optimizer.step() # 根据梯度更新参数
# 记录损失值,每100个批次打印一次日志
loss_value = loss.item() # 从Tensor中提取损失值(避免计算图占用内存)
batch_size_num += 1
if batch_size_num % 100 == 0:
print(f"loss: {loss_value:>7f} [batch: {batch_size_num}]")
def test(dataloader, model, loss_fn):
"""测试函数:评估模型在测试集上的性能(无参数更新)"""
size = len(dataloader.dataset) # 测试集总样本数
num_batches = len(dataloader) # 测试集总批次
model.eval() # 开启评估模式(冻结BN/Dropout)
test_loss, correct = 0, 0 # 累计测试损失和正确预测数
# 关闭梯度计算(测试阶段无需反向传播,节省内存和时间)
with torch.no_grad():
for x, y in dataloader:
x, y = x.to(device), y.to(device)
pred = model(x)
# 累计损失(按批次累加)
test_loss += loss_fn(pred, y).item()
# 计算正确预测数:pred.argmax(1)取每行最大值索引(预测类别),与标签比较
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
# 计算平均损失和准确率
test_loss /= num_batches # 平均损失 = 总损失 / 批次数量
correct /= size # 准确率 = 正确数 / 总样本数
print(f"Test result: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f}\n")
# -------------------------- 6. 配置训练超参数并启动训练 --------------------------
# 损失函数:多分类任务用CrossEntropyLoss(内置Softmax,无需手动添加)
loss_fn = nn.CrossEntropyLoss()
# 优化器:Adam优化器,学习率0.0001(小学习率避免震荡)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
# 训练轮次:50轮(MNIST数据简单,50轮足够收敛)
epochs = 50
# 启动训练循环:每轮训练后测试
for t in range(epochs):
print(f"Epoch {t + 1}\n--------------")
train(train_dataloader, model, loss_fn, optimizer) # 单轮训练
test(test_dataloader, model, loss_fn) # 单轮测试
print("Training Done!") # 训练结束
四、训练过程与结果分析


五、总结
本次基于自定义 ResNet 的 MNIST 手写数字识别实践,以 “简化、适配、高效” 为核心,既验证了残差网络在简单任务中的有效性,也为初学者提供了 “从模型设计到训练落地” 的完整实战路径。实践表明:残差网络的价值不仅在于 “深层”,更在于 “通过跳跃连接解决训练难题”;针对任务特点的轻量化设计,往往比直接套用复杂预训练模型更具性价比。后续可基于此框架,扩展至更复杂的图像分类任务(如 Fashion-MNIST、CIFAR-10),进一步深化对残差网络的理解与应用。
更多推荐
所有评论(0)