YOLOv11改进 | 添加注意力篇 | 一文带你改进GAM、CBAM、CA、ECA等通道注意力机制和多头注意力机制

一、本文介绍

本文的内容不会过度的去解释原理,更多的是从从代码的使用上和实用的角度出发去写这篇教程。

欢迎大家订阅我的专栏一起学习YOLO!


目录

一、本文介绍

二、GAM

2.1 GAM的介绍

2.2 GAM的核心代码

三、CBAM

3.1 CBAM的介绍

3.2 CBAM核心代码

四、CA

4.1 CA的介绍

4.2 CA核心代码

五、ECA

5.1 ECA的介绍

5.2 ECA核心代码

六、注意力机制的添加方法

6.1 修改一

6.2 修改二

6.3 修改三

6.4 修改四

七、yaml文件

八、本文总结


二、 GAM

2.1 GAM的介绍

​​ 官方论文地址: 官方论文地址点击此处即可跳转

官方代码地址: 官方代码地址点击此处即可跳转

​​


简单介绍: GAM旨在通过设计一种机制,减少信息损失并放大全局维度互动特征,从而解决传统注意力机制在通道和空间两个维度上保留信息不足的问题。GAM采用了顺序的通道- 空间注意力机制 ,并对子模块进行了重新设计。具体来说,通道注意力子模块使用3D排列来跨三个维度保留信息,并通过一个两层的MLP增强跨维度的通道-空间依赖性。在空间注意力子模块中,为了更好地关注空间信息,采用了两个卷积层进行空间信息融合,同时去除了可能导致信息减少的最大池化操作,通过使用分组卷积和通道混洗在ResNet50中避免参数数量显著增加。GAM在不同的 神经网络 架构上稳定提升 性能 ,特别是对于ResNet18,GAM以更少的参数和更好的效率超过了ABN, 其简单原理结构图如下所示。

​​


2.2 GAM的核心代码

  1. import torch
  2. import torch.nn as nn
  3. '''
  4. https://arxiv.org/abs/2112.05561
  5. '''
  6. class GAM(nn.Module):
  7. def __init__(self, in_channels, rate=4):
  8. super().__init__()
  9. out_channels = in_channels
  10. in_channels = int(in_channels)
  11. out_channels = int(out_channels)
  12. inchannel_rate = int(in_channels/rate)
  13. self.linear1 = nn.Linear(in_channels, inchannel_rate)
  14. self.relu = nn.ReLU(inplace=True)
  15. self.linear2 = nn.Linear(inchannel_rate, in_channels)
  16. self.conv1=nn.Conv2d(in_channels, inchannel_rate,kernel_size=7,padding=3,padding_mode='replicate')
  17. self.conv2=nn.Conv2d(inchannel_rate, out_channels,kernel_size=7,padding=3,padding_mode='replicate')
  18. self.norm1 = nn.BatchNorm2d(inchannel_rate)
  19. self.norm2 = nn.BatchNorm2d(out_channels)
  20. self.sigmoid = nn.Sigmoid()
  21. def forward(self,x):
  22. b, c, h, w = x.shape
  23. # B,C,H,W ==> B,H*W,C
  24. x_permute = x.permute(0, 2, 3, 1).view(b, -1, c)
  25. # B,H*W,C ==> B,H,W,C
  26. x_att_permute = self.linear2(self.relu(self.linear1(x_permute))).view(b, h, w, c)
  27. # B,H,W,C ==> B,C,H,W
  28. x_channel_att = x_att_permute.permute(0, 3, 1, 2)
  29. x = x * x_channel_att
  30. x_spatial_att = self.relu(self.norm1(self.conv1(x)))
  31. x_spatial_att = self.sigmoid(self.norm2(self.conv2(x_spatial_att)))
  32. out = x * x_spatial_att
  33. return out
  34. if __name__ == '__main__':
  35. img = torch.rand(1,64,32,48)
  36. b, c, h, w = img.shape
  37. net = GAM(in_channels=c, out_channels=c)
  38. output = net(img)
  39. print(output.shape)


三、CBAM

3.1 CBAM的介绍

​​

官方论文地址: 官方论文地址点击此处即可跳转

官方代码地址: 官方代码地址点击此处即可跳转

​​

简单介绍: CBAM的主要思想是通过关注重要的特征并抑制不必要的特征来增强网络的表示能力。模块首先应用通道注意力,关注"重要的"特征,然后应用空间注意力,关注这些特征的"重要位置"。通过这种方式,CBAM有效地帮助网络聚焦于图像中的关键信息,提高了特征的表示力度, 下图为其简单原理结构图

​​


3.2 CBAM核心代码

  1. import torch
  2. import torch.nn as nn
  3. class ChannelAttention(nn.Module):
  4. """Channel-attention module https://github.com/open-mmlab/mmdetection/tree/v3.0.0rc1/configs/rtmdet."""
  5. def __init__(self, channels: int) -> None:
  6. """Initializes the class and sets the basic configurations and instance variables required."""
  7. super().__init__()
  8. self.pool = nn.AdaptiveAvgPool2d(1)
  9. self.fc = nn.Conv2d(channels, channels, 1, 1, 0, bias=True)
  10. self.act = nn.Sigmoid()
  11. def forward(self, x: torch.Tensor) -> torch.Tensor:
  12. """Applies forward pass using activation on convolutions of the input, optionally using batch normalization."""
  13. return x * self.act(self.fc(self.pool(x)))
  14. class SpatialAttention(nn.Module):
  15. """Spatial-attention module."""
  16. def __init__(self, kernel_size=7):
  17. """Initialize Spatial-attention module with kernel size argument."""
  18. super().__init__()
  19. assert kernel_size in (3, 7), "kernel size must be 3 or 7"
  20. padding = 3 if kernel_size == 7 else 1
  21. self.cv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
  22. self.act = nn.Sigmoid()
  23. def forward(self, x):
  24. """Apply channel and spatial attention on input for feature recalibration."""
  25. return x * self.act(self.cv1(torch.cat([torch.mean(x, 1, keepdim=True), torch.max(x, 1, keepdim=True)[0]], 1)))
  26. class CBAM(nn.Module):
  27. """Convolutional Block Attention Module."""
  28. def __init__(self, c1, kernel_size=7):
  29. """Initialize CBAM with given input channel (c1) and kernel size."""
  30. super().__init__()
  31. self.channel_attention = ChannelAttention(c1)
  32. self.spatial_attention = SpatialAttention(kernel_size)
  33. def forward(self, x):
  34. """Applies the forward pass through C1 module."""
  35. return self.spatial_attention(self.channel_attention(x))


四、CA

4.1 CA的介绍

​​

官方论文地址: 官方论文地址点击此处即可跳转

官方代码地址: 官方代码地址点击此处即可跳转

​​


简单介绍: 坐标注意力是一种结合了通道注意力和位置信息的注意力机制,旨在提升移动网络的性能。它通过将特征张量沿两个空间方向进行1D全局池化,分别捕获沿垂直和水平方向的特征,保留了精确的位置信息并捕获了长距离依赖性。这两个方向的特征图被单独编码成方向感知和位置敏感的注意力图,然后这些注意力图通过乘法作用于输入特征图,以突出感兴趣的对象表示。坐标注意力的引入,使得 模型 能够更准确地定位和识别感兴趣的对象,同时由于其轻量级和灵活性,它可以轻松集成到现有的移动网络架构中,几乎不会增加计算开销。

​​


4.2 CA核心代码

  1. import torch
  2. import torch.nn as nn
  3. import math
  4. import torch.nn.functional as F
  5. class h_sigmoid(nn.Module):
  6. def __init__(self, inplace=True):
  7. super(h_sigmoid, self).__init__()
  8. self.relu = nn.ReLU6(inplace=inplace)
  9. def forward(self, x):
  10. return self.relu(x + 3) / 6
  11. class h_swish(nn.Module):
  12. def __init__(self, inplace=True):
  13. super(h_swish, self).__init__()
  14. self.sigmoid = h_sigmoid(inplace=inplace)
  15. def forward(self, x):
  16. return x * self.sigmoid(x)
  17. class CoordAtt(nn.Module):
  18. def __init__(self, inp, reduction=32):
  19. super(CoordAtt, self).__init__()
  20. oup = inp
  21. self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
  22. self.pool_w = nn.AdaptiveAvgPool2d((1, None))
  23. mip = max(8, inp // reduction)
  24. self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
  25. self.bn1 = nn.BatchNorm2d(mip)
  26. self.act = h_swish()
  27. self.conv_h = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
  28. self.conv_w = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
  29. def forward(self, x):
  30. identity = x
  31. n,c,h,w = x.size()
  32. x_h = self.pool_h(x)
  33. x_w = self.pool_w(x).permute(0, 1, 3, 2)
  34. y = torch.cat([x_h, x_w], dim=2)
  35. y = self.conv1(y)
  36. y = self.bn1(y)
  37. y = self.act(y)
  38. x_h, x_w = torch.split(y, [h, w], dim=2)
  39. x_w = x_w.permute(0, 1, 3, 2)
  40. a_h = self.conv_h(x_h).sigmoid()
  41. a_w = self.conv_w(x_w).sigmoid()
  42. out = identity * a_w * a_h
  43. return out

五、ECA

5.1 ECA的介绍

​​

官方论文地址: 官方论文地址点击此处即可跳转

官方代码地址: 官方代码地址点击此处即可跳转

​​

简单介绍:ECA(Efficient Channel Attention)注意力机制的原理可以总结为:避免通道注意力模块中的降维操作,通过采用局部跨通道交互策略,利用1D卷积实现高效的通道注意力计算。这种方法保持了性能的同时显著减少了模型的复杂性,通过自适应选择卷积核大小,确定了局部跨通道交互的覆盖范围。 ECA模块通过少量参数和低计算成本,实现了在ResNets和MobileNetV2等主干网络上的显著性能提升,且相对于其他注意力模块具有更高的效率和更好的性能。


5.2 ECA核心代码

  1. import torch
  2. from torch import nn
  3. from torch.nn.parameter import Parameter
  4. class ECA(nn.Module):
  5. """Constructs a ECA module.
  6. Args:
  7. channel: Number of channels of the input feature map
  8. k_size: Adaptive selection of kernel size
  9. """
  10. def __init__(self, channel, k_size=3):
  11. super(ECA, self).__init__()
  12. self.avg_pool = nn.AdaptiveAvgPool2d(1)
  13. self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False)
  14. self.sigmoid = nn.Sigmoid()
  15. def forward(self, x):
  16. # feature descriptor on the global spatial information
  17. y = self.avg_pool(x)
  18. # Two different branches of ECA module
  19. y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
  20. # Multi-scale information fusion
  21. y = self.sigmoid(y)
  22. return x * y.expand_as(x)


六、注意力机制的添加方法

6.1 修改一

第一还是建立文件,我们找到如下 ultralytics /nn文件夹下建立一个目录名字呢就是'Addmodules'文件夹 然后在其内部建立一个新的py文件将核心代码复制粘贴进去即可,我们可以将上述任意一个注意力机制放到里,也可以将所有的注意力机制都放在里面。

​​


6.2 修改二

第二步我们在该目录下创建一个新的py文件名字为'__init__.py ,然后在其内部导入我们的检测头如下图所示。

​​


6.3 修改三

第三步我门中到如下文件'ultralytics/nn/tasks.py'进行导入和注册我们的模块( !

​​​


6.4 修改四

第四步我们找到'ultralytics/nn/tasks.py'文件中的'def parse_model(d, ch, verbose=True):  # model_dict, input_channels(3)'方法,然后按照图片进行添加修改即可。

​​


七、yaml文件

  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. # YOLO11 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
  3. # Parameters
  4. nc: 80 # number of classes
  5. scales: # model compound scaling constants, i.e. 'model=yolo11n.yaml' will call yolo11.yaml with scale 'n'
  6. # [depth, width, max_channels]
  7. n: [0.50, 0.25, 1024] # summary: 319 layers, 2624080 parameters, 2624064 gradients, 6.6 GFLOPs
  8. s: [0.50, 0.50, 1024] # summary: 319 layers, 9458752 parameters, 9458736 gradients, 21.7 GFLOPs
  9. m: [0.50, 1.00, 512] # summary: 409 layers, 20114688 parameters, 20114672 gradients, 68.5 GFLOPs
  10. l: [1.00, 1.00, 512] # summary: 631 layers, 25372160 parameters, 25372144 gradients, 87.6 GFLOPs
  11. x: [1.00, 1.50, 512] # summary: 631 layers, 56966176 parameters, 56966160 gradients, 196.0 GFLOPs
  12. # YOLO11n backbone
  13. backbone:
  14. # [from, repeats, module, args]
  15. - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
  16. - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
  17. - [-1, 2, C3k2, [256, False, 0.25]]
  18. - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
  19. - [-1, 2, C3k2, [512, False, 0.25]]
  20. - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
  21. - [-1, 2, C3k2, [512, True]]
  22. - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
  23. - [-1, 2, C3k2, [1024, True]]
  24. - [-1, 1, SPPF, [1024, 5]] # 9
  25. - [-1, 2, C2PSA, [1024]] # 10
  26. # YOLO11n head
  27. head:
  28. - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  29. - [[-1, 6], 1, Concat, [1]] # cat backbone P4
  30. - [-1, 2, C3k2, [512, False]] # 13
  31. - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  32. - [[-1, 4], 1, Concat, [1]] # cat backbone P3
  33. - [-1, 2, C3k2, [256, False]] # 16 (P3/8-small)
  34. - [-1, 1, Conv, [256, 3, 2]]
  35. - [[-1, 13], 1, Concat, [1]] # cat head P4
  36. - [-1, 2, C3k2, [512, False]] # 19 (P4/16-medium)
  37. - [-1, 1, Conv, [512, 3, 2]]
  38. - [[-1, 10], 1, Concat, [1]] # cat head P5
  39. - [-1, 2, C3k2, [1024, True]] # 22 (P5/32-large)
  40. - [-1, 1, CBAM, []] # 23 这里默认先用的CBAM大家使用那个只需要把其他的注释掉即可,这里是在大目标检测曾层输出位置添加一个注意力机制。
  41. # - [-1, 1, ECA, []] # 23
  42. # - [-1, 1, GAM, []] # 23
  43. # - [-1, 1, CoordAtt, []] # 23
  44. - [[16, 19, 23], 1, Detect, [nc]] # Detect(P3, P4, P5)

想要学习添加更多添加位置更多机制欢迎大家订阅专栏~


八、本文总结

到此本文的正式分享内容就结束了,在这里给大家推荐我的YOLOv11改进有效涨点专栏,本专栏目前为新开的平均质量分98分,后期我会根据各种最新的前沿顶会进行论文复现,也会对一些老的改进机制进行补充,如果大家觉得本文帮助到你了,订阅本专栏,关注后续更多的更新~