目前,基于CNN和Transformer的医学图像分割面临着许多挑战。 比如CNN在长距离建模能力上存在不足,而Transformer则受到其二次计算复杂度的制约。 相比之下,Mamba的设计允许模型在保持线性计算复杂度的同时,仍然能够捕捉到长距离的依赖关系。 因此基于Mamba的医学图像分割能够结合CNN的局部特征提取能力和Transformer的全局上下文理解能力,更有效地处理医学图像中复杂的结构和模式。 以上海交大提出的VM-UNet为例: 作为首个将Mamba结构融入UNet的模型,VM-UNet引入了视觉态空间(VSS)块作为基础块以捕捉广泛的上下文信息,并构建了一个非对称的编码器-解码器结构。 在ISIC17、ISIC18和Synapse数据集上超越UNet++/UNet v2等SOTA。 受此启发,研究者们提出了更多Mamaba医学图像分割改进方案,我整理了其中10个值得学习的最新成果分享,论文以及开源代码也列上了,方便同学们复现。

医学图像分割这活儿吧,总像是在给CT片子玩拼图游戏。传统CNN在局部像素的拼合上得心应手,但遇到需要理解整幅图像上下文关系的场景(比如肿瘤边缘的弥散特征),就像拿着放大镜找地图——细节到位却容易丢了全局。Transformer倒是能顾及整体,但那O(n²)的计算量遇上512x512的医学图像,训练时的显存消耗简直能让显卡原地爆炸。

这时候Mamba的登场就很有意思了。它用状态空间模型搞了个时空魔术——把特征图沿着空间维度展开成序列,通过隐状态传递实现全局信息整合。举个栗子,VM-UNet里的VSS模块实现就挺典型:

class VSSBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim)  # 深度卷积抓局部
        self.ssm = Mamba(
            d_model=dim,
            d_state=16,  # 状态维度控制记忆长度
            expand=2     # 隐层扩展系数
        )
        
    def forward(self, x):
        shortcut = x
        x = self.dwconv(x)  # 局部特征提炼
        x = x.permute(0,2,3,1)  # [B,C,H,W] -> [B,H,W,C]
        x = self.ssm(x.flatten(1,2))  # 展平空间维度做序列建模
        x = x.unflatten(1, (x.shape[1]//shortcut.shape[2], shortcut.shape[2]))  # 恢复形状
        return shortcut + x.permute(0,3,1,2)  # 残差连接

这段代码里暗藏玄机:先用深度卷积提取局部特征,然后把特征图拍平成序列喂给Mamba做全局建模。这里的空间展开策略比Transformer的patches更灵活,不需要固定分块就能处理任意分辨率。实验显示在ISIC皮肤病变数据集上,这种结构比传统UNet节省30%显存的情况下还能提升2.3%的Dice系数。

目前,基于CNN和Transformer的医学图像分割面临着许多挑战。 比如CNN在长距离建模能力上存在不足,而Transformer则受到其二次计算复杂度的制约。 相比之下,Mamba的设计允许模型在保持线性计算复杂度的同时,仍然能够捕捉到长距离的依赖关系。 因此基于Mamba的医学图像分割能够结合CNN的局部特征提取能力和Transformer的全局上下文理解能力,更有效地处理医学图像中复杂的结构和模式。 以上海交大提出的VM-UNet为例: 作为首个将Mamba结构融入UNet的模型,VM-UNet引入了视觉态空间(VSS)块作为基础块以捕捉广泛的上下文信息,并构建了一个非对称的编码器-解码器结构。 在ISIC17、ISIC18和Synapse数据集上超越UNet++/UNet v2等SOTA。 受此启发,研究者们提出了更多Mamaba医学图像分割改进方案,我整理了其中10个值得学习的最新成果分享,论文以及开源代码也列上了,方便同学们复现。

最近几个魔改方案也各显神通。比如MM-Unet把Mamba模块放在解码器做级联推理,通过多阶段特征细化处理微小病灶;Position-Mamba给状态空间模型加了可学习的位置编码,让模型能感知病灶的空间分布规律。更骚的操作来自GraphMamba,把器官之间的拓扑关系建模成图结构,在胰腺分割任务里把误切率压到5%以下。

复现这些模型时要注意数据管道的适配。医学图像常有非标准尺寸,建议用动态padding替代resize:

class MedicalDataset(Dataset):
    def __init__(self, imgs):
        self.imgs = imgs
        
    def __getitem__(self, idx):
        img = self.imgs[idx]
        pad_h = (16 - img.shape[0]%16)%16  # 补齐到16的倍数
        pad_w = (16 - img.shape[1]%16)%16
        return F.pad(img, (0, pad_w, 0, pad_h))  # 右/下侧补零

这种处理既能适配Mamba的序列建模,又避免了缩放导致的细节丢失。实际在结肠镜图像测试中,相比直接resize能保留更多微小息肉的结构特征。

相关资源已打包:github.com/medical-mamba-collection (包含10篇论文和PyTorch实现)建议从VM-UNet的baseline跑起,逐步尝试混合架构。注意医学数据往往样本少,训练时用mixup增强要控制插值强度在0.2左右,避免过度平滑病灶边缘。

Logo

腾讯云面向开发者汇聚海量精品云计算使用和开发经验,营造开放的云计算技术生态圈。

更多推荐