深度学习第二步:看懂Unet
U-Net是一种用于图像分割的经典网络结构,其核心创新在于对称的编码器-解码器架构和跳跃连接机制。编码器通过卷积和下采样提取高级语义特征,解码器则通过上采样和跳跃连接恢复空间分辨率。跳跃连接将编码器的低级特征与解码器的高级特征拼接,既保留了空间细节又融合了语义信息。U-Net没有引入新算子,仅使用卷积、ReLU、池化等基础操作,通过精巧的结构设计解决了普通CNN在分割任务中丢失空间信息的问题。Py
·
一、U-Net 没有新操作
相比于简单的CNN结构U-Net 没有任何新算子
它只用了:
-
Conv
-
ReLU
-
Pooling
-
Upsample
-
Concat
So难点只有一个:结构
二、U-Net 要解决什么问题?
普通 CNN 的问题
卷积 + 池化 → 特征越来越抽象 但:空间信息越来越少
这在分类任务没问题,但在 分割 / 像素级任务 是灾难:
!!! 最后不知道“这个像素属于哪一类”
U-Net 的核心思想就一句话
“抽象语义 + 精确位置,我全都要”
怎么做到?
1.编码器记语义
2.解码器恢复分辨率
3.中间用 skip connection 把空间细节拉回来
三、把 U-Net 画成「一条数据流」
你可以这样想象:
输入图像 ↓ [ 下采样 ] → 语义越来越强 ↓ [ Bottleneck ] ↓ [ 上采样 ] → 分辨率逐渐恢复 ↓ 输出分割图
补一句:
U-Net = 对称的 Encoder + Decoder
四、从「一张图」开始跟踪张量尺寸(关键)
假设输入:
[B, C, H, W] = [1, 1, 256, 256] # 灰度医学图像
1.Encoder(下采样路径)
Block 1
Conv → Conv → MaxPool
[1, 1, 256, 256]
→ Conv
→ [1, 64, 256, 256]
→ MaxPool
→ [1, 64, 128, 128]
保存一份 feature(跳跃连接用)
Block 2
[1, 64, 128, 128]
→ Conv
→ [1, 128, 128, 128]
→ MaxPool
→ [1, 128, 64, 64]
Block 3 / 4 同理
Encoder 做的事只有一件:空间 ↓,通道 ↑
2. Bottleneck(中间层)
[1, 512, 32, 32]
→ Conv
→ [1, 1024, 32, 32]
这是 最抽象、最“语义化”的表示
3. Decoder(上采样路径)——最容易懵的地方
先讲一句“规则”
Decoder 每一层都做三件事:
(下面的概念可以问一下deepseek,解释的很详细)
-
上采样
-
拼接 Encoder 的特征
-
卷积融合
举一个完整例子
输入:
[1, 1024, 32, 32]
Step 1:上采样
→ Upsample
→ [1, 512, 64, 64]
Step 2:拼接 Encoder 特征
Encoder 里保存的:
[1, 512, 64, 64]
拼接(concat):
[1, 512+512, 64, 64]
= [1, 1024, 64, 64]
!!!拼的是 channel 维度
Step 3:卷积融合
→ Conv
→ [1, 512, 64, 64]
4. Decoder 的一句话总结
Decoder:空间 ×2,通道 ÷2(先翻倍,再压缩)
五、Skip Connection 是什么?
1. 如果没有 skip connection
Decoder 只能靠低分辨率特征猜位置
→ 边界模糊
2. Skip Connection 的本质
Encoder 提供:精确空间信息 Decoder 提供:高级语义信息
!!! 拼在一起,让网络自己学怎么用
3. 为什么不用加法而用 concat?
一句话总结:
concat 不丢信息,加法会混信息
六、用 PyTorch 代码实现一个 U-Net(极简版)
1. Conv Block(最基本模块)
import torch
import torch.nn as nn
class DoubleConv(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, 3, padding=1),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.net(x)
2. Encoder Block
class Down(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.conv = DoubleConv(in_ch, out_ch)
self.pool = nn.MaxPool2d(2)
def forward(self, x):
feat = self.conv(x)
x = self.pool(feat)
return x, feat # feat 用于 skip
3. Decoder Block(重点)
class Up(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv = DoubleConv(in_ch, out_ch)
def forward(self, x, skip):
x = self.up(x)
x = torch.cat([x, skip], dim=1) # 通道维拼接
x = self.conv(x)
return x
4.一句话读懂 forward
x, f1 = down1(x)
x, f2 = down2(x)
x = bottleneck(x)
x = up1(x, f2)
x = up2(x, f1)
!!! 左边存,右边用
!!!完整版-代码
import torch
import torch.nn as nn
#Unet极速版
#Conv Block
class DoubleConv(nn.Module):
def __init__(self,in_ch,out_ch):
super().__init__()
self.net=nn.Sequential(
nn.Conv2d(in_ch,out_ch,3,padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch,out_ch,3,padding=1),
nn.ReLU(inplace=True)
)
def forward(self,x):
return self.net(x)
#Encoder Block
class Dowm(nn.Module):
def __init__(self,in_ch,out_ch):
super().__init__()
self.conv=DoubleConv(in_ch,out_ch)
self.pool=nn.MaxPool2d(2)
def forward(self,x):
feat=self.conv(x)
x=self.pool(feat)
return x,feat
#Decoder Block
class Up(nn.Module):
def __init__(self,in_ch,out_ch):
super().__init__()
self.up=nn.Upsample(scale_factor=2,mode='bilinear',align_corners=True)
self.conv=DoubleConv(in_ch,out_ch)
def forward(self,x,skip):
x=self.up(x)
x=torch.cat([x,skip],dim=1)
x=self.conv(x)
return x
更多推荐
所有评论(0)