目录

一、任务背景与模型选择

二、核心原理:残差块与轻量级 ResNet 设计

1. 残差块设计(ResBlock)

2. 整体网络结构

三、完整代码实现与逐段解析

1. 环境依赖

2. 完整代码与逐段解析

四、训练过程与结果分析

五、总结


一、任务背景与模型选择

手写数字识别是计算机视觉领域的 “入门标杆任务”,其核心是从灰度图像中提取特征并分类(0-9 共 10 类)。常用数据集 MNIST 包含 60000 张训练集和 10000 张测试集,每张图片为 28×28 的灰度图,虽数据规模小、特征简单,但传统 CNN 在深层训练时易出现 “梯度消失” 或 “退化问题”。

本次实践未采用预训练的 ResNet-18(原模型为 3 通道 RGB 输入,适配 MNIST 需额外处理),而是自定义轻量级 ResNet

  • 针对 MNIST 1 通道灰度图设计输入层,避免通道转换冗余;
  • 简化残差块结构,减少计算量(适配 CPU/GPU 轻量化训练);
  • 保留 ResNet 核心的 “跳跃连接”,解决深层网络训练问题。

该方案的优势在于:无需迁移学习适配,从零训练即可快速收敛,同时让初学者直观理解残差网络的核心逻辑。

二、核心原理:残差块与轻量级 ResNet 设计

ResNet 的核心是残差块(Residual Block),通过 “跳跃连接(Skip Connection)” 让梯度直接回传,避免梯度消失。本次自定义的 ResNet 模型针对 MNIST 数据特点做了 3 点优化:

1. 残差块设计(ResBlock)

传统 ResNet 残差块多为 “3×3 卷积→BN→ReLU” 的组合,本次简化为 “5×3 卷积组合”,在保证特征提取能力的同时减少参数:

  • 输入特征图与输出特征图通道数一致(通过channels_in控制),确保 “跳跃连接” 时可直接元素相加;
  • 先通过 5×5 卷积扩大感受野(捕捉数字轮廓),再通过 3×3 卷积细化特征,最后与原始输入相加并 ReLU 激活。

残差块前向传播公式:

其中x为原始输入(跳跃连接路径),为残差路径。

2. 整体网络结构

自定义 ResNet 针对 28×28 灰度图设计,共包含 “特征提取层→残差块→分类层” 三部分,具体结构如下:

三、完整代码实现与逐段解析

本次实践基于 PyTorch 框架,代码包含 “数据加载→模型定义→训练测试→结果输出” 全流程,可直接复制运行(自动下载 MNIST 数据集)。

1. 环境依赖

确保安装以下库(Python 3.8+,PyTorch 1.10+):

pip install torch torchvision matplotlib

2. 完整代码与逐段解析

# -------------------------- 1. 导入必要库 --------------------------
import torch
from torch import nn  # 神经网络核心模块
from torch.utils.data import DataLoader  # 批量加载数据
from torchvision import datasets  # 加载MNIST数据集
from torchvision.transforms import ToTensor  # 图像转为Tensor
from matplotlib import pyplot as plt  # 可选:可视化数据(本文暂未用)

# -------------------------- 2. 加载并预处理MNIST数据集 --------------------------
# 加载训练集:root为数据保存路径,train=True表示训练集,download=True自动下载
train_data = datasets.MNIST(
    root='data',  # 数据保存在./data目录下(不存在则自动创建)
    train=True,
    download=True,
    transform=ToTensor(),  # 转为Tensor:维度H×W×C→C×H×W,值归一化到0-1
)

# 加载测试集:train=False表示测试集
test_data = datasets.MNIST(
    root='data',
    train=False,
    download=True,
    transform=ToTensor(),
)

# 创建数据加载器:按批次加载数据,训练集打乱(shuffle=True)提升泛化能力
train_dataloader = DataLoader(train_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=False)  # 测试集无需打乱

# 查看数据形状:验证输入格式是否正确
for x, y in test_dataloader:
    print(f"输入图像形状 [批次大小, 通道数, 高度, 宽度]: {x.shape}")  # 输出:torch.Size([64, 1, 28, 28])
    print(f"标签形状: {y.shape},标签数据类型: {y.dtype}")  # 输出:torch.Size([64]) torch.int64
    break  # 仅查看第一个批次

# -------------------------- 3. 配置训练设备(CPU/GPU自动适配) --------------------------
device = (
    "cuda"  # NVIDIA GPU
    if torch.cuda.is_available()
    else "mps"  # 苹果M系列芯片GPU
    if torch.backends.mps.is_available()
    else "cpu"  # 无GPU则用CPU
)
print(f"使用训练设备: {device}")

# -------------------------- 4. 定义残差块(ResBlock)与完整ResNet模型 --------------------------
# 残差块:ResNet的核心组件,实现跳跃连接
class ResBlock(nn.Module):
    def __init__(self, channels_in):
        super().__init__()  # 继承nn.Module的初始化方法
        # 残差路径:两次卷积(5×5→3×3)
        self.conv1 = nn.Conv2d(channels_in, 32, 5, padding=2)  # 5×5卷积,输出32通道,padding=2保证尺寸不变
        self.conv2 = nn.Conv2d(32, channels_in, 3, padding=1)  # 3×3卷积,输出通道数与输入一致(适配跳跃连接)
        self.relu = nn.ReLU()  # 激活函数

    def forward(self, x):
        # 残差路径计算
        out = self.conv1(x)
        out = self.conv2(out)
        # 跳跃连接:残差路径结果 + 原始输入,再激活
        return self.relu(out + x)

# 完整ResNet模型:针对MNIST设计的轻量级版本
class ResNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.relu = nn.ReLU()  # 激活函数(复用)
        # 初始卷积层:1→64通道,5×5卷积(捕捉大尺度边缘)
        self.conv1 = nn.Conv2d(1, 64, 5, 1, 2)  # 输入1通道(灰度图),输出64通道,padding=2保证尺寸不变
        # 二次卷积层:64→128通道,3×3卷积(细化局部特征)
        self.conv2 = nn.Conv2d(64, 128, 3, 1, 1)  # padding=1保证尺寸不变
        self.maxpool = nn.MaxPool2d(2)  # 2×2最大池化(降维减参)
        # 残差块:分别对应64通道和128通道的特征图
        self.resblock1 = ResBlock(channels_in=64)
        self.resblock2 = ResBlock(channels_in=128)
        # 全连接层:输入为展平后的特征数(128×7×7=6272),输出10类
        self.full_c = nn.Linear(6272, 10)

    def forward(self, x):
        size = x.shape[0]  # 获取批次大小(用于后续展平)
        # 第一层:卷积→池化→激活
        x = self.maxpool(self.conv1(x))  # Conv1→MaxPool:28×28→14×14
        x = self.relu(x)
        x = self.resblock1(x)  # 残差块1:处理64通道特征图
        # 第二层:卷积→池化→激活
        x = self.maxpool(self.conv2(x))  # Conv2→MaxPool:14×14→7×7
        x = self.relu(x)
        x = self.resblock2(x)  # 残差块2:处理128通道特征图
        # 展平特征图:从[batch, 128, 7, 7]→[batch, 6272]
        x = x.view(size, -1)  # -1表示自动计算剩余维度
        # 全连接层分类
        x = self.full_c(x)
        return x

# 初始化模型并转移到目标设备
model = ResNet().to(device)
print("\n自定义残差神经网络结构:")
print(model)  # 打印模型结构,验证是否正确

# -------------------------- 5. 定义训练与测试函数 --------------------------
def train(dataloader, model, loss_fn, optimizer):
    """训练函数:单轮训练,更新模型参数"""
    model.train()  # 开启训练模式(启用Dropout/BN更新等)
    batch_size_num = 0  # 批次计数器,用于打印日志
    for x, y in dataloader:
        # 将数据转移到训练设备(CPU/GPU)
        x, y = x.to(device), y.to(device)
        # 前向传播:计算模型预测值
        pred = model(x)
        # 计算损失(多分类任务用CrossEntropyLoss)
        loss = loss_fn(pred, y)

        # 反向传播:更新参数
        optimizer.zero_grad()  # 梯度清零(避免累积)
        loss.backward()  # 计算梯度
        optimizer.step()  # 根据梯度更新参数

        # 记录损失值,每100个批次打印一次日志
        loss_value = loss.item()  # 从Tensor中提取损失值(避免计算图占用内存)
        batch_size_num += 1
        if batch_size_num % 100 == 0:
            print(f"loss: {loss_value:>7f} [batch: {batch_size_num}]")

def test(dataloader, model, loss_fn):
    """测试函数:评估模型在测试集上的性能(无参数更新)"""
    size = len(dataloader.dataset)  # 测试集总样本数
    num_batches = len(dataloader)  # 测试集总批次
    model.eval()  # 开启评估模式(冻结BN/Dropout)
    test_loss, correct = 0, 0  # 累计测试损失和正确预测数

    # 关闭梯度计算(测试阶段无需反向传播,节省内存和时间)
    with torch.no_grad():
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)
            pred = model(x)
            # 累计损失(按批次累加)
            test_loss += loss_fn(pred, y).item()
            # 计算正确预测数:pred.argmax(1)取每行最大值索引(预测类别),与标签比较
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    # 计算平均损失和准确率
    test_loss /= num_batches  # 平均损失 = 总损失 / 批次数量
    correct /= size  # 准确率 = 正确数 / 总样本数
    print(f"Test result: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f}\n")

# -------------------------- 6. 配置训练超参数并启动训练 --------------------------
# 损失函数:多分类任务用CrossEntropyLoss(内置Softmax,无需手动添加)
loss_fn = nn.CrossEntropyLoss()
# 优化器:Adam优化器,学习率0.0001(小学习率避免震荡)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
# 训练轮次:50轮(MNIST数据简单,50轮足够收敛)
epochs = 50

# 启动训练循环:每轮训练后测试
for t in range(epochs):
    print(f"Epoch {t + 1}\n--------------")
    train(train_dataloader, model, loss_fn, optimizer)  # 单轮训练
    test(test_dataloader, model, loss_fn)  # 单轮测试
print("Training Done!")  # 训练结束

四、训练过程与结果分析

五、总结

本次基于自定义 ResNet 的 MNIST 手写数字识别实践,以 “简化、适配、高效” 为核心,既验证了残差网络在简单任务中的有效性,也为初学者提供了 “从模型设计到训练落地” 的完整实战路径。实践表明:残差网络的价值不仅在于 “深层”,更在于 “通过跳跃连接解决训练难题”;针对任务特点的轻量化设计,往往比直接套用复杂预训练模型更具性价比。后续可基于此框架,扩展至更复杂的图像分类任务(如 Fashion-MNIST、CIFAR-10),进一步深化对残差网络的理解与应用。

Logo

腾讯云面向开发者汇聚海量精品云计算使用和开发经验,营造开放的云计算技术生态圈。

更多推荐