在yolov8上添加CBAM
在github上浏览时,看见有人问该问题,工作人员有给出相应步骤,故记录一下,后续验证尝试()
·
前言
在github上浏览时,看见有人问该问题,工作人员有给出相应步骤,故记录一下,后续验证尝试()
添加步骤
确保已定义 CBAM 模块:
首先,确保在代码库中正确定义 CBAM 模块。您的文件或类似文件中应包含如下内容:models/common.py
import torch
import torch.nn as nn
class CBAM(nn.Module):
def __init__(self, channels, reduction=16):
super(CBAM, self).__init__()
self.channel_attention = ChannelAttention(channels, reduction)
self.spatial_attention = SpatialAttention()
def forward(self, x):
x = self.channel_attention(x) * x
x = self.spatial_attention(x) * x
return x
class ChannelAttention(nn.Module):
def __init__(self, in_planes, reduction=16):
super(ChannelAttention, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.fc = nn.Sequential(
nn.Conv2d(in_planes, in_planes // reduction, 1, bias=False),
nn.ReLU(),
nn.Conv2d(in_planes // reduction, in_planes, 1, bias=False)
)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = self.fc(self.avg_pool(x))
max_out = self.fc(self.max_pool(x))
out = avg_out + max_out
return self.sigmoid(out)
class SpatialAttention(nn.Module):
def __init__(self):
super(SpatialAttention, self).__init__()
self.conv1 = nn.Conv2d(2, 1, kernel_size=7, padding=3, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
x = torch.cat([avg_out, max_out], dim=1)
x = self.conv1(x)
return self.sigmoid(x)
修改 YAML 配置:
确保您的 YAML 配置文件正确引用 CBAM 模块。例如
# YOLOv8.0n backbone
backbone:
# [from, repeats, module, args]
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
- [-1, 3, C2f, [128, True]]
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
- [-1, 6, C2f, [256, True]]
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
- [-1, 6, C2f, [512, True]]
- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
- [-1, 3, C2f, [1024, True]]
- [-1, 1, SPPF, [1024, 5]] # 9
# YOLOv8.0n head
head:
- [-1, 1, CBAM, [1024]]
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
- [-1, 3, C2f, [512]] # 13
- [-1, 1, CBAM, [512]]
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- [[-1, 4], 1, Concat, [1]] # cat backbone P3
- [-1, 3, C2f, [256]] # 17 (P3/8-small)
- [-1, 1, CBAM, [256]]
- [-1, 1, Conv, [256, 3, 2]]
- [[-1, 13], 1, Concat, [1]] # cat head P4
- [-1, 3, C2f, [512]] # 21 (P4/16-medium)
- [-1, 1, CBAM, [512]]
- [-1, 1, Conv, [512, 3, 2]]
- [[-1, 9], 1, Concat, [1]] # cat head P5
- [-1, 3, C2f, [1024]] # 25 (P5/32-large)
- [[17, 21, 25], 1, Detect, [nc]] # Detect(P3, P4, P5)
更新parse_model方法:
确保 tasks.py中的parse_model方法可以正确解析 CBAM 模块:
Update parse_model Method:
Ensure that the parse_model method in tasks.py can correctly parse the CBAM module:
elif m in {CBAM}:
c1, c2 = ch[f], args[0]
if c2 != nc:
c2 = make_divisible(min(c2, max_channels) * width, 8)
args = [c1, *args[1:]]
参考资料
1.YOLOv8 network adds CBAM module · Issue #10758 · ultralytics/ultralytics
更多推荐
所有评论(0)