图像分割语义分割unet、 deeplab3、FCN、Resnet网络等 基于pytorch框架制作 全套项目,包含网络模型,训练代码,预测代码,直接下载数据集就能跑,拿上就能用,简单又省事儿

在计算机视觉领域,图像分割一直是个热门话题,无论是语义分割、实例分割还是全景分割,都有着广泛的应用场景,比如医学图像分析、自动驾驶中的场景理解等。今天咱就来讲讲基于PyTorch框架,用Unet、Deeplab3、FCN这些经典网络模型,搭配Resnet骨干网络,制作一个超实用的图像分割全套项目,真的是直接下载数据集就能跑,简单省事儿。

一、网络模型

(一)Unet

Unet是医学图像分割领域的明星模型,它的结构就像个U字,非常有特点。它由收缩路径(下采样)和扩张路径(上采样)组成。下采样部分不断提取图像的特征,让特征图变小但特征维度变高;上采样部分则把小的特征图还原回原图大小,同时结合下采样过程中的特征,以实现精确的分割。

下面是一个简单的Unet部分代码示例(简化版,仅展示核心结构思路):

import torch
import torch.nn as nn


class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)


class Down(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Down, self).__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(kernel_size=2, stride=2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    def __init__(self, in_channels, out_channels, bilinear=True):
        super(Up, self).__init__()
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)
        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]
        x1 = nn.functional.pad(x1, [diffX // 2, diffX - diffX // 2,
                                     diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)


class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

(二)Deeplab3

Deeplab3是语义分割中很厉害的模型,它主要利用空洞卷积(也叫扩张卷积)来增大感受野,同时还引入了空洞空间金字塔池化(ASPP)模块,从不同尺度提取特征,以更好地处理不同大小的物体。

图像分割语义分割unet、 deeplab3、FCN、Resnet网络等 基于pytorch框架制作 全套项目,包含网络模型,训练代码,预测代码,直接下载数据集就能跑,拿上就能用,简单又省事儿

下面是ASPP模块的简单代码:

class ASPP(nn.Module):
    def __init__(self, in_channels, out_channels, atrous_rates):
        super(ASPP, self).__init__()
        self.aspp = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True)
            ),
            nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=atrous_rates[0],
                          dilation=atrous_rates[0], bias=False),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True)
            ),
            nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=atrous_rates[1],
                          dilation=atrous_rates[1], bias=False),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True)
            ),
            nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=atrous_rates[2],
                          dilation=atrous_rates[2], bias=False),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True)
            )
        ])
        self.global_pooling = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        self.conv1 = nn.Conv2d(5 * out_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x1 = self.aspp[0](x)
        x2 = self.aspp[1](x)
        x3 = self.aspp[2](x)
        x4 = self.aspp[3](x)
        x5 = self.global_pooling(x)
        x5 = nn.functional.interpolate(x5, size=x.size()[2:], mode='bilinear', align_corners=True)
        x = torch.cat((x1, x2, x3, x4, x5), dim=1)
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        return self.dropout(x)

(三)FCN

全卷积网络(FCN)开创了端到端的语义分割先河,它把传统卷积网络最后的全连接层换成卷积层,直接输出分割结果。这样可以对任意大小的图像进行处理,不像传统方法需要对图像进行裁剪或变形。

class FCN(nn.Module):
    def __init__(self, num_classes):
        super(FCN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            # 更多卷积层和池化层,这里省略部分
        )
        self.score_fr = nn.Conv2d(512, num_classes, kernel_size=1)
        self.upscore = nn.ConvTranspose2d(num_classes, num_classes, kernel_size=64, stride=32, bias=False)

    def forward(self, x):
        x = self.features(x)
        x = self.score_fr(x)
        x = self.upscore(x)
        x = x[:, :, 19:19 + x.size()[2], 19:19 + x.size()[3]].contiguous()
        return x

(四)Resnet

Resnet解决了深度神经网络训练过程中的梯度消失和梯度爆炸问题,它通过引入残差块,让网络可以学习到残差映射,使得训练深层网络变得更容易。Resnet常被用作其他模型的骨干网络,像在Unet、Deeplab3和FCN中,替换原来的普通卷积层,能提升模型的性能。

下面是一个简单的Resnet残差块代码:

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = None
        if stride!= 1 or in_channels!= out_channels:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.downsample is not None:
            identity = self.downsample(x)
        out += identity
        out = self.relu(out)
        return out

二、训练代码

训练代码是让模型学习数据特征的关键部分。我们首先要加载数据集,这里假设使用的是标准的图像分割数据集格式,比如VOC格式。

from torchvision import transforms, datasets
from torch.utils.data import DataLoader


transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_dataset = datasets.VOCSegmentation(root='./data', year='2012', image_set='train',
                                         download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)

接下来定义损失函数和优化器,以Unet模型为例:

import torch.optim as optim
import torch.nn.functional as F


model = UNet(n_channels=3, n_classes=21)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

for epoch in range(10):
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data['image'], data['segmentation']
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels.long())
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f'Epoch {epoch + 1}, Loss: {running_loss / len(train_loader)}')

三、预测代码

预测代码用于使用训练好的模型对新的图像进行分割预测。

import cv2
import numpy as np
from torchvision import transforms


def predict_image(model, image_path):
    img = cv2.imread(image_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    img = transform(img).unsqueeze(0)
    with torch.no_grad():
        output = model(img)
    prediction = torch.argmax(output, dim=1).squeeze().numpy()
    prediction = np.array(prediction * 255 / np.max(prediction), dtype=np.uint8)
    prediction = cv2.resize(prediction, (img.shape[3], img.shape[2]), interpolation=cv2.INTER_NEAREST)
    return prediction


model = UNet(n_channels=3, n_classes=21)
model.load_state_dict(torch.load('unet_model.pth'))
model.eval()
prediction = predict_image(model, 'test_image.jpg')
cv2.imwrite('prediction.png', prediction)

以上就是基于PyTorch框架,结合Unet、Deeplab3、FCN和Resnet网络制作图像分割项目的主要内容啦,整套项目代码简单易懂,数据集下载就能跑起来,无论是初学者学习图像分割,还是老手快速搭建应用,都非常方便,赶紧动手试试吧!

Logo

腾讯云面向开发者汇聚海量精品云计算使用和开发经验,营造开放的云计算技术生态圈。

更多推荐