YOLOv11改进 | 主干/Backbone篇 | 视觉变换器SwinTransformer目标检测网络( 适配yolov11全系列模型)

一、本文介绍

本文给大家带来的改进机制是利用 Swin Transformer 替换 YOLOv11中的骨干网络 其是一个开创性的视觉 变换器 模型 ,它通过使用位移窗口来构建分层的 特征图 ,有效地适应了 计算机视觉 任务。与传统的变换器模型不同,Swin Transformer的自注意力计算仅限于局部窗口内,使得 计算复杂度与图像大小成线性关系,而非二次方 。这种设计不仅提高了模型的效率,还保持了强大的特征提取能力。Swin Transformer 的创新在于其能够在不同层次上捕捉图像的细节和全局信息,使其成为各种视觉任务的强大通用骨干网络。 亲测在小目标检测和大尺度目标检测的数据集上都有涨点效果。

(本文内容可根据yolov11的N、S、M、L、X进行二次缩放,轻量化更上一层)


目录

一、本文介绍

二、Swin Transformer原理

2.1 Swin Transformer的基本原理

2.2 层次化特征映射

2.3 局部自注意力计算

2.4 移动窗口自注意力

2.5 移动窗口分区

三、 Swin Transformer的完整代码

四、手把手教你添加Swin Transformer网络结构

修改一

修改二

修改三

修改四

修改五

修改六

修改七

修改八

五、Swin Transformer的yaml文件

六、成功运行记录

七、本文总结


二、Swin Transformer原理

论文地址: 官方论文地址

代码地址: 官方代码、地址


2.1 Swin Transformer的基本原理

Swin Transformer 是一个新的视觉变换器,能够作为通用的计算机视觉骨干网络。这个模型解决了将Transformer从语言处理领域适应到视觉任务中的挑战,主要是因为这两个领域之间存在差异,例如视觉实体的尺度变化大,以及图像中像素的高分辨率与文本中的单词相比。下图对比展示了Swin Transformer与Vision Transformer (ViT)的不同之处,清楚地展示了Swin Transformer在 构建特征映射和处理计算复杂度方面 的创新优势。

(a) Swin Transformer: 提出的Swin Transformer通过在更深层次合并图像小块(灰色部分所示)来构建层次化的特征映射。在每个局部窗口(红色部分所示)内只计算自注意力,因此它对输入图像大小有线性的计算复杂度。它可以作为通用的骨干网络,用于图像分类和密集识别任务,如分割和检测。

(b) Vision Transformer (ViT): 以前的视觉Transformer模型(如ViT)产生单一低分辨率的特征映射,并且由于全局自注意力的计算,其计算复杂度与输入图像大小呈二次方关系。

我们可以将Swin Transformer的基本原理分为以下几点:

1. 层次化特征映射: Swin Transformer通过合并图像的相邻小块(patches),在更深的Transformer层次中逐步构建层次化的特征映射。这样的层次化特征映射可以方便地利用密集预测的高级技术,如特征金字塔网络(Feature Pyramid Networks, FPN)或U-Net。

2. 局部自注意力计算: 为了实现线性计算复杂性,Swin Transformer在非重叠的局部窗口内计算自注意力,这些窗口是通过划分图像来创建的。每个窗口内的小块数量是固定的,因此计算复杂性与图像大小成线性关系。

3. 移动窗口自注意力(Shifted Window based Self-Attention): 标准的Transformer架构在全局范围内计算自注意力,即计算一个标记与所有其他标记之间的关系。这种全局计算导致与标记数量成二次方的计算复杂性,不适用于许多需要处理大规模高维数据的视觉问题。Swin Transformer通过一个基于移动窗口的多头自注意力(MSA)模块取代了传统的MSA模块。每个Swin Transformer块由一个基于移动窗口的MSA模块组成,然后是两层带有GELU非线性的MLP,之前是LayerNorm(LN)层,之后是残差连接。

4. 移动窗口分区: 为了在连续的Swin Transformer块中引入跨窗口连接的同时保持非重叠窗口的有效计算,提出了一种移动窗口分区方法。这种方法在连续的块之间交替使用两种分区配置。第一个模块使用常规的窗口分区策略,然后下一个模块采用的窗口配置与前一层相比,通过移动窗口偏移了一定距离,从而实现窗口的交替。

下图详细展示了 Swin Transformer的架构和两个连续Swin Transformer块的设计 。图中的W-MSA和SW-MSA分别代表带有常规和移动窗口配置的多头自注意力模块。这两种类型的注意力模块交替使用,允许模型在保持局部计算的同时,也能够捕捉更广泛的上下文信息。

(a) 架构(Architecture): 图展示了Swin Transformer的四个阶段。每个阶段都包含若干Swin Transformer块。输入图像首先通过“Patch Partition”被划分成小块,并通过“Linear Embedding”转换成向量序列。各个阶段通过“Patch Merging”操作降低特征图的分辨率,同时增加特征维数(例如,第一阶段输出的特征维数为C,第二阶段为2C,依此类推)。

(b) 两个连续的Swin Transformer块(Two Successive Swin Transformer Blocks): 每个Swin Transformer块由多头自注意力模块(W-MSA和SW-MSA)和多层感知机(MLP)组成,其中W-MSA使用常规窗口配置,而SW-MSA使用移动窗口配置。每个块内部,先是LayerNorm(LN)层,然后是自注意力模块,再是另一个LayerNorm层,最后是MLP。块之间通过残差连接进行连接,这样的设计可以避免深层网络中的梯度消失问题,并允许信息在网络中更有效地流动。


2.2 层次化特征映射

层次化特征映射可以使 Swin Transformer有效地处理不同分辨率的特征 ,并适用于各种视觉任务,如图像分类、对象检测和语义分割。这种层次化设计使Swin Transformer与以往基于Transformer的架构(这些架构产生单一分辨率的特征图并具有二次方复杂度)形成对比,后者不适合需要在像素级进行密集预测的视觉任务。Swin Transformer的 层次化特征映射 主要通过以下步骤实现:

1. 分块和线性嵌入: 首先,输入图像被分割成小块(通常是4x4像素大小),每个小块被视为一个“标记”,其特征是原始像素RGB值的串联。然后,一个线性嵌入层被应用于这些原始值特征,将其投影到任意维度(表示为C)。这些步骤构成了所谓的“第1阶段”。

2. 分块合并: 随着网络深入,通过合并层减少标记的数量,从而降低特征图的分辨率。例如,第一个合并层将每组2x2相邻小块的特征合并,并应用一个线性层到这些4C维度的串联特征上,这样做将标记的数量减少了4倍(分辨率降低了2倍),并将输出维度设为2C。这个过程在后续的“第2阶段”、“第3阶段”和“第4阶段”中重复,分别产生更低分辨率的输出。

3. 层次化特征图: 通过在更深的Transformer层合并相邻小块,Swin Transformer构建了层次化的特征映射。这些层次化特征映射允许模型方便地使用密集预测的高级技术,例如特征金字塔网络(FPN)或U-Net。

4. 计算效率: Swin Transformer在非重叠的局部窗口内局部计算自注意力,从而实现了线性的计算复杂度。每个窗口中的小块数量是固定的,因此复杂度与图像大小成线性关系。


2.3 局部自注意力计算

Swin Transformer的局部自注意力计算通过 在小窗口内计算自注意力以及通过移动窗口在连续层之间引入跨窗口的信息流通,使得计算更加高效,同时保留了模型捕捉长距离依赖的能力。Swin Transformer中的局部自注意力计算我们可以通过以下方式实现:

1. 替代标准多头自注意力模块: Swin Transformer使用基于移动窗口的多头自注意力(MSA)模块替代了传统Transformer块中的标准多头自注意力模块,其他层保持不变。每个Swin Transformer块由一个基于移动窗口的MSA模块组成,后跟一个两层的MLP,中间包含GELU非线性 激活函数 。在每个MSA模块和MLP之前都会应用一个LayerNorm(LN)层,每个模块之后都会应用残差连接。

2. 在各个窗口内计算自注意力: 在每一层中,采用常规的窗口分区方案,每个窗口内部独立计算自注意力。在下一层中,窗口分区会发生移动,形成新的窗口。新窗口中的自注意力计算会跨越之前层中窗口的边界,建立它们之间的连接。

3. 非重叠窗口中的自注意力: 为了有效的建模,Swin Transformer在非重叠的局部窗口内计算自注意力。这些窗口被安排以均匀非重叠的方式分割图像。假设每个窗口包含M×M个小块,全局MSA模块和基于窗口的MSA模块的计算复杂度分别为二次方和线性,当M固定时(默认设为7)。

4. 循环位移和掩码机制: 提出了一种通过循环位移来提高批量计算的效率的方法。通过这种位移,一个批次的窗口可能由几个在特征图中不相邻的子窗口组成,因此采用掩码机制限制在每个子窗口内计算自注意力。这种循环位移保持了批次窗口的数量与常规窗口分区相同。

5. 窗口间的位移: 为了在连续层之间实现更高效的硬件实现,Swin Transformer提出在连续层之间位移窗口,这样的位移允许跨窗口的连接,同时维持计算的高效性。

6. 相对位置偏置: 在计算自注意力时,Swin Transformer包括了相对位置偏置B,以增强模型对不同位置之间关系的学习能力。


2.4 移动窗口自注意力

移动窗口自注意力是Swin Transformer设计的核心元素,它 通过在局部窗口内计算自注意力 并在连续层之间引入窗口位移,以实现高效的计算和强大的建模能力。在Swin Transformer论文中,移动窗口自注意力(shifted window self-attention)的 主要特点 包括:

1. 替代多头自注意力模块: 在Swin Transformer块中,标准的 多头自注意力(MSA)模块 被基于移动窗口的MSA模块替换。这种基于移动窗口的MSA模块后跟一个两层的MLP,中间有GELU非线性激活函数。每个MSA模块和MLP之前都会应用一个LayerNorm(LN)层,每个模块之后都会应用残差连接。

2. 移动窗口分区: 在连续的Swin Transformer块中,窗口分区策略在每一层之间交替。在某一层中,采用常规窗口分区,而在下一层中,窗口分区会发生移动,从而形成新的窗口。这种移动窗口分区方法能够跨越前一层中窗口的边界,提供窗口间的连接。

3. 交替分区配置: 移动窗口分区方法在连续的Swin Transformer块中交替使用两种分区配置。例如,第一个模块从左上角像素开始使用常规窗口分区策略,接着下一个模块采用的窗口配置将与前一层相比移动一定距离。

4. 移动窗口自注意力的计算: 移动窗口自注意力计算的有效性不仅在图像分类、目标检测和语义分割任务中得到了验证,而且它的实现也被证明在所有MLP架构中有益。

5. 效率: 相比于滑动窗口方法,移动窗口方法具有更低的延迟,但在建模能力上却相似。此外,移动窗口方法也有助于提高批量计算的效率。

6. 连续块的计算: 在移动窗口分区方法中,连续的Swin Transformer块的 计算方式 如下: \hat{z_l} = W\text{-}MSA(LN(z_{l-1})) + z_{l-1} ,然后是MLP层,之后是 \\hat{z_{l+1}} = SW\text{-}MSA(LN(z_l)) + z_l 。这里, \hat{z_l}z_l 分别代表块l的(S)W-MSA模块和MLP模块的输出特征。

下面我给大家展示了所提出的Swin Transformer架构中用于 计算自注意力的移动窗口方法

在第l层(左侧),采用了常规窗口划分方案,并且在每个窗口内计算自注意力。在接下来的第l+1层(右侧),窗口划分被移动,结果在新的窗口中进行了自注意力计算。这些新窗口中的自注意力计算跨越了l层中之前窗口的边界,提供了它们之间的连接。这种移动窗口方法提高了效率, 因为它限制了自注意力计算在非重叠的局部窗口内,同时允许窗口间的交叉连接。


2.5 移动窗口分区

移动窗口分区是Swin Transformer中一项关键的创新,它 通过在连续层之间交替窗口的分区方式 有效地促进了信息在窗口之间的流动,同时保持了处理高分辨率图像时的计算效率。下面我将通过图片 解释如何使用循环位移来计算在移动窗口中的自注意力,以及如何高效地实施这一计算

(1)窗口分区(Window partition): 首先,图像被分成多个窗口。
(2)循环位移(Cyclic shift): 接着,为了计算自注意力,窗口内的像素或特征会进行循环位移。这样可以将本来不相邻的像素或特征暂时性地排列到同一个窗口内,使得可以在局部窗口中计算原本跨窗口的自注意力。
(3)掩码多头自注意力(Masked MSA): 在经过循环位移后,可以在这些临时形成的窗口上执行掩码多头自注意力操作,以此计算注意力得分和更新特征。
(4)逆循环位移(Reverse cyclic shift): 完成自注意力计算后,特征会进行逆循环位移,恢复到它们原来在图像中的位置。


三、 Swin Transformer的完整代码

  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. import torch.utils.checkpoint as checkpoint
  5. import numpy as np
  6. from timm.models.layers import DropPath, to_2tuple, trunc_normal_
  7. __all__ = ['SwinTransformer']
  8. class Mlp(nn.Module):
  9. """ Multilayer perceptron."""
  10. def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
  11. super().__init__()
  12. out_features = out_features or in_features
  13. hidden_features = hidden_features or in_features
  14. self.fc1 = nn.Linear(in_features, hidden_features)
  15. self.act = act_layer()
  16. self.fc2 = nn.Linear(hidden_features, out_features)
  17. self.drop = nn.Dropout(drop)
  18. def forward(self, x):
  19. x = self.fc1(x)
  20. x = self.act(x)
  21. x = self.drop(x)
  22. x = self.fc2(x)
  23. x = self.drop(x)
  24. return x
  25. def window_partition(x, window_size):
  26. """
  27. Args:
  28. x: (B, H, W, C)
  29. window_size (int): window size
  30. Returns:
  31. windows: (num_windows*B, window_size, window_size, C)
  32. """
  33. B, H, W, C = x.shape
  34. x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
  35. windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
  36. return windows
  37. def window_reverse(windows, window_size, H, W):
  38. """
  39. Args:
  40. windows: (num_windows*B, window_size, window_size, C)
  41. window_size (int): Window size
  42. H (int): Height of image
  43. W (int): Width of image
  44. Returns:
  45. x: (B, H, W, C)
  46. """
  47. B = int(windows.shape[0] / (H * W / window_size / window_size))
  48. x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
  49. x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
  50. return x
  51. class WindowAttention(nn.Module):
  52. """ Window based multi-head self attention (W-MSA) module with relative position bias.
  53. It supports both of shifted and non-shifted window.
  54. Args:
  55. dim (int): Number of input channels.
  56. window_size (tuple[int]): The height and width of the window.
  57. num_heads (int): Number of attention heads.
  58. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
  59. qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
  60. attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
  61. proj_drop (float, optional): Dropout ratio of output. Default: 0.0
  62. """
  63. def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
  64. super().__init__()
  65. self.dim = dim
  66. self.window_size = window_size # Wh, Ww
  67. self.num_heads = num_heads
  68. head_dim = dim // num_heads
  69. self.scale = qk_scale or head_dim ** -0.5
  70. # define a parameter table of relative position bias
  71. self.relative_position_bias_table = nn.Parameter(
  72. torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
  73. # get pair-wise relative position index for each token inside the window
  74. coords_h = torch.arange(self.window_size[0])
  75. coords_w = torch.arange(self.window_size[1])
  76. coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
  77. coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
  78. relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
  79. relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
  80. relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
  81. relative_coords[:, :, 1] += self.window_size[1] - 1
  82. relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
  83. relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
  84. self.register_buffer("relative_position_index", relative_position_index)
  85. self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
  86. self.attn_drop = nn.Dropout(attn_drop)
  87. self.proj = nn.Linear(dim, dim)
  88. self.proj_drop = nn.Dropout(proj_drop)
  89. trunc_normal_(self.relative_position_bias_table, std=.02)
  90. self.softmax = nn.Softmax(dim=-1)
  91. def forward(self, x, mask=None):
  92. """ Forward function.
  93. Args:
  94. x: input features with shape of (num_windows*B, N, C)
  95. mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
  96. """
  97. B_, N, C = x.shape
  98. qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
  99. q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
  100. q = q * self.scale
  101. attn = (q @ k.transpose(-2, -1))
  102. relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
  103. self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
  104. relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
  105. attn = attn + relative_position_bias.unsqueeze(0)
  106. if mask is not None:
  107. nW = mask.shape[0]
  108. attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
  109. attn = attn.view(-1, self.num_heads, N, N)
  110. attn = self.softmax(attn)
  111. else:
  112. attn = self.softmax(attn)
  113. attn = self.attn_drop(attn)
  114. x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
  115. x = self.proj(x)
  116. x = self.proj_drop(x)
  117. return x
  118. class SwinTransformerBlock(nn.Module):
  119. """ Swin Transformer Block.
  120. Args:
  121. dim (int): Number of input channels.
  122. num_heads (int): Number of attention heads.
  123. window_size (int): Window size.
  124. shift_size (int): Shift size for SW-MSA.
  125. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
  126. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
  127. qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
  128. drop (float, optional): Dropout rate. Default: 0.0
  129. attn_drop (float, optional): Attention dropout rate. Default: 0.0
  130. drop_path (float, optional): Stochastic depth rate. Default: 0.0
  131. act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
  132. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
  133. """
  134. def __init__(self, dim, num_heads, window_size=7, shift_size=0,
  135. mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
  136. act_layer=nn.GELU, norm_layer=nn.LayerNorm):
  137. super().__init__()
  138. self.dim = dim
  139. self.num_heads = num_heads
  140. self.window_size = window_size
  141. self.shift_size = shift_size
  142. self.mlp_ratio = mlp_ratio
  143. assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
  144. self.norm1 = norm_layer(dim)
  145. self.attn = WindowAttention(
  146. dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
  147. qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
  148. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  149. self.norm2 = norm_layer(dim)
  150. mlp_hidden_dim = int(dim * mlp_ratio)
  151. self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
  152. self.H = None
  153. self.W = None
  154. def forward(self, x, mask_matrix):
  155. """ Forward function.
  156. Args:
  157. x: Input feature, tensor size (B, H*W, C).
  158. H, W: Spatial resolution of the input feature.
  159. mask_matrix: Attention mask for cyclic shift.
  160. """
  161. B, L, C = x.shape
  162. H, W = self.H, self.W
  163. assert L == H * W, "input feature has wrong size"
  164. shortcut = x
  165. x = self.norm1(x)
  166. x = x.view(B, H, W, C)
  167. # pad feature maps to multiples of window size
  168. pad_l = pad_t = 0
  169. pad_r = (self.window_size - W % self.window_size) % self.window_size
  170. pad_b = (self.window_size - H % self.window_size) % self.window_size
  171. x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
  172. _, Hp, Wp, _ = x.shape
  173. # cyclic shift
  174. if self.shift_size > 0:
  175. shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
  176. attn_mask = mask_matrix.type(x.dtype)
  177. else:
  178. shifted_x = x
  179. attn_mask = None
  180. # partition windows
  181. x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
  182. x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
  183. # W-MSA/SW-MSA
  184. attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
  185. # merge windows
  186. attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
  187. shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
  188. # reverse cyclic shift
  189. if self.shift_size > 0:
  190. x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
  191. else:
  192. x = shifted_x
  193. if pad_r > 0 or pad_b > 0:
  194. x = x[:, :H, :W, :].contiguous()
  195. x = x.view(B, H * W, C)
  196. # FFN
  197. x = shortcut + self.drop_path(x)
  198. x = x + self.drop_path(self.mlp(self.norm2(x)))
  199. return x
  200. class PatchMerging(nn.Module):
  201. """ Patch Merging Layer
  202. Args:
  203. dim (int): Number of input channels.
  204. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
  205. """
  206. def __init__(self, dim, norm_layer=nn.LayerNorm):
  207. super().__init__()
  208. self.dim = dim
  209. self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
  210. self.norm = norm_layer(4 * dim)
  211. def forward(self, x, H, W):
  212. """ Forward function.
  213. Args:
  214. x: Input feature, tensor size (B, H*W, C).
  215. H, W: Spatial resolution of the input feature.
  216. """
  217. B, L, C = x.shape
  218. assert L == H * W, "input feature has wrong size"
  219. x = x.view(B, H, W, C)
  220. # padding
  221. pad_input = (H % 2 == 1) or (W % 2 == 1)
  222. if pad_input:
  223. x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
  224. x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
  225. x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
  226. x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
  227. x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
  228. x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
  229. x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
  230. x = self.norm(x)
  231. x = self.reduction(x)
  232. return x
  233. class BasicLayer(nn.Module):
  234. """ A basic Swin Transformer layer for one stage.
  235. Args:
  236. dim (int): Number of feature channels
  237. depth (int): Depths of this stage.
  238. num_heads (int): Number of attention head.
  239. window_size (int): Local window size. Default: 7.
  240. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
  241. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
  242. qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
  243. drop (float, optional): Dropout rate. Default: 0.0
  244. attn_drop (float, optional): Attention dropout rate. Default: 0.0
  245. drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
  246. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
  247. downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
  248. use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
  249. """
  250. def __init__(self,
  251. dim,
  252. depth,
  253. num_heads,
  254. window_size=7,
  255. mlp_ratio=4.,
  256. qkv_bias=True,
  257. qk_scale=None,
  258. drop=0.,
  259. attn_drop=0.,
  260. drop_path=0.,
  261. norm_layer=nn.LayerNorm,
  262. downsample=None,
  263. use_checkpoint=False):
  264. super().__init__()
  265. self.window_size = window_size
  266. self.shift_size = window_size // 2
  267. self.depth = depth
  268. self.use_checkpoint = use_checkpoint
  269. # build blocks
  270. self.blocks = nn.ModuleList([
  271. SwinTransformerBlock(
  272. dim=dim,
  273. num_heads=num_heads,
  274. window_size=window_size,
  275. shift_size=0 if (i % 2 == 0) else window_size // 2,
  276. mlp_ratio=mlp_ratio,
  277. qkv_bias=qkv_bias,
  278. qk_scale=qk_scale,
  279. drop=drop,
  280. attn_drop=attn_drop,
  281. drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
  282. norm_layer=norm_layer)
  283. for i in range(depth)])
  284. # patch merging layer
  285. if downsample is not None:
  286. self.downsample = downsample(dim=dim, norm_layer=norm_layer)
  287. else:
  288. self.downsample = None
  289. def forward(self, x, H, W):
  290. """ Forward function.
  291. Args:
  292. x: Input feature, tensor size (B, H*W, C).
  293. H, W: Spatial resolution of the input feature.
  294. """
  295. # calculate attention mask for SW-MSA
  296. Hp = int(np.ceil(H / self.window_size)) * self.window_size
  297. Wp = int(np.ceil(W / self.window_size)) * self.window_size
  298. img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
  299. h_slices = (slice(0, -self.window_size),
  300. slice(-self.window_size, -self.shift_size),
  301. slice(-self.shift_size, None))
  302. w_slices = (slice(0, -self.window_size),
  303. slice(-self.window_size, -self.shift_size),
  304. slice(-self.shift_size, None))
  305. cnt = 0
  306. for h in h_slices:
  307. for w in w_slices:
  308. img_mask[:, h, w, :] = cnt
  309. cnt += 1
  310. mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
  311. mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
  312. attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
  313. attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
  314. for blk in self.blocks:
  315. blk.H, blk.W = H, W
  316. if self.use_checkpoint:
  317. x = checkpoint.checkpoint(blk, x, attn_mask)
  318. else:
  319. x = blk(x, attn_mask)
  320. if self.downsample is not None:
  321. x_down = self.downsample(x, H, W)
  322. Wh, Ww = (H + 1) // 2, (W + 1) // 2
  323. return x, H, W, x_down, Wh, Ww
  324. else:
  325. return x, H, W, x, H, W
  326. class PatchEmbed(nn.Module):
  327. """ Image to Patch Embedding
  328. Args:
  329. patch_size (int): Patch token size. Default: 4.
  330. in_chans (int): Number of input image channels. Default: 3.
  331. embed_dim (int): Number of linear projection output channels. Default: 96.
  332. norm_layer (nn.Module, optional): Normalization layer. Default: None
  333. """
  334. def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
  335. super().__init__()
  336. patch_size = to_2tuple(patch_size)
  337. self.patch_size = patch_size
  338. self.in_chans = in_chans
  339. self.embed_dim = embed_dim
  340. self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
  341. if norm_layer is not None:
  342. self.norm = norm_layer(embed_dim)
  343. else:
  344. self.norm = None
  345. def forward(self, x):
  346. """Forward function."""
  347. # padding
  348. _, _, H, W = x.size()
  349. if W % self.patch_size[1] != 0:
  350. x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
  351. if H % self.patch_size[0] != 0:
  352. x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
  353. x = self.proj(x) # B C Wh Ww
  354. if self.norm is not None:
  355. Wh, Ww = x.size(2), x.size(3)
  356. x = x.flatten(2).transpose(1, 2)
  357. x = self.norm(x)
  358. x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
  359. return x
  360. class SwinTransformer(nn.Module):
  361. """ Swin Transformer backbone.
  362. A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
  363. https://arxiv.org/pdf/2103.14030
  364. Args:
  365. pretrain_img_size (int): Input image size for training the pretrained model,
  366. used in absolute postion embedding. Default 224.
  367. patch_size (int | tuple(int)): Patch size. Default: 4.
  368. in_chans (int): Number of input image channels. Default: 3.
  369. embed_dim (int): Number of linear projection output channels. Default: 96.
  370. depths (tuple[int]): Depths of each Swin Transformer stage.
  371. num_heads (tuple[int]): Number of attention head of each stage.
  372. window_size (int): Window size. Default: 7.
  373. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
  374. qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
  375. qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
  376. drop_rate (float): Dropout rate.
  377. attn_drop_rate (float): Attention dropout rate. Default: 0.
  378. drop_path_rate (float): Stochastic depth rate. Default: 0.2.
  379. norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
  380. ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
  381. patch_norm (bool): If True, add normalization after patch embedding. Default: True.
  382. out_indices (Sequence[int]): Output from which stages.
  383. frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
  384. -1 means not freezing any parameters.
  385. use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
  386. """
  387. def __init__(self,
  388. factor=0.5,
  389. depth_factor=0.5,
  390. pretrain_img_size=224,
  391. patch_size=4,
  392. in_chans=3,
  393. embed_dim=96,
  394. depths=[2, 2, 6, 2],
  395. num_heads=[3, 6, 12, 24],
  396. window_size=7,
  397. mlp_ratio=4.,
  398. qkv_bias=True,
  399. qk_scale=None,
  400. drop_rate=0.,
  401. attn_drop_rate=0.,
  402. drop_path_rate=0.2,
  403. norm_layer=nn.LayerNorm,
  404. ape=False,
  405. patch_norm=True,
  406. out_indices=(0, 1, 2, 3),
  407. frozen_stages=-1,
  408. use_checkpoint=False):
  409. super().__init__()
  410. embed_dim = int(embed_dim * factor)
  411. depths = [max(1, int(dim * depth_factor)) for dim in depths]
  412. self.pretrain_img_size = pretrain_img_size
  413. self.num_layers = len(depths)
  414. self.embed_dim = embed_dim
  415. self.ape = ape
  416. self.patch_norm = patch_norm
  417. self.out_indices = out_indices
  418. self.frozen_stages = frozen_stages
  419. # split image into non-overlapping patches
  420. self.patch_embed = PatchEmbed(
  421. patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
  422. norm_layer=norm_layer if self.patch_norm else None)
  423. # absolute position embedding
  424. if self.ape:
  425. pretrain_img_size = to_2tuple(pretrain_img_size)
  426. patch_size = to_2tuple(patch_size)
  427. patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1]]
  428. self.absolute_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1]))
  429. trunc_normal_(self.absolute_pos_embed, std=.02)
  430. self.pos_drop = nn.Dropout(p=drop_rate)
  431. # stochastic depth
  432. dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
  433. # build layers
  434. self.layers = nn.ModuleList()
  435. for i_layer in range(self.num_layers):
  436. layer = BasicLayer(
  437. dim=int(embed_dim * 2 ** i_layer),
  438. depth=depths[i_layer],
  439. num_heads=num_heads[i_layer],
  440. window_size=window_size,
  441. mlp_ratio=mlp_ratio,
  442. qkv_bias=qkv_bias,
  443. qk_scale=qk_scale,
  444. drop=drop_rate,
  445. attn_drop=attn_drop_rate,
  446. drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
  447. norm_layer=norm_layer,
  448. downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
  449. use_checkpoint=use_checkpoint)
  450. self.layers.append(layer)
  451. num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
  452. self.num_features = num_features
  453. # add a norm layer for each output
  454. for i_layer in out_indices:
  455. layer = norm_layer(num_features[i_layer])
  456. layer_name = f'norm{i_layer}'
  457. self.add_module(layer_name, layer)
  458. self.width_list = [i.size(1) for i in self.forward(torch.randn(1, 3, 640, 640))]
  459. def forward(self, x):
  460. """Forward function."""
  461. x = self.patch_embed(x)
  462. Wh, Ww = x.size(2), x.size(3)
  463. if self.ape:
  464. # interpolate the position embedding to the corresponding size
  465. absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
  466. x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
  467. else:
  468. x = x.flatten(2).transpose(1, 2)
  469. x = self.pos_drop(x)
  470. outs = []
  471. for i in range(self.num_layers):
  472. layer = self.layers[i]
  473. x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
  474. if i in self.out_indices:
  475. norm_layer = getattr(self, f'norm{i}')
  476. x_out = norm_layer(x_out)
  477. out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
  478. outs.append(out)
  479. return outs


四、手把手教你添加Swin Transformer网络结构

这个主干的网络结构添加起来算是所有的改进机制里最麻烦的了,因为有一些网略结构可以用yaml文件搭建出来,有一些网络结构其中的一些细节根本没有办法用yaml文件去搭建,用yaml文件去搭建会损失一些细节部分(而且一个网络结构设计很多细节的结构修改方式都不一样,一个一个去修改大家难免会出错),所以这里让网络直接返回整个网络,然后修改部分 yolo代码以后就都以这种形式添加了,以后我提出的网络模型基本上都会通过这种方式修改,我也会进行一些模型细节改进。创新出新的网络结构大家直接拿来用就可以的。 下面开始添加教程->

(同时每一个后面都有代码,大家拿来复制粘贴替换即可,但是要看好了不要复制粘贴替换多了)


4.1 修改一

第一还是建立文件,我们找到如下 ultralytics /nn文件夹下建立一个目录名字呢就是'Addmodules'文件夹 然后在其内部建立一个新的py文件将核心代码复制粘贴进去即可。


4.2 修改二

第二步我们在该目录下创建一个新的py文件名字为'__init__.py'( ,然后在其内部导入我们的检测头如下图所示。


4.3 修改三

第三步我门中到如下文件'ultralytics/nn/tasks.py'进行导入和注册我们的模块( !


4.4 修改四

添加如下两行代码!!!


4.5 修改五

找到七百多行大概把具体看图片,按照图片来修改就行,添加红框内的部分,注意没有()只是 函数 名。

  1. elif m in {自行添加对应的模型即可,下面都是一样的}: # 这段代码是自己添加的原代码中没有
  2. m = m(*args)
  3. c2 = m.width_list # 返回通道列表
  4. backbone = True


4.6 修改六

下面的两个红框内都是需要改动的。

  1. if isinstance(c2, list):
  2. m_ = m
  3. m_.backbone = True
  4. else:
  5. m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
  6. t = str(m)[8:-2].replace('__main__.', '') # module type
  7. m.np = sum(x.numel() for x in m_.parameters()) # number params
  8. m_.i, m_.f, m_.type = i + 4 if backbone else i, f, t # attach index, 'from' index, type


4.7 修改七

如下的也需要修改,全部按照我的来。

代码如下把原先的代码替换了即可。

  1. if verbose:
  2. LOGGER.info(f'{i:>3}{str(f):>20}{n_:>3}{m.np:10.0f} {t:<45}{str(args):<30}') # print
  3. save.extend(x % (i + 4 if backbone else i) for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
  4. layers.append(m_)
  5. if i == 0:
  6. ch = []
  7. if isinstance(c2, list):
  8. ch.extend(c2)
  9. if len(c2) != 5:
  10. ch.insert(0, 0)
  11. else:
  12. ch.append(c2)


4.8 修改八

修改七和前面的都不太一样,需要修改前向传播中的一个部分, 已经离开了parse_model方法了。

可以在图片中开代码行数,没有离开task.py文件都是同一个文件。 同时这个部分有好几个前向传播都很相似,大家不要看错了, 是70多行左右的!!!,同时我后面提供了代码,大家直接复制粘贴即可,有时间我针对这里会出一个视频。

​​

代码如下->

  1. def _predict_once(self, x, profile=False, visualize=False, embed=None):
  2. """
  3. Perform a forward pass through the network.
  4. Args:
  5. x (torch.Tensor): The input tensor to the model.
  6. profile (bool): Print the computation time of each layer if True, defaults to False.
  7. visualize (bool): Save the feature maps of the model if True, defaults to False.
  8. embed (list, optional): A list of feature vectors/embeddings to return.
  9. Returns:
  10. (torch.Tensor): The last output of the model.
  11. """
  12. y, dt, embeddings = [], [], [] # outputs
  13. for m in self.model:
  14. if m.f != -1: # if not from previous layer
  15. x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
  16. if profile:
  17. self._profile_one_layer(m, x, dt)
  18. if hasattr(m, 'backbone'):
  19. x = m(x)
  20. if len(x) != 5: # 0 - 5
  21. x.insert(0, None)
  22. for index, i in enumerate(x):
  23. if index in self.save:
  24. y.append(i)
  25. else:
  26. y.append(None)
  27. x = x[-1] # 最后一个输出传给下一层
  28. else:
  29. x = m(x) # run
  30. y.append(x if m.i in self.save else None) # save output
  31. if visualize:
  32. feature_visualization(x, m.type, m.i, save_dir=visualize)
  33. if embed and m.i in embed:
  34. embeddings.append(nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten
  35. if m.i == max(embed):
  36. return torch.unbind(torch.cat(embeddings, 1), dim=0)
  37. return x

到这里就完成了修改部分,但是这里面细节很多,大家千万要注意不要替换多余的代码,导致报错,也不要拉下任何一部,都会导致运行失败,而且报错很难排查!!!很难排查!!!


4.9 修改九

我们找到如下文件'ultralytics/utils/torch_utils.py'按照如下的图片进行修改,否则容易打印不出来计算量。


五、Swintransformer的yaml文件

复制如下yaml文件进行运行!!!

此版本训练信息:YOLO11-SwinTransformer summary: 325 layers, 2,514,792 parameters, 2,514,776 gradients, 6.1 GFLOPs

使用说明:# 下面 [-1, 1, LSKNet, [0.25,0.5]] 参数位置的0.25是通道放缩的系数, YOLOv11N是0.25 YOLOv11S是0.5 YOLOv11M是1. YOLOv11l是1 YOLOv11是1.5大家根据自己训练的YOLO版本设定即可.

#  0.5对应的是模型的深度系数

  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. # YOLO11 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
  3. # Parameters
  4. nc: 80 # number of classes
  5. scales: # model compound scaling constants, i.e. 'model=yolo11n.yaml' will call yolo11.yaml with scale 'n'
  6. # [depth, width, max_channels]
  7. n: [0.50, 0.25, 1024] # summary: 319 layers, 2624080 parameters, 2624064 gradients, 6.6 GFLOPs
  8. s: [0.50, 0.50, 1024] # summary: 319 layers, 9458752 parameters, 9458736 gradients, 21.7 GFLOPs
  9. m: [0.50, 1.00, 512] # summary: 409 layers, 20114688 parameters, 20114672 gradients, 68.5 GFLOPs
  10. l: [1.00, 1.00, 512] # summary: 631 layers, 25372160 parameters, 25372144 gradients, 87.6 GFLOPs
  11. x: [1.00, 1.50, 512] # summary: 631 layers, 56966176 parameters, 56966160 gradients, 196.0 GFLOPs
  12. # 下面 [-1, 1, SwinTransformer, [0.250.5]] 参数位置的0.25是通道放缩的系数, YOLOv11N是0.25 YOLOv11S是0.5 YOLOv11M是1. YOLOv11l是1 YOLOv111.5大家根据自己训练的YOLO版本设定即可.
  13. # 0.5对应的是模型的深度系数
  14. # YOLO11n backbone
  15. backbone:
  16. # [from, repeats, module, args]
  17. - [-1, 1, SwinTransformer, [0.25,0.5]] # 0-4 P1/2
  18. - [-1, 1, SPPF, [1024, 5]] # 5
  19. - [-1, 2, C2PSA, [1024]] # 6
  20. # YOLO11n head
  21. head:
  22. - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  23. - [[-1, 3], 1, Concat, [1]] # cat backbone P4
  24. - [-1, 2, C3k2, [512, False]] # 9
  25. - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  26. - [[-1, 2], 1, Concat, [1]] # cat backbone P3
  27. - [-1, 2, C3k2, [256, False]] # 12 (P3/8-small)
  28. - [-1, 1, Conv, [256, 3, 2]]
  29. - [[-1, 9], 1, Concat, [1]] # cat head P4
  30. - [-1, 2, C3k2, [512, False]] # 15 (P4/16-medium)
  31. - [-1, 1, Conv, [512, 3, 2]]
  32. - [[-1, 6], 1, Concat, [1]] # cat head P5
  33. - [-1, 2, C3k2, [1024, True]] # 18 (P5/32-large)
  34. - [[12, 15, 18], 1, Detect, [nc]] # Detect(P3, P4, P5)


六、成功运行记录

下面是成功运行的截图,已经完成了有1个epochs的训练,图片太大截不全第2个epochs了。


七、本文总结

到此本文的正式分享内容就结束了,在这里给大家推荐我的YOLOv11改进有效涨点专栏,本专栏目前为新开的平均质量分98分,后期我会根据各种最新的前沿顶会进行论文复现,也会对一些老的改进机制进行补充, 目前本专栏免费阅读(暂时,大家尽早关注不迷路~), 如果大家觉得本文帮助到你了,订阅本专栏,关注后续更多的更新~

​​