深度学习第三步:自己能写Unet(基于Unet的图像分割任务)
本文介绍了医学图像分割任务的基本实现方法,主要包括: 任务定义:输入医学灰度图像,输出像素级二值分割mask 工程结构:包含数据加载、U-Net模型、损失函数和训练循环的完整框架 关键技术: 数据预处理:确保图像和mask严格对齐 标准U-Net实现:包含Encoder-Decoder结构和跳跃连接 损失函数:采用BCE+Dice组合损失 训练流程:展示了从数据加载到模型训练的基本过程
:演示代码只是为了易懂,虽然也能跑通,但是你真的做医学图像分割任务,肯定是不行的。后面会上传一个真正论文级的图像分割任务。
https://github.com/Zhoukun357/Unet-based_Medical_Image_Segmentation





Taining Loss基本上是教科书级别,Validation Dice / IoU:整体合理,有一点点“早期不稳定。
几乎没有明显过拟合。这说明:
-
学习率是合适的(不大不小)
-
反向传播稳定
-
网络确实学到了东西
但是模型这个效果肯定不会很好,因为只用了Dice / BCE 。而且是最简单的Unet网络,所以只能达到覆盖肿瘤,但看起来像一堆的效果。这是合理的,你如果有兴趣可以尝试加:边界感知损失,多尺度监督等。
目标
我们要做一个医学图像分割任务:
-
输入:医学图像(CT / MRI / 超声,灰度)
-
输出:像素级分割 mask(0/1)
-
模型:U-Net(标准医学版)
一、什么是「医学分割任务」
1. 医学图像分割 ≠ 分类
可以这样区分:
分类:整张图一个 label
分割:每一个像素一个 label
2. 数据长什么样?
一条训练样本 = 一对图像
image.png -> 原始医学图像 mask.png -> 分割标签(0 / 1)
对应张量:
image: [1, H, W] # 灰度 mask: [1, H, W] # 0 or 1
3. 网络学的是什么?
输入:CT 图像 输出:每个像素是「肿瘤 or 背景」的概率
二、完整工程结构
先给一个“全局地图”,否则会迷路
unet_seg/ ├── dataset.py # 数据加载 ├── model.py # U-Net 网络 ├── train.py # 训练循环 ├── loss.py # Dice / BCE ├── utils.py └── data/ ├── images/ └── masks/
三、Dataset:医学分割第一道坎(重点)
1. Dataset 要返回什么?
一句话原则:
Dataset 返回的每一项,必须能直接喂给网络和 loss
2.Dataset 实现
# dataset.py
import os
from PIL import Image
import torch
from torch.utils.data import Dataset
import torchvision.transforms as T
class MedicalSegDataset(Dataset):
def __init__(self, img_dir, mask_dir):
self.img_dir = img_dir
self.mask_dir = mask_dir
self.names = os.listdir(img_dir)
self.img_tf = T.Compose([
T.Resize((256, 256)),
T.ToTensor(), # [1, H, W], [0,1]
])
self.mask_tf = T.Compose([
T.Resize((256, 256)),
T.ToTensor(), # 0 or 1
])
def __len__(self):
return len(self.names)
def __getitem__(self, idx):
name = self.names[idx]
img = Image.open(os.path.join(self.img_dir, name)).convert("L")
mask = Image.open(os.path.join(self.mask_dir, name)).convert("L")
img = self.img_tf(img)
mask = self.mask_tf(mask)
return img, mask
你可以重点强调:
-
image 和 mask 必须一一对应
-
mask 不是 RGB
-
不做乱七八糟增强(先跑通)
四、一个「正经医学版」U-Net
1.DoubleConv(医学分割的灵魂)
# model.py
import torch
import torch.nn as nn
class DoubleConv(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
)
def forward(self, x):
return self.net(x)
2. Encoder / Decoder Block
class Down(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.conv = DoubleConv(in_ch, out_ch)
self.pool = nn.MaxPool2d(2)
def forward(self, x):
feat = self.conv(x)
x = self.pool(feat)
return x, feat
class Up(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.up = nn.ConvTranspose2d(in_ch // 2, in_ch // 2, 2, stride=2)
self.conv = DoubleConv(in_ch, out_ch)
def forward(self, x, skip):
x = self.up(x)
x = torch.cat([x, skip], dim=1)
x = self.conv(x)
return x
3.完整 U-Net
class UNet(nn.Module):
def __init__(self):
super().__init__()
self.down1 = Down(1, 64)
self.down2 = Down(64, 128)
self.down3 = Down(128, 256)
self.down4 = Down(256, 512)
self.bottleneck = DoubleConv(512, 1024)
self.up1 = Up(1024, 512)
self.up2 = Up(512, 256)
self.up3 = Up(256, 128)
self.up4 = Up(128, 64)
self.out = nn.Conv2d(64, 1, kernel_size=1)
def forward(self, x):
x, f1 = self.down1(x)
x, f2 = self.down2(x)
x, f3 = self.down3(x)
x, f4 = self.down4(x)
x = self.bottleneck(x)
x = self.up1(x, f4)
x = self.up2(x, f3)
x = self.up3(x, f2)
x = self.up4(x, f1)
return torch.sigmoid(self.out(x))
五、Loss
1. Dice Loss
# loss.py
def dice_loss(pred, target, smooth=1e-5):
pred = pred.view(-1)
target = target.view(-1)
inter = (pred * target).sum()
union = pred.sum() + target.sum()
dice = (2 * inter + smooth) / (union + smooth)
return 1 - dice
提醒:
Dice 关注“重叠程度”,不是像素个数
2.实战中常用组合
loss = BCE + Dice
六、训练循环(跑起来)
# train.py
import torch
from torch.utils.data import DataLoader
from dataset import MedicalSegDataset
from model import UNet
from loss import dice_loss
dataset = MedicalSegDataset("data/images", "data/masks")
loader = DataLoader(dataset, batch_size=4, shuffle=True)
model = UNet().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
bce = torch.nn.BCELoss()
for epoch in range(50):
model.train()
total_loss = 0
for img, mask in loader:
img = img.cuda()
mask = mask.cuda()
pred = model(img)
loss = bce(pred, mask) + dice_loss(pred, mask)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch}, Loss: {total_loss/len(loader):.4f}")
更多推荐
所有评论(0)