【小白也能学会】用PyTorch实现三层神经网络:代码逐行详解+实战指南
本文详细讲解了如何使用PyTorch实现一个基础的三层神经网络。首先介绍了神经网络的重要性及应用场景,然后通过代码示例展示了如何定义网络结构(784→256→128→64→10),包括初始化方法、全连接层和前向传播过程。重点解释了ReLU激活函数的作用、数据展平的必要性,以及为什么输出层不需要softmax。最后提供了完整的实战代码,帮助读者理解网络构建的核心概念,并为后续学习训练网络、调整结构和
·
本文用最通俗的语言讲解PyTorch神经网络实现,即使你是零基础小白,也能轻松理解!
一、前言:为什么要学习神经网络?
在人工智能时代,神经网络是机器学习的核心技术之一。无论是图像识别、语音助手还是自动驾驶,都离不开神经网络的支持。今天我们就来学习如何使用PyTorch框架实现一个基础的三层神经网络模型。
二、整体代码结构概览
我们先来看一下完整的神经网络代码):
### 改进的三层神经网络模型
class o1Net(nn.Module):
def __init__(self):
super().__init__()
# 定义全连接层
self.fc1 = nn.Linear(28 * 28, 256) # 输入是28x28的灰度图像,输出是256个神经元
self.fc2 = nn.Linear(256, 128) # 第二层,输入256,输出128
self.fc3 = nn.Linear(128, 64) # 第三层,输入128,输出64
self.fc4 = nn.Linear(64, 10) # 最后一层,输入64,输出10个类别
def forward(self, x):
x = torch.flatten(x, start_dim=1) # 展平数据
x = torch.relu(self.fc1(x)) # 第一层 + ReLU激活
x = torch.relu(self.fc2(x)) # 第二层 + ReLU激活
x = torch.relu(self.fc3(x)) # 第三层 + ReLU激活
x = self.fc4(x) # 第四层(输出层)
return x
三、逐行代码详解(小白友好版)
1. 类定义:class o1Net(nn.Module)
nn.Module:这是PyTorch中所有神经网络的基类,就像"汽车"是各种车型的基类一样o1Net:我们自定义的神经网络类名,可以按需修改
2. 初始化方法:__init__
def __init__(self):
super().__init__()
super().__init__():调用父类的初始化方法,这是Python的标准做法
3. 定义网络层:全连接层
self.fc1 = nn.Linear(28 * 28, 256)
nn.Linear:全连接层,就像神经网络中的"神经元"28 * 28:输入大小,对应MNIST数据集的28x28像素图像256:输出大小,即这一层有256个神经元
💡 通俗理解:想象一下,这一层有256个小计算器,每个计算器都会接收784个输入值(28x28),然后进行计算
self.fc2 = nn.Linear(256, 128) # 第二层
self.fc3 = nn.Linear(128, 64) # 第三层
self.fc4 = nn.Linear(64, 10) # 输出层
- 网络结构:784 → 256 → 128 → 64 → 10
- 最后一层输出10个值,对应10个数字类别(0-9)
4. 前向传播:forward方法
def forward(self, x):
x = torch.flatten(x, start_dim=1)
torch.flatten:将二维图像数据"压扁"成一维数组start_dim=1:从第1维度开始展平(跳过batch维度)
🎯 为什么需要展平?
图像本是二维的(28x28),但全连接层需要一维输入,所以需要展平为784个特征
x = torch.relu(self.fc1(x)) # 第一层 + ReLU激活
self.fc1(x):数据通过第一层torch.relu():ReLU激活函数,让神经网络能够学习复杂模式
🔥 ReLU的作用:就像"开关",负数值变为0,正数值保持不变,增加非线性能力
后续几层类似,最后输出层不需要激活函数,因为损失函数会处理。
四、重要概念解析
1. 为什么输出层不需要softmax?
代码中注释说明了:
# nn.CrossEntropyLoss会在内部进行softmax操作
CrossEntropyLoss:常用的分类损失函数,内部已经包含softmax- 如果使用其他损失函数,可能需要手动添加softmax
2. 神经网络层数怎么算?
虽然代码有4个nn.Linear,但通常:
fc1、fc2、fc3是隐藏层(3层)fc4是输出层- 所以这是"三层隐藏层+输出层"的网络
五、完整实战代码示例
import torch
import torch.nn as nn
import torch.optim as optim
# 定义网络
class o1Net(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(28 * 28, 256)
self.fc2 = nn.Linear(256, 128)
self.fc3 = nn.Linear(128, 64)
self.fc4 = nn.Linear(64, 10)
def forward(self, x):
x = torch.flatten(x, 1)
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = torch.relu(self.fc3(x))
x = self.fc4(x)
return x
# 创建网络实例
model = o1Net()
print(model) # 查看网络结构
# 模拟输入数据(随机生成4张28x28图像)
dummy_input = torch.randn(4, 1, 28, 28)
output = model(dummy_input)
print(f"输出形状: {output.shape}") # 应该是[4, 10]
六、总结与下一步学习
通过本文,你应该已经理解了:
- ✅ 如何用PyTorch定义神经网络类
- ✅ 全连接层(
nn.Linear)的作用和参数含义 - ✅ 前向传播的过程和数据变换
- ✅ ReLU激活函数的重要性
- ✅ 为什么输出层不需要softmax
下一步学习建议:
- 学习如何训练这个网络(损失函数、优化器)
- 尝试在MNIST数据集上实际训练和测试
- 调整网络结构(增加/减少层数、神经元数量)
- 尝试不同的激活函数(Sigmoid、Tanh等)
欢迎在评论区留言交流! 如果遇到任何问题,或者想要了解更多内容,欢迎提出,我会尽力解答!记得点赞收藏哦~
更多推荐

所有评论(0)