YOLOv11改进 | 添加注意力篇 | 实现级联群体注意力机制CGAttention改进C2PSA机制 (全网独家首发)

一、本文介绍

本文给大家带来的改进机制是实现级联群体 注意力机制 CascadedGroupAttention ,其主要思想为增强输入到注意力头的特征的多样性。与以前的 自注意力 不同,它为每个头提供不同的输入分割,并跨头级联输出特征。这种方法不仅减少了多头注意力中的计算冗余,而且通过增加网络深度来提升 模型 容量, 亲测在我的25个类别的数据上,大部分的类别均有一定的涨点效果 ,仅有部分的类别保持不变,同时给该注意力机制含有二次创新的机会

欢迎大家订阅我的专栏一起学习YOLO!


目录

一、本文介绍

二、 CascadedGroupAttention的基本原理

三、CGA的核心代码

四、CGA的添加方式

4.1 修改一

4.2 修改二

4.3 修改三

4.4 修改四

五、CGA的yaml文件和运行记录

5.1 CGA的yaml文件一

5.2 CGA的yaml文件二

5.3 CGA的训练过程截图

五、本文总结


二、 CascadedGroupAttention的基本原理

官方论文地址: 官方论文地址点击即可跳转

官方代码地址: 官方代码地址点击即可跳转


Cascaded Group Attention (CGA) 是在文章 "EfficientViT: Memory Efficient Vision Transformer with Cascaded Group Attention" 中提出的一种新型注意力机制。其核心思想是增强输入到注意力头的特征的多样性。与以前的自注意力不同,它为每个头提供不同的输入分割,并跨头级联输出特征。这种方法不仅减少了多头注意力中的计算冗余,而且通过增加网络深度来提升模型容量。

具体来说,CGA 将输入特征分成不同的部分,每部分输入到一个注意力头。每个头计算其自注意力映射,然后将所有头的输出级联起来,并通过一个线性层将它们投影回输入的维度。通过这样的方式,CGA 在不增加额外参数的情况下提高了模型的计算效率。另外,通过串联的方式,每个头的输出都会添加到下一个头的输入中,从而逐步精化特征表示。

Cascaded Group Attention 的优点包括:
1. 提高了注意力图的多样性。
2. 减少了计算冗余,因为它减少了 QKV 层中输入和输出通道的数量。
3. 增加了网络深度,从而进一步提高了模型容量,同时只增加了很小的延迟开销,因为每个头的 QK 通道维度较小。

这张图描绘了 "EfficientViT" 模型中 "Cascaded Group Attention" (CGA) 模块的架构。

CGA模块位于图中的(c)部分,可以看到它的作用是处理输入特征,并提供分级的注意力机制。在这个模块中,输入首先被分割成多个部分,每个部分对应一个注意力头。每个头独立地计算其自注意力,并产生一个输出。然后,所有头的输出被级联(concatenate)在一起,通过一个线性投影层形成最终的输出。这种设计允许模型在不同的层次上捕捉特征,通过级联增强了特征之间的交互,同时提高了计算效率。

级联组注意力的关键点在于每个注意力头只关注输入的一部分,然后把所有头的注意力合并起来,来获取一个全面的特征表示。这样做的好处是减少了计算重复并增加了注意力的多样性,因为不同的头可能会关注输入的不同方面。这种方法提高了模型的内存和计算效率,同时保持或增强模型的 性能


三、 CGA的核心代码

代码使用方式看章节四!

  1. # https://github.com/microsoft/Cream/blob/ef68993c764f241a768cd69a087ed567dec6cb40/EfficientViT/classification/model/efficientvit.py#L104-L181
  2. import itertools
  3. import torch
  4. from torch import nn
  5. __all__ = ['C2PSA_CGA', 'LocalWindowAttention']
  6. class Conv2d_BN(torch.nn.Sequential):
  7. def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1,
  8. groups=1, bn_weight_init=1, resolution=-10000):
  9. super().__init__()
  10. self.add_module('c', torch.nn.Conv2d(
  11. a, b, ks, stride, pad, dilation, groups, bias=False))
  12. self.add_module('bn', torch.nn.BatchNorm2d(b))
  13. torch.nn.init.constant_(self.bn.weight, bn_weight_init)
  14. torch.nn.init.constant_(self.bn.bias, 0)
  15. @torch.no_grad()
  16. def switch_to_deploy(self):
  17. c, bn = self._modules.values()
  18. w = bn.weight / (bn.running_var + bn.eps)**0.5
  19. w = c.weight * w[:, None, None, None]
  20. b = bn.bias - bn.running_mean * bn.weight / \
  21. (bn.running_var + bn.eps)**0.5
  22. m = torch.nn.Conv2d(w.size(1) * self.c.groups, w.size(
  23. 0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups)
  24. m.weight.data.copy_(w)
  25. m.bias.data.copy_(b)
  26. return m
  27. class CascadedGroupAttention(torch.nn.Module):
  28. r""" Cascaded Group Attention.
  29. Args:
  30. dim (int): Number of input channels.
  31. key_dim (int): The dimension for query and key.
  32. num_heads (int): Number of attention heads.
  33. attn_ratio (int): Multiplier for the query dim for value dimension.
  34. resolution (int): Input resolution, correspond to the window size.
  35. kernels (List[int]): The kernel size of the dw conv on query.
  36. """
  37. def __init__(self, dim, key_dim, num_heads=8,
  38. attn_ratio=4,
  39. resolution=14,
  40. kernels=[5, 5, 5, 5], ):
  41. super().__init__()
  42. self.num_heads = num_heads
  43. self.scale = key_dim ** -0.5
  44. self.key_dim = key_dim
  45. self.d = int(attn_ratio * key_dim)
  46. self.attn_ratio = attn_ratio
  47. qkvs = []
  48. dws = []
  49. for i in range(num_heads):
  50. qkvs.append(Conv2d_BN(dim // (num_heads), self.key_dim * 2 + self.d, resolution=resolution))
  51. dws.append(Conv2d_BN(self.key_dim, self.key_dim, kernels[i], 1, kernels[i] // 2, groups=self.key_dim,
  52. resolution=resolution))
  53. self.qkvs = torch.nn.ModuleList(qkvs)
  54. self.dws = torch.nn.ModuleList(dws)
  55. self.proj = torch.nn.Sequential(torch.nn.ReLU(), Conv2d_BN(
  56. self.d * num_heads, dim, bn_weight_init=0, resolution=resolution))
  57. points = list(itertools.product(range(resolution), range(resolution)))
  58. N = len(points)
  59. attention_offsets = {}
  60. idxs = []
  61. for p1 in points:
  62. for p2 in points:
  63. offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
  64. if offset not in attention_offsets:
  65. attention_offsets[offset] = len(attention_offsets)
  66. idxs.append(attention_offsets[offset])
  67. self.attention_biases = torch.nn.Parameter(
  68. torch.zeros(num_heads, len(attention_offsets)))
  69. self.register_buffer('attention_bias_idxs',
  70. torch.LongTensor(idxs).view(N, N))
  71. @torch.no_grad()
  72. def train(self, mode=True):
  73. super().train(mode)
  74. if mode and hasattr(self, 'ab'):
  75. del self.ab
  76. else:
  77. self.ab = self.attention_biases[:, self.attention_bias_idxs]
  78. def forward(self, x): # x (B,C,H,W)
  79. B, C, H, W = x.shape
  80. trainingab = self.attention_biases[:, self.attention_bias_idxs]
  81. feats_in = x.chunk(len(self.qkvs), dim=1)
  82. feats_out = []
  83. feat = feats_in[0]
  84. for i, qkv in enumerate(self.qkvs):
  85. if i > 0: # add the previous output to the input
  86. feat = feat + feats_in[i]
  87. feat = qkv(feat)
  88. q, k, v = feat.view(B, -1, H, W).split([self.key_dim, self.key_dim, self.d], dim=1) # B, C/h, H, W
  89. q = self.dws[i](q)
  90. q, k, v = q.flatten(2), k.flatten(2), v.flatten(2) # B, C/h, N
  91. attn = (
  92. (q.transpose(-2, -1) @ k) * self.scale
  93. +
  94. (trainingab[i] if self.training else self.ab[i])
  95. )
  96. attn = attn.softmax(dim=-1) # BNN
  97. feat = (v @ attn.transpose(-2, -1)).view(B, self.d, H, W) # BCHW
  98. feats_out.append(feat)
  99. x = self.proj(torch.cat(feats_out, 1))
  100. return x
  101. class LocalWindowAttention(torch.nn.Module):
  102. r""" Local Window Attention.
  103. Args:
  104. dim (int): Number of input channels.
  105. key_dim (int): The dimension for query and key.
  106. num_heads (int): Number of attention heads.
  107. attn_ratio (int): Multiplier for the query dim for value dimension.
  108. resolution (int): Input resolution.
  109. window_resolution (int): Local window resolution.
  110. kernels (List[int]): The kernel size of the dw conv on query.
  111. """
  112. def __init__(self, dim, num_heads=4,
  113. attn_ratio=4,
  114. resolution=14,
  115. window_resolution=7,
  116. kernels=[5, 5, 5, 5], ):
  117. super().__init__()
  118. key_dim = dim // 16 # 必须放缩16倍否则会报错
  119. self.dim = dim
  120. self.num_heads = num_heads
  121. self.resolution = resolution
  122. assert window_resolution > 0, 'window_size must be greater than 0'
  123. self.window_resolution = window_resolution
  124. self.attn = CascadedGroupAttention(dim, key_dim, num_heads,
  125. attn_ratio=attn_ratio,
  126. resolution=window_resolution,
  127. kernels=kernels, )
  128. def forward(self, x):
  129. B, C, H, W = x.shape
  130. if H <= self.window_resolution and W <= self.window_resolution:
  131. x = self.attn(x)
  132. else:
  133. x = x.permute(0, 2, 3, 1)
  134. pad_b = (self.window_resolution - H %
  135. self.window_resolution) % self.window_resolution
  136. pad_r = (self.window_resolution - W %
  137. self.window_resolution) % self.window_resolution
  138. padding = pad_b > 0 or pad_r > 0
  139. if padding:
  140. x = torch.nn.functional.pad(x, (0, 0, 0, pad_r, 0, pad_b))
  141. pH, pW = H + pad_b, W + pad_r
  142. nH = pH // self.window_resolution
  143. nW = pW // self.window_resolution
  144. # window partition, BHWC -> B(nHh)(nWw)C -> BnHnWhwC -> (BnHnW)hwC -> (BnHnW)Chw
  145. x = x.view(B, nH, self.window_resolution, nW, self.window_resolution, C).transpose(2, 3).reshape(
  146. B * nH * nW, self.window_resolution, self.window_resolution, C
  147. ).permute(0, 3, 1, 2)
  148. x = self.attn(x)
  149. # window reverse, (BnHnW)Chw -> (BnHnW)hwC -> BnHnWhwC -> B(nHh)(nWw)C -> BHWC
  150. x = x.permute(0, 2, 3, 1).view(B, nH, nW, self.window_resolution, self.window_resolution,
  151. C).transpose(2, 3).reshape(B, pH, pW, C)
  152. if padding:
  153. x = x[:, :H, :W].contiguous()
  154. x = x.permute(0, 3, 1, 2)
  155. return x
  156. def autopad(k, p=None, d=1): # kernel, padding, dilation
  157. """Pad to 'same' shape outputs."""
  158. if d > 1:
  159. k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k] # actual kernel-size
  160. if p is None:
  161. p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
  162. return p
  163. class Conv(nn.Module):
  164. """Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)."""
  165. default_act = nn.SiLU() # default activation
  166. def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
  167. """Initialize Conv layer with given arguments including activation."""
  168. super().__init__()
  169. self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
  170. self.bn = nn.BatchNorm2d(c2)
  171. self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
  172. def forward(self, x):
  173. """Apply convolution, batch normalization and activation to input tensor."""
  174. return self.act(self.bn(self.conv(x)))
  175. def forward_fuse(self, x):
  176. """Perform transposed convolution of 2D data."""
  177. return self.act(self.conv(x))
  178. class PSABlock(nn.Module):
  179. """
  180. PSABlock class implementing a Position-Sensitive Attention block for neural networks.
  181. This class encapsulates the functionality for applying multi-head attention and feed-forward neural network layers
  182. with optional shortcut connections.
  183. Attributes:
  184. attn (Attention): Multi-head attention module.
  185. ffn (nn.Sequential): Feed-forward neural network module.
  186. add (bool): Flag indicating whether to add shortcut connections.
  187. Methods:
  188. forward: Performs a forward pass through the PSABlock, applying attention and feed-forward layers.
  189. Examples:
  190. Create a PSABlock and perform a forward pass
  191. """
  192. def __init__(self, c, attn_ratio=0.5, num_heads=4, shortcut=True) -> None:
  193. """Initializes the PSABlock with attention and feed-forward layers for enhanced feature extraction."""
  194. super().__init__()
  195. self.attn = LocalWindowAttention(c)
  196. self.ffn = nn.Sequential(Conv(c, c * 2, 1), Conv(c * 2, c, 1, act=False))
  197. self.add = shortcut
  198. def forward(self, x):
  199. """Executes a forward pass through PSABlock, applying attention and feed-forward layers to the input tensor."""
  200. x = x + self.attn(x) if self.add else self.attn(x)
  201. x = x + self.ffn(x) if self.add else self.ffn(x)
  202. return x
  203. class C2PSA_CGA(nn.Module):
  204. """
  205. C2PSA module with attention mechanism for enhanced feature extraction and processing.
  206. This module implements a convolutional block with attention mechanisms to enhance feature extraction and processing
  207. capabilities. It includes a series of PSABlock modules for self-attention and feed-forward operations.
  208. Attributes:
  209. c (int): Number of hidden channels.
  210. cv1 (Conv): 1x1 convolution layer to reduce the number of input channels to 2*c.
  211. cv2 (Conv): 1x1 convolution layer to reduce the number of output channels to c.
  212. m (nn.Sequential): Sequential container of PSABlock modules for attention and feed-forward operations.
  213. Methods:
  214. forward: Performs a forward pass through the C2PSA module, applying attention and feed-forward operations.
  215. Notes:
  216. This module essentially is the same as PSA module, but refactored to allow stacking more PSABlock modules.
  217. Examples:
  218. """
  219. def __init__(self, c1, c2, n=1, e=0.5):
  220. """Initializes the C2PSA module with specified input/output channels, number of layers, and expansion ratio."""
  221. super().__init__()
  222. assert c1 == c2
  223. self.c = int(c1 * e)
  224. self.cv1 = Conv(c1, 2 * self.c, 1, 1)
  225. self.cv2 = Conv(2 * self.c, c1, 1)
  226. self.m = nn.Sequential(*(PSABlock(self.c, attn_ratio=0.5, num_heads=self.c // 64) for _ in range(n)))
  227. def forward(self, x):
  228. """Processes the input tensor 'x' through a series of PSA blocks and returns the transformed tensor."""
  229. a, b = self.cv1(x).split((self.c, self.c), dim=1)
  230. b = self.m(b)
  231. return self.cv2(torch.cat((a, b), 1))
  232. if __name__ == "__main__":
  233. # Generating Sample image
  234. image_size = (1, 64, 224, 224)
  235. image = torch.rand(*image_size)
  236. # Model
  237. model = C2PSA_CGA(64, 64)
  238. out = model(image)
  239. print(out.size())


四、 CGA 的添加方式


4.1 修改一

第一还是建立文件,我们找到如下 ultralytics /nn文件夹下建立一个目录名字呢就是'Addmodules'文件夹( !然后在其内部建立一个新的py文件将核心代码复制粘贴进去即可。


4.2 修改二

第二步我们在该目录下创建一个新的py文件名字为'__init__.py'( ,然后在其内部导入我们的检测头如下图所示。


4.3 修改三

第三步我门中到如下文件'ultralytics/nn/tasks.py'进行导入和注册我们的模块( !


4.4 修改四

按照我的添加在parse_model里添加即可。


五、 CGA 的yaml文件和运行记录

5.1 CGA 的yaml文件一(推荐)

此版本训练信息:YOLO11-C2PSA-CGA summary: 340 layers, 2,567,615 parameters, 2,567,599 gradients, 6.4 GFLOPs

  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. # YOLO11 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
  3. # Parameters
  4. nc: 80 # number of classes
  5. scales: # model compound scaling constants, i.e. 'model=yolo11n.yaml' will call yolo11.yaml with scale 'n'
  6. # [depth, width, max_channels]
  7. n: [0.50, 0.25, 1024] # summary: 319 layers, 2624080 parameters, 2624064 gradients, 6.6 GFLOPs
  8. s: [0.50, 0.50, 1024] # summary: 319 layers, 9458752 parameters, 9458736 gradients, 21.7 GFLOPs
  9. m: [0.50, 1.00, 512] # summary: 409 layers, 20114688 parameters, 20114672 gradients, 68.5 GFLOPs
  10. l: [1.00, 1.00, 512] # summary: 631 layers, 25372160 parameters, 25372144 gradients, 87.6 GFLOPs
  11. x: [1.00, 1.50, 512] # summary: 631 layers, 56966176 parameters, 56966160 gradients, 196.0 GFLOPs
  12. # YOLO11n backbone
  13. backbone:
  14. # [from, repeats, module, args]
  15. - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
  16. - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
  17. - [-1, 2, C3k2, [256, False, 0.25]]
  18. - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
  19. - [-1, 2, C3k2, [512, False, 0.25]]
  20. - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
  21. - [-1, 2, C3k2, [512, True]]
  22. - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
  23. - [-1, 2, C3k2, [1024, True]]
  24. - [-1, 1, SPPF, [1024, 5]] # 9
  25. - [-1, 2, C2PSA_CGA, [1024]] # 10
  26. # YOLO11n head
  27. head:
  28. - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  29. - [[-1, 6], 1, Concat, [1]] # cat backbone P4
  30. - [-1, 2, C3k2, [512, False]] # 13
  31. - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  32. - [[-1, 4], 1, Concat, [1]] # cat backbone P3
  33. - [-1, 2, C3k2, [256, False]] # 16 (P3/8-small)
  34. - [-1, 1, Conv, [256, 3, 2]]
  35. - [[-1, 13], 1, Concat, [1]] # cat head P4
  36. - [-1, 2, C3k2, [512, False]] # 19 (P4/16-medium)
  37. - [-1, 1, Conv, [512, 3, 2]]
  38. - [[-1, 10], 1, Concat, [1]] # cat head P5
  39. - [-1, 2, C3k2, [1024, True]] # 22 (P5/32-large)
  40. - [[16, 19, 22], 1, Detect, [nc]] # Detect(P3, P4, P5)


5.2 CGA 的yaml文件二

此版本的训练信息:YOLO11-CGA summary: 418 layers, 2,718,839 parameters, 2,718,823 gradients, 6.7 GFLOPs

  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. # YOLO11 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
  3. # Parameters
  4. nc: 80 # number of classes
  5. scales: # model compound scaling constants, i.e. 'model=yolo11n.yaml' will call yolo11.yaml with scale 'n'
  6. # [depth, width, max_channels]
  7. n: [0.50, 0.25, 1024] # summary: 319 layers, 2624080 parameters, 2624064 gradients, 6.6 GFLOPs
  8. s: [0.50, 0.50, 1024] # summary: 319 layers, 9458752 parameters, 9458736 gradients, 21.7 GFLOPs
  9. m: [0.50, 1.00, 512] # summary: 409 layers, 20114688 parameters, 20114672 gradients, 68.5 GFLOPs
  10. l: [1.00, 1.00, 512] # summary: 631 layers, 25372160 parameters, 25372144 gradients, 87.6 GFLOPs
  11. x: [1.00, 1.50, 512] # summary: 631 layers, 56966176 parameters, 56966160 gradients, 196.0 GFLOPs
  12. # YOLO11n backbone
  13. backbone:
  14. # [from, repeats, module, args]
  15. - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
  16. - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
  17. - [-1, 2, C3k2, [256, False, 0.25]]
  18. - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
  19. - [-1, 2, C3k2, [512, False, 0.25]]
  20. - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
  21. - [-1, 2, C3k2, [512, True]]
  22. - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
  23. - [-1, 2, C3k2, [1024, True]]
  24. - [-1, 1, SPPF, [1024, 5]] # 9
  25. - [-1, 2, C2PSA, [1024]] # 10
  26. # YOLO11n head
  27. head:
  28. - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  29. - [[-1, 6], 1, Concat, [1]] # cat backbone P4
  30. - [-1, 2, C3k2, [512, False]] # 13
  31. - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  32. - [[-1, 4], 1, Concat, [1]] # cat backbone P3
  33. - [-1, 2, C3k2, [256, False]] # 16 (P3/8-small)
  34. - [-1, 1, LocalWindowAttention, []] # 17 (P3/8-small) 小目标检测层输出位置增加注意力机制
  35. - [-1, 1, Conv, [256, 3, 2]]
  36. - [[-1, 13], 1, Concat, [1]] # cat head P4
  37. - [-1, 2, C3k2, [512, False]] # 20 (P4/16-medium)
  38. - [-1, 1, LocalWindowAttention, []] # 21 (P4/16-medium) 中目标检测层输出位置增加注意力机制
  39. - [-1, 1, Conv, [512, 3, 2]]
  40. - [[-1, 10], 1, Concat, [1]] # cat head P5
  41. - [-1, 2, C3k2, [1024, True]] # 24 (P5/32-large)
  42. - [-1, 1, LocalWindowAttention, []] # 25 (P5/32-large) 大目标检测层输出位置增加注意力机制
  43. # 注意力机制我这里其实是添加了三个但是实际一般生效就只添加一个就可以了,所以大家可以自行注释来尝试, 上面三个仅建议大家保留一个, 但是from位置要对齐.
  44. # 具体在那一层用注意力机制可以根据自己的数据集场景进行选择。
  45. # 如果你自己配置注意力位置注意from[17, 21, 25]位置要对应上对应的检测层!
  46. - [[17, 21, 25], 1, Detect, [nc]] # Detect(P3, P4, P5)


5.3 CGA 的训练过程截图


五、本文总结

到此本文的正式分享内容就结束了,在这里给大家推荐我的YOLOv11改进有效涨点专栏,本专栏目前为新开的平均质量分98分,后期我会根据各种最新的前沿顶会进行论文复现,也会对一些老的改进机制进行补充,如果大家觉得本文帮助到你了,订阅本专栏,关注后续更多的更新~