:演示代码只是为了易懂,虽然也能跑通,但是你真的做医学图像分割任务,肯定是不行的。后面会上传一个真正论文级的图像分割任务。

https://github.com/Zhoukun357/Unet-based_Medical_Image_Segmentation

Taining Loss基本上是教科书级别,Validation Dice / IoU:整体合理,有一点点“早期不稳定。

几乎没有明显过拟合。这说明:

  • 学习率是合适的(不大不小)

  • 反向传播稳定

  • 网络确实学到了东西

但是模型这个效果肯定不会很好,因为只用了Dice / BCE 。而且是最简单的Unet网络,所以只能达到覆盖肿瘤,但看起来像一堆的效果。这是合理的,你如果有兴趣可以尝试加:边界感知损失,多尺度监督等。

目标

我们要做一个医学图像分割任务:

  • 输入:医学图像(CT / MRI / 超声,灰度)

  • 输出:像素级分割 mask(0/1)

  • 模型:U-Net(标准医学版)

一、什么是「医学分割任务」

1. 医学图像分割 ≠ 分类

可以这样区分:

分类:整张图一个 label

分割:每一个像素一个 label

2. 数据长什么样?

一条训练样本 = 一对图像

image.png   -> 原始医学图像
mask.png    -> 分割标签(0 / 1)

对应张量:

image: [1, H, W]   # 灰度
mask:  [1, H, W]   # 0 or 1

3. 网络学的是什么?

输入:CT 图像
输出:每个像素是「肿瘤 or 背景」的概率

二、完整工程结构

先给一个“全局地图”,否则会迷路

unet_seg/
├── dataset.py        # 数据加载
├── model.py          # U-Net 网络
├── train.py          # 训练循环
├── loss.py           # Dice / BCE
├── utils.py
└── data/
    ├── images/
    └── masks/

三、Dataset:医学分割第一道坎(重点)

1. Dataset 要返回什么?

一句话原则:

Dataset 返回的每一项,必须能直接喂给网络和 loss

2.Dataset 实现

# dataset.py
import os
from PIL import Image
import torch
from torch.utils.data import Dataset
import torchvision.transforms as T
​
class MedicalSegDataset(Dataset):
    def __init__(self, img_dir, mask_dir):
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.names = os.listdir(img_dir)
​
        self.img_tf = T.Compose([
            T.Resize((256, 256)),
            T.ToTensor(),   # [1, H, W], [0,1]
        ])
​
        self.mask_tf = T.Compose([
            T.Resize((256, 256)),
            T.ToTensor(),   # 0 or 1
        ])
​
    def __len__(self):
        return len(self.names)
​
    def __getitem__(self, idx):
        name = self.names[idx]
​
        img = Image.open(os.path.join(self.img_dir, name)).convert("L")
        mask = Image.open(os.path.join(self.mask_dir, name)).convert("L")
​
        img = self.img_tf(img)
        mask = self.mask_tf(mask)
​
        return img, mask

你可以重点强调:

  • image 和 mask 必须一一对应

  • mask 不是 RGB

  • 不做乱七八糟增强(先跑通)

四、一个「正经医学版」U-Net

1.DoubleConv(医学分割的灵魂)

# model.py
import torch
import torch.nn as nn
​
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
​
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )
​
    def forward(self, x):
        return self.net(x)

2. Encoder / Decoder Block

class Down(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = DoubleConv(in_ch, out_ch)
        self.pool = nn.MaxPool2d(2)
​
    def forward(self, x):
        feat = self.conv(x)
        x = self.pool(feat)
        return x, feat
class Up(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_ch // 2, in_ch // 2, 2, stride=2)
        self.conv = DoubleConv(in_ch, out_ch)
​
    def forward(self, x, skip):
        x = self.up(x)
        x = torch.cat([x, skip], dim=1)
        x = self.conv(x)
        return x

3.完整 U-Net

class UNet(nn.Module):
    def __init__(self):
        super().__init__()
​
        self.down1 = Down(1, 64)
        self.down2 = Down(64, 128)
        self.down3 = Down(128, 256)
        self.down4 = Down(256, 512)
​
        self.bottleneck = DoubleConv(512, 1024)
​
        self.up1 = Up(1024, 512)
        self.up2 = Up(512, 256)
        self.up3 = Up(256, 128)
        self.up4 = Up(128, 64)
​
        self.out = nn.Conv2d(64, 1, kernel_size=1)
​
    def forward(self, x):
        x, f1 = self.down1(x)
        x, f2 = self.down2(x)
        x, f3 = self.down3(x)
        x, f4 = self.down4(x)
​
        x = self.bottleneck(x)
​
        x = self.up1(x, f4)
        x = self.up2(x, f3)
        x = self.up3(x, f2)
        x = self.up4(x, f1)
​
        return torch.sigmoid(self.out(x))

五、Loss

1. Dice Loss

# loss.py
def dice_loss(pred, target, smooth=1e-5):
    pred = pred.view(-1)
    target = target.view(-1)
​
    inter = (pred * target).sum()
    union = pred.sum() + target.sum()
​
    dice = (2 * inter + smooth) / (union + smooth)
    return 1 - dice

提醒:

Dice 关注“重叠程度”,不是像素个数

2.实战中常用组合

loss = BCE + Dice

六、训练循环(跑起来)

# train.py
import torch
from torch.utils.data import DataLoader
from dataset import MedicalSegDataset
from model import UNet
from loss import dice_loss
​
dataset = MedicalSegDataset("data/images", "data/masks")
loader = DataLoader(dataset, batch_size=4, shuffle=True)
​
model = UNet().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
bce = torch.nn.BCELoss()
​
for epoch in range(50):
    model.train()
    total_loss = 0
​
    for img, mask in loader:
        img = img.cuda()
        mask = mask.cuda()
​
        pred = model(img)
        loss = bce(pred, mask) + dice_loss(pred, mask)
​
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
​
        total_loss += loss.item()
​
    print(f"Epoch {epoch}, Loss: {total_loss/len(loader):.4f}")

Logo

腾讯云面向开发者汇聚海量精品云计算使用和开发经验,营造开放的云计算技术生态圈。

更多推荐