一、U-Net 没有新操作

相比于简单的CNN结构U-Net 没有任何新算子

它只用了:

  • Conv

  • ReLU

  • Pooling

  • Upsample

  • Concat

So难点只有一个:结构

二、U-Net 要解决什么问题?

普通 CNN 的问题

卷积 + 池化 → 特征越来越抽象
但:空间信息越来越少

这在分类任务没问题,但在 分割 / 像素级任务 是灾难:

!!! 最后不知道“这个像素属于哪一类”

U-Net 的核心思想就一句话

“抽象语义 + 精确位置,我全都要”

怎么做到?

1.编码器记语义

2.解码器恢复分辨率

3.中间用 skip connection 把空间细节拉回来

三、把 U-Net 画成「一条数据流」

你可以这样想象:

输入图像
   ↓
[ 下采样 ]  →  语义越来越强
   ↓
[ Bottleneck ]
   ↓
[ 上采样 ]  →  分辨率逐渐恢复
   ↓
输出分割图

补一句:

U-Net = 对称的 Encoder + Decoder

四、从「一张图」开始跟踪张量尺寸(关键)

假设输入:

[B, C, H, W] = [1, 1, 256, 256]   # 灰度医学图像

1.Encoder(下采样路径)

Block 1

Conv → Conv → MaxPool
[1, 1, 256, 256]
→ Conv
→ [1, 64, 256, 256]
→ MaxPool
→ [1, 64, 128, 128]

保存一份 feature(跳跃连接用)

Block 2

[1, 64, 128, 128]
→ Conv
→ [1, 128, 128, 128]
→ MaxPool
→ [1, 128, 64, 64]

Block 3 / 4 同理

Encoder 做的事只有一件:空间 ↓,通道 ↑

2. Bottleneck(中间层)

[1, 512, 32, 32]
→ Conv
→ [1, 1024, 32, 32]

这是 最抽象、最“语义化”的表示

3. Decoder(上采样路径)——最容易懵的地方

先讲一句“规则”

Decoder 每一层都做三件事:

(下面的概念可以问一下deepseek,解释的很详细)

  1. 上采样

  2. 拼接 Encoder 的特征

  3. 卷积融合

举一个完整例子

输入:
[1, 1024, 32, 32]
Step 1:上采样
→ Upsample
→ [1, 512, 64, 64]
Step 2:拼接 Encoder 特征

Encoder 里保存的:

[1, 512, 64, 64]

拼接(concat):

[1, 512+512, 64, 64]
= [1, 1024, 64, 64]

!!!拼的是 channel 维度

Step 3:卷积融合

→ Conv
→ [1, 512, 64, 64]

4. Decoder 的一句话总结

Decoder:空间 ×2,通道 ÷2(先翻倍,再压缩)

五、Skip Connection 是什么?

1. 如果没有 skip connection

Decoder 只能靠低分辨率特征猜位置
→ 边界模糊

2. Skip Connection 的本质

Encoder 提供:精确空间信息
Decoder 提供:高级语义信息

!!! 拼在一起,让网络自己学怎么用

3. 为什么不用加法而用 concat?

一句话总结:

concat 不丢信息,加法会混信息

六、用 PyTorch 代码实现一个 U-Net(极简版)

1. Conv Block(最基本模块)

import torch
import torch.nn as nn
​
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True)
        )
​
    def forward(self, x):
        return self.net(x)

2. Encoder Block

class Down(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = DoubleConv(in_ch, out_ch)
        self.pool = nn.MaxPool2d(2)
​
    def forward(self, x):
        feat = self.conv(x)
        x = self.pool(feat)
        return x, feat  # feat 用于 skip

3. Decoder Block(重点)

class Up(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv = DoubleConv(in_ch, out_ch)
​
    def forward(self, x, skip):
        x = self.up(x)
        x = torch.cat([x, skip], dim=1)  # 通道维拼接
        x = self.conv(x)
        return x

4.一句话读懂 forward

x, f1 = down1(x)
x, f2 = down2(x)
​
x = bottleneck(x)
​
x = up1(x, f2)
x = up2(x, f1)

!!! 左边存,右边用

!!!完整版-代码

import torch
import torch.nn as nn

#Unet极速版

#Conv Block
class DoubleConv(nn.Module):
    def __init__(self,in_ch,out_ch):
        super().__init__()
        self.net=nn.Sequential(
            nn.Conv2d(in_ch,out_ch,3,padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch,out_ch,3,padding=1),
            nn.ReLU(inplace=True)
        )
    def forward(self,x):
        return self.net(x)

#Encoder Block
class Dowm(nn.Module):
    def __init__(self,in_ch,out_ch):
        super().__init__()
        self.conv=DoubleConv(in_ch,out_ch)
        self.pool=nn.MaxPool2d(2)
    
    def forward(self,x):
        feat=self.conv(x)
        x=self.pool(feat)
        return x,feat
    
#Decoder Block
class Up(nn.Module):
    def __init__(self,in_ch,out_ch):
        super().__init__()
        self.up=nn.Upsample(scale_factor=2,mode='bilinear',align_corners=True)
        self.conv=DoubleConv(in_ch,out_ch)
    
    def forward(self,x,skip):
        x=self.up(x)
        x=torch.cat([x,skip],dim=1)
        x=self.conv(x)
        return x

Logo

腾讯云面向开发者汇聚海量精品云计算使用和开发经验,营造开放的云计算技术生态圈。

更多推荐