YOLOv11改进 | 主干/Backbone篇 | RepViT从视觉变换器(ViT)的视角重新审视CNN的目标检测网络(适配yolov11全系列)

一、本文介绍

本文给大家来的改进机制是 RepViT ,用其替换我们整个主干网络,其是今年最新推出的主干网络,其主要思想是将轻量级视觉 变换器 (ViT)的设计原则应用于传统的轻量级 卷积神经网络 (CNN)。我将其替换整个YOLOv11的Backbone,实现了大幅度涨点。我对修改后的网络(我用的最轻量的版本),在一个包含1000张图片包含大中小的检测目标的数据集上(共有20+类别),进行训练测试, 发现所有的目标上均有一定程度的涨点效果 ,下面我会附上基础版本和修改版本的训练对比图。

(本文内容可根据yolov11的N、S、M、L、X进行二次缩放,轻量化更上一层)。


目录

一、本文介绍

二、RepViT基本原理

三、RepViT的核心代码

四、手把手教你添加RepViT网络结构

4.1 修改一

4.2 修改二

4.3 修改三

4.4 修改四

4.5 修改五

4.6 修改六

4.7 修改七

4.8 修改八

注意!!! 额外的修改!

打印计算量问题解决方案

注意事项!!!

五、RepViT的yaml文件

5.1 RepViT的yaml文件版本1

5.2 训练文件

六、成功运行记录

七、本文总结


二、RepViT基本原理

官方论文地址: 官方论文地址点击即可跳转

官方代码地址: 官方代码地址点击即可跳转


RepViT: Revisiting Mobile CNN From ViT Perspective 这篇论文探讨了如何改进轻量级卷积 神经网络 (CNN)以提高其在移动设备上的性能和效率。作者们发现,虽然轻量级视觉变换器( ViT )因其能够学习全局表示而表现出色,但轻量级CNN和轻量级ViT之间的架构差异尚未得到充分研究。因此,他们通过整合轻量级ViT的高效架构设计,逐步改进标准轻量级CNN(特别是MobileNetV3),从而创造了一系列全新的纯CNN模型,称为RepViT。这些模型在各种视觉任务上表现出色,比现有的轻量级ViT更高效。

其主要的改进机制包括:

  1. 结构性重组 :通过结构性重组(Structural Re-parameterization, SR),引入多分支拓扑结构,以提高训练时的性能。

  2. 扩展比率调整 :调整卷积层中的扩展比率,以减少参数冗余和延迟,同时提高网络宽度以增强 模型 性能。

  3. 宏观设计优化 :对网络的宏观架构进行优化,包括早期卷积层的设计、更深的下采样层、简化的分类器,以及整体阶段比例的调整。

  4. 微观设计调整 :在微观架构层面进行优化,包括卷积核大小的选择和压缩激励(SE)层的最佳放置。

这些创新机制共同推动了轻量级CNN的性能和效率,使其更适合在移动设备上使用,下面的是官方论文中的结构图,我们对其进行简单的分析。

这张图片是论文中的图3,展示了RepViT架构的总览。RepViT有四个阶段,输入图像的分辨率依次为

每个阶段的通道维度用 Ci​ 表示,批处理大小用 B 表示。

  • Stem :用于预处理输入图像的模块。
  • Stage1-4 :每个阶段由多个RepViTBlock组成,以及一个可选的RepViTSEBlock,包含深度可分离卷积(3x3DW),1x1卷积,压缩激励模块(SE)和前馈网络(FFN)。每个阶段通过下采样减少空间维度。
  • Pooling :全局平均池化层,用于减少特征图的空间维度。
  • FC :全连接层,用于最终的类别预测。

总结: 大家可以将RepViT看成是MobileNet系列的改进版本


三、RepViT的核心代码

下面的代码是整个RepViT的核心代码,其中有个版本,对应的GFLOPs也不相同,使用方式看章节四。

  1. from symbol import factor
  2. import torch.nn as nn
  3. from timm.models.layers import SqueezeExcite
  4. import torch
  5. __all__ = ['repvit_m0_6','repvit_m0_9', 'repvit_m1_0', 'repvit_m1_1', 'repvit_m1_5', 'repvit_m2_3']
  6. def _make_divisible(v, divisor, min_value=None):
  7. """
  8. This function is taken from the original tf repo.
  9. It ensures that all layers have a channel number that is divisible by 8
  10. It can be seen here:
  11. https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
  12. :param v:
  13. :param divisor:
  14. :param min_value:
  15. :return:
  16. """
  17. if min_value is None:
  18. min_value = divisor
  19. new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
  20. # Make sure that round down does not go down by more than 10%.
  21. if new_v < 0.9 * v:
  22. new_v += divisor
  23. return new_v
  24. class Conv2d_BN(torch.nn.Sequential):
  25. def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1,
  26. groups=1, bn_weight_init=1, resolution=-10000):
  27. super().__init__()
  28. self.add_module('c', torch.nn.Conv2d(
  29. a, b, ks, stride, pad, dilation, groups, bias=False))
  30. self.add_module('bn', torch.nn.BatchNorm2d(b))
  31. torch.nn.init.constant_(self.bn.weight, bn_weight_init)
  32. torch.nn.init.constant_(self.bn.bias, 0)
  33. @torch.no_grad()
  34. def fuse_self(self):
  35. c, bn = self._modules.values()
  36. w = bn.weight / (bn.running_var + bn.eps) ** 0.5
  37. w = c.weight * w[:, None, None, None]
  38. b = bn.bias - bn.running_mean * bn.weight / \
  39. (bn.running_var + bn.eps) ** 0.5
  40. m = torch.nn.Conv2d(w.size(1) * self.c.groups, w.size(
  41. 0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation,
  42. groups=self.c.groups,
  43. device=c.weight.device)
  44. m.weight.data.copy_(w)
  45. m.bias.data.copy_(b)
  46. return m
  47. class Residual(torch.nn.Module):
  48. def __init__(self, m, drop=0.):
  49. super().__init__()
  50. self.m = m
  51. self.drop = drop
  52. def forward(self, x):
  53. if self.training and self.drop > 0:
  54. return x + self.m(x) * torch.rand(x.size(0), 1, 1, 1,
  55. device=x.device).ge_(self.drop).div(1 - self.drop).detach()
  56. else:
  57. return x + self.m(x)
  58. @torch.no_grad()
  59. def fuse_self(self):
  60. if isinstance(self.m, Conv2d_BN):
  61. m = self.m.fuse_self()
  62. assert (m.groups == m.in_channels)
  63. identity = torch.ones(m.weight.shape[0], m.weight.shape[1], 1, 1)
  64. identity = torch.nn.functional.pad(identity, [1, 1, 1, 1])
  65. m.weight += identity.to(m.weight.device)
  66. return m
  67. elif isinstance(self.m, torch.nn.Conv2d):
  68. m = self.m
  69. assert (m.groups != m.in_channels)
  70. identity = torch.ones(m.weight.shape[0], m.weight.shape[1], 1, 1)
  71. identity = torch.nn.functional.pad(identity, [1, 1, 1, 1])
  72. m.weight += identity.to(m.weight.device)
  73. return m
  74. else:
  75. return self
  76. class RepVGGDW(torch.nn.Module):
  77. def __init__(self, ed) -> None:
  78. super().__init__()
  79. self.conv = Conv2d_BN(ed, ed, 3, 1, 1, groups=ed)
  80. self.conv1 = torch.nn.Conv2d(ed, ed, 1, 1, 0, groups=ed)
  81. self.dim = ed
  82. self.bn = torch.nn.BatchNorm2d(ed)
  83. def forward(self, x):
  84. return self.bn((self.conv(x) + self.conv1(x)) + x)
  85. @torch.no_grad()
  86. def fuse_self(self):
  87. conv = self.conv.fuse_self()
  88. conv1 = self.conv1
  89. conv_w = conv.weight
  90. conv_b = conv.bias
  91. conv1_w = conv1.weight
  92. conv1_b = conv1.bias
  93. conv1_w = torch.nn.functional.pad(conv1_w, [1, 1, 1, 1])
  94. identity = torch.nn.functional.pad(torch.ones(conv1_w.shape[0], conv1_w.shape[1], 1, 1, device=conv1_w.device),
  95. [1, 1, 1, 1])
  96. final_conv_w = conv_w + conv1_w + identity
  97. final_conv_b = conv_b + conv1_b
  98. conv.weight.data.copy_(final_conv_w)
  99. conv.bias.data.copy_(final_conv_b)
  100. bn = self.bn
  101. w = bn.weight / (bn.running_var + bn.eps) ** 0.5
  102. w = conv.weight * w[:, None, None, None]
  103. b = bn.bias + (conv.bias - bn.running_mean) * bn.weight / \
  104. (bn.running_var + bn.eps) ** 0.5
  105. conv.weight.data.copy_(w)
  106. conv.bias.data.copy_(b)
  107. return conv
  108. class RepViTBlock(nn.Module):
  109. def __init__(self, inp, hidden_dim, oup, kernel_size, stride, use_se, use_hs):
  110. super(RepViTBlock, self).__init__()
  111. assert stride in [1, 2]
  112. self.identity = stride == 1 and inp == oup
  113. assert (hidden_dim == 2 * inp)
  114. if stride == 2:
  115. self.token_mixer = nn.Sequential(
  116. Conv2d_BN(inp, inp, kernel_size, stride, (kernel_size - 1) // 2, groups=inp),
  117. SqueezeExcite(inp, 0.25) if use_se else nn.Identity(),
  118. Conv2d_BN(inp, oup, ks=1, stride=1, pad=0)
  119. )
  120. self.channel_mixer = Residual(nn.Sequential(
  121. # pw
  122. Conv2d_BN(oup, 2 * oup, 1, 1, 0),
  123. nn.GELU() if use_hs else nn.GELU(),
  124. # pw-linear
  125. Conv2d_BN(2 * oup, oup, 1, 1, 0, bn_weight_init=0),
  126. ))
  127. else:
  128. self.token_mixer = nn.Sequential(
  129. RepVGGDW(inp),
  130. SqueezeExcite(inp, 0.25) if use_se else nn.Identity(),
  131. )
  132. self.channel_mixer = Residual(nn.Sequential(
  133. # pw
  134. Conv2d_BN(inp, hidden_dim, 1, 1, 0),
  135. nn.GELU() if use_hs else nn.GELU(),
  136. # pw-linear
  137. Conv2d_BN(hidden_dim, oup, 1, 1, 0, bn_weight_init=0),
  138. ))
  139. def forward(self, x):
  140. return self.channel_mixer(self.token_mixer(x))
  141. class RepViT(nn.Module):
  142. def __init__(self, cfgs, factor):
  143. super(RepViT, self).__init__()
  144. # setting of inverted residual blocks
  145. cfgs = [sublist[:2] + [_make_divisible(int(sublist[2] * factor) , 8)] + sublist[3:] for sublist in cfgs]
  146. self.cfgs = cfgs
  147. # building first layer
  148. input_channel = self.cfgs[0][2]
  149. patch_embed = torch.nn.Sequential(Conv2d_BN(3, input_channel // 2, 3, 2, 1), torch.nn.GELU(),
  150. Conv2d_BN(input_channel // 2, input_channel, 3, 2, 1))
  151. layers = [patch_embed]
  152. # building inverted residual blocks
  153. block = RepViTBlock
  154. for k, t, c, use_se, use_hs, s in self.cfgs:
  155. output_channel = _make_divisible(c , 8)
  156. exp_size = _make_divisible(input_channel * t, 8)
  157. layers.append(block(input_channel, exp_size, output_channel, k, s, use_se, use_hs))
  158. input_channel = output_channel
  159. self.features = nn.ModuleList(layers)
  160. self.width_list = [i.size(1) for i in self.forward(torch.randn(1, 3, 640, 640))]
  161. def forward(self, x):
  162. # x = self.features(x
  163. results = [None, None, None, None]
  164. temp = None
  165. i = None
  166. for index, f in enumerate(self.features):
  167. x = f(x)
  168. if index == 0:
  169. temp = x.size(1)
  170. i = 0
  171. elif x.size(1) == temp:
  172. results[i] = x
  173. else:
  174. temp = x.size(1)
  175. i = i + 1
  176. return results
  177. def repvit_m0_6(factor):
  178. """
  179. Constructs a MobileNetV3-Large model
  180. """
  181. cfgs = [
  182. [3, 2, 40, 1, 0, 1],
  183. [3, 2, 40, 0, 0, 1],
  184. [3, 2, 80, 0, 0, 2],
  185. [3, 2, 80, 1, 0, 1],
  186. [3, 2, 80, 0, 0, 1],
  187. [3, 2, 160, 0, 1, 2],
  188. [3, 2, 160, 1, 1, 1],
  189. [3, 2, 160, 0, 1, 1],
  190. [3, 2, 160, 1, 1, 1],
  191. [3, 2, 160, 0, 1, 1],
  192. [3, 2, 160, 1, 1, 1],
  193. [3, 2, 160, 0, 1, 1],
  194. [3, 2, 160, 1, 1, 1],
  195. [3, 2, 160, 0, 1, 1],
  196. [3, 2, 160, 0, 1, 1],
  197. [3, 2, 320, 0, 1, 2],
  198. [3, 2, 320, 1, 1, 1],
  199. ]
  200. model = RepViT(cfgs, factor)
  201. return model
  202. def repvit_m0_9(factor):
  203. """
  204. Constructs a MobileNetV3-Large model
  205. """
  206. cfgs = [
  207. # k, t, c, SE, HS, s
  208. [3, 2, 48, 1, 0, 1],
  209. [3, 2, 48, 0, 0, 1],
  210. [3, 2, 48, 0, 0, 1],
  211. [3, 2, 96, 0, 0, 2],
  212. [3, 2, 96, 1, 0, 1],
  213. [3, 2, 96, 0, 0, 1],
  214. [3, 2, 96, 0, 0, 1],
  215. [3, 2, 192, 0, 1, 2],
  216. [3, 2, 192, 1, 1, 1],
  217. [3, 2, 192, 0, 1, 1],
  218. [3, 2, 192, 1, 1, 1],
  219. [3, 2, 192, 0, 1, 1],
  220. [3, 2, 192, 1, 1, 1],
  221. [3, 2, 192, 0, 1, 1],
  222. [3, 2, 192, 1, 1, 1],
  223. [3, 2, 192, 0, 1, 1],
  224. [3, 2, 192, 1, 1, 1],
  225. [3, 2, 192, 0, 1, 1],
  226. [3, 2, 192, 1, 1, 1],
  227. [3, 2, 192, 0, 1, 1],
  228. [3, 2, 192, 1, 1, 1],
  229. [3, 2, 192, 0, 1, 1],
  230. [3, 2, 192, 0, 1, 1],
  231. [3, 2, 384, 0, 1, 2],
  232. [3, 2, 384, 1, 1, 1],
  233. [3, 2, 384, 0, 1, 1]
  234. ]
  235. model = RepViT(cfgs, factor)
  236. return model
  237. def repvit_m1_0(factor):
  238. """
  239. Constructs a MobileNetV3-Large model
  240. """
  241. cfgs = [
  242. # k, t, c, SE, HS, s
  243. [3, 2, 56, 1, 0, 1],
  244. [3, 2, 56, 0, 0, 1],
  245. [3, 2, 56, 0, 0, 1],
  246. [3, 2, 112, 0, 0, 2],
  247. [3, 2, 112, 1, 0, 1],
  248. [3, 2, 112, 0, 0, 1],
  249. [3, 2, 112, 0, 0, 1],
  250. [3, 2, 224, 0, 1, 2],
  251. [3, 2, 224, 1, 1, 1],
  252. [3, 2, 224, 0, 1, 1],
  253. [3, 2, 224, 1, 1, 1],
  254. [3, 2, 224, 0, 1, 1],
  255. [3, 2, 224, 1, 1, 1],
  256. [3, 2, 224, 0, 1, 1],
  257. [3, 2, 224, 1, 1, 1],
  258. [3, 2, 224, 0, 1, 1],
  259. [3, 2, 224, 1, 1, 1],
  260. [3, 2, 224, 0, 1, 1],
  261. [3, 2, 224, 1, 1, 1],
  262. [3, 2, 224, 0, 1, 1],
  263. [3, 2, 224, 1, 1, 1],
  264. [3, 2, 224, 0, 1, 1],
  265. [3, 2, 224, 0, 1, 1],
  266. [3, 2, 448, 0, 1, 2],
  267. [3, 2, 448, 1, 1, 1],
  268. [3, 2, 448, 0, 1, 1]
  269. ]
  270. model = RepViT(cfgs,factor=factor)
  271. return model
  272. def repvit_m1_1(factor):
  273. """
  274. Constructs a MobileNetV3-Large model
  275. """
  276. cfgs = [
  277. # k, t, c, SE, HS, s
  278. [3, 2, 64, 1, 0, 1],
  279. [3, 2, 64, 0, 0, 1],
  280. [3, 2, 64, 0, 0, 1],
  281. [3, 2, 128, 0, 0, 2],
  282. [3, 2, 128, 1, 0, 1],
  283. [3, 2, 128, 0, 0, 1],
  284. [3, 2, 128, 0, 0, 1],
  285. [3, 2, 256, 0, 1, 2],
  286. [3, 2, 256, 1, 1, 1],
  287. [3, 2, 256, 0, 1, 1],
  288. [3, 2, 256, 1, 1, 1],
  289. [3, 2, 256, 0, 1, 1],
  290. [3, 2, 256, 1, 1, 1],
  291. [3, 2, 256, 0, 1, 1],
  292. [3, 2, 256, 1, 1, 1],
  293. [3, 2, 256, 0, 1, 1],
  294. [3, 2, 256, 1, 1, 1],
  295. [3, 2, 256, 0, 1, 1],
  296. [3, 2, 256, 1, 1, 1],
  297. [3, 2, 256, 0, 1, 1],
  298. [3, 2, 256, 0, 1, 1],
  299. [3, 2, 512, 0, 1, 2],
  300. [3, 2, 512, 1, 1, 1],
  301. [3, 2, 512, 0, 1, 1]
  302. ]
  303. model = RepViT(cfgs,factor=factor)
  304. return model
  305. def repvit_m1_5(factor):
  306. """
  307. Constructs a MobileNetV3-Large model
  308. """
  309. cfgs = [
  310. # k, t, c, SE, HS, s
  311. [3, 2, 64, 1, 0, 1],
  312. [3, 2, 64, 0, 0, 1],
  313. [3, 2, 64, 1, 0, 1],
  314. [3, 2, 64, 0, 0, 1],
  315. [3, 2, 64, 0, 0, 1],
  316. [3, 2, 128, 0, 0, 2],
  317. [3, 2, 128, 1, 0, 1],
  318. [3, 2, 128, 0, 0, 1],
  319. [3, 2, 128, 1, 0, 1],
  320. [3, 2, 128, 0, 0, 1],
  321. [3, 2, 128, 0, 0, 1],
  322. [3, 2, 256, 0, 1, 2],
  323. [3, 2, 256, 1, 1, 1],
  324. [3, 2, 256, 0, 1, 1],
  325. [3, 2, 256, 1, 1, 1],
  326. [3, 2, 256, 0, 1, 1],
  327. [3, 2, 256, 1, 1, 1],
  328. [3, 2, 256, 0, 1, 1],
  329. [3, 2, 256, 1, 1, 1],
  330. [3, 2, 256, 0, 1, 1],
  331. [3, 2, 256, 1, 1, 1],
  332. [3, 2, 256, 0, 1, 1],
  333. [3, 2, 256, 1, 1, 1],
  334. [3, 2, 256, 0, 1, 1],
  335. [3, 2, 256, 1, 1, 1],
  336. [3, 2, 256, 0, 1, 1],
  337. [3, 2, 256, 1, 1, 1],
  338. [3, 2, 256, 0, 1, 1],
  339. [3, 2, 256, 1, 1, 1],
  340. [3, 2, 256, 0, 1, 1],
  341. [3, 2, 256, 1, 1, 1],
  342. [3, 2, 256, 0, 1, 1],
  343. [3, 2, 256, 1, 1, 1],
  344. [3, 2, 256, 0, 1, 1],
  345. [3, 2, 256, 1, 1, 1],
  346. [3, 2, 256, 0, 1, 1],
  347. [3, 2, 256, 0, 1, 1],
  348. [3, 2, 512, 0, 1, 2],
  349. [3, 2, 512, 1, 1, 1],
  350. [3, 2, 512, 0, 1, 1],
  351. [3, 2, 512, 1, 1, 1],
  352. [3, 2, 512, 0, 1, 1]
  353. ]
  354. model = RepViT(cfgs,factor=factor)
  355. return model
  356. def repvit_m2_3(factor):
  357. """
  358. Constructs a MobileNetV3-Large model
  359. """
  360. cfgs = [
  361. # k, t, c, SE, HS, s
  362. [3, 2, 80, 1, 0, 1],
  363. [3, 2, 80, 0, 0, 1],
  364. [3, 2, 80, 1, 0, 1],
  365. [3, 2, 80, 0, 0, 1],
  366. [3, 2, 80, 1, 0, 1],
  367. [3, 2, 80, 0, 0, 1],
  368. [3, 2, 80, 0, 0, 1],
  369. [3, 2, 160, 0, 0, 2],
  370. [3, 2, 160, 1, 0, 1],
  371. [3, 2, 160, 0, 0, 1],
  372. [3, 2, 160, 1, 0, 1],
  373. [3, 2, 160, 0, 0, 1],
  374. [3, 2, 160, 1, 0, 1],
  375. [3, 2, 160, 0, 0, 1],
  376. [3, 2, 160, 0, 0, 1],
  377. [3, 2, 320, 0, 1, 2],
  378. [3, 2, 320, 1, 1, 1],
  379. [3, 2, 320, 0, 1, 1],
  380. [3, 2, 320, 1, 1, 1],
  381. [3, 2, 320, 0, 1, 1],
  382. [3, 2, 320, 1, 1, 1],
  383. [3, 2, 320, 0, 1, 1],
  384. [3, 2, 320, 1, 1, 1],
  385. [3, 2, 320, 0, 1, 1],
  386. [3, 2, 320, 1, 1, 1],
  387. [3, 2, 320, 0, 1, 1],
  388. [3, 2, 320, 1, 1, 1],
  389. [3, 2, 320, 0, 1, 1],
  390. [3, 2, 320, 1, 1, 1],
  391. [3, 2, 320, 0, 1, 1],
  392. [3, 2, 320, 1, 1, 1],
  393. [3, 2, 320, 0, 1, 1],
  394. [3, 2, 320, 1, 1, 1],
  395. [3, 2, 320, 0, 1, 1],
  396. [3, 2, 320, 1, 1, 1],
  397. [3, 2, 320, 0, 1, 1],
  398. [3, 2, 320, 1, 1, 1],
  399. [3, 2, 320, 0, 1, 1],
  400. [3, 2, 320, 1, 1, 1],
  401. [3, 2, 320, 0, 1, 1],
  402. [3, 2, 320, 1, 1, 1],
  403. [3, 2, 320, 0, 1, 1],
  404. [3, 2, 320, 1, 1, 1],
  405. [3, 2, 320, 0, 1, 1],
  406. [3, 2, 320, 1, 1, 1],
  407. [3, 2, 320, 0, 1, 1],
  408. [3, 2, 320, 1, 1, 1],
  409. [3, 2, 320, 0, 1, 1],
  410. [3, 2, 320, 1, 1, 1],
  411. [3, 2, 320, 0, 1, 1],
  412. # [3, 2, 320, 1, 1, 1],
  413. # [3, 2, 320, 0, 1, 1],
  414. [3, 2, 320, 0, 1, 1],
  415. [3, 2, 640, 0, 1, 2],
  416. [3, 2, 640, 1, 1, 1],
  417. [3, 2, 640, 0, 1, 1],
  418. # [3, 2, 640, 1, 1, 1],
  419. # [3, 2, 640, 0, 1, 1]
  420. ]
  421. model = RepViT(cfgs,factor=factor)
  422. return model
  423. if __name__ == '__main__':
  424. model = repvit_m0_6(factor=0.25)
  425. inputs = torch.randn((1, 3, 640, 640))
  426. for i in model(inputs):
  427. print(i.size())


四、手把手教你添加RepViT网络结构

4.1 修改一

第一步还是建立文件,我们找到如下ultralytics/nn文件夹下建立一个目录名字呢就是'Addmodules'文件夹( !然后在其内部建立一个新的py文件将核心代码复制粘贴进去即可


4.2 修改二

第二步我们在该目录下创建一个新的py文件名字为'__init__.py'( ,然后在其内部导入我们的检测头如下图所示。


4.3 修改三

第三步我门中到如下文件'ultralytics/nn/tasks.py'进行导入和注册我们的模块( !


4.4 修改四

添加如下两行代码!!!


4.5 修改五

找到七百多行大概把具体看图片,按照图片来修改就行,添加红框内的部分,注意没有()只是 函数 名。

  1. elif m in {自行添加对应的模型即可,下面都是一样的}:
  2. m = m(*args)
  3. c2 = m.width_list # 返回通道列表
  4. backbone = True


4.6 修改六

下面的两个红框内都是需要改动的。

  1. if isinstance(c2, list):
  2. m_ = m
  3. m_.backbone = True
  4. else:
  5. m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
  6. t = str(m)[8:-2].replace('__main__.', '') # module type
  7. m.np = sum(x.numel() for x in m_.parameters()) # number params
  8. m_.i, m_.f, m_.type = i + 4 if backbone else i, f, t # attach index, 'from' index, type


4.7 修改七

如下的也需要修改,全部按照我的来。

代码如下把原先的代码替换了即可。

  1. if verbose:
  2. LOGGER.info(f'{i:>3}{str(f):>20}{n_:>3}{m.np:10.0f} {t:<45}{str(args):<30}') # print
  3. save.extend(x % (i + 4 if backbone else i) for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
  4. layers.append(m_)
  5. if i == 0:
  6. ch = []
  7. if isinstance(c2, list):
  8. ch.extend(c2)
  9. if len(c2) != 5:
  10. ch.insert(0, 0)
  11. else:
  12. ch.append(c2)


4.8 修改八

修改八和前面的都不太一样,需要修改前向传播中的一个部分, 已经离开了parse_model方法了。

可以在图片中开代码行数,没有离开task.py文件都是同一个文件。 同时这个部分有好几个前向传播都很相似,大家不要看错了, 是70多行左右的!!!,同时我后面提供了代码,大家直接复制粘贴即可,有时间我针对这里会出一个视频。

​​

代码如下->

  1. def _predict_once(self, x, profile=False, visualize=False, embed=None):
  2. """
  3. Perform a forward pass through the network.
  4. Args:
  5. x (torch.Tensor): The input tensor to the model.
  6. profile (bool): Print the computation time of each layer if True, defaults to False.
  7. visualize (bool): Save the feature maps of the model if True, defaults to False.
  8. embed (list, optional): A list of feature vectors/embeddings to return.
  9. Returns:
  10. (torch.Tensor): The last output of the model.
  11. """
  12. y, dt, embeddings = [], [], [] # outputs
  13. for m in self.model:
  14. if m.f != -1: # if not from previous layer
  15. x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
  16. if profile:
  17. self._profile_one_layer(m, x, dt)
  18. if hasattr(m, 'backbone'):
  19. x = m(x)
  20. if len(x) != 5: # 0 - 5
  21. x.insert(0, None)
  22. for index, i in enumerate(x):
  23. if index in self.save:
  24. y.append(i)
  25. else:
  26. y.append(None)
  27. x = x[-1] # 最后一个输出传给下一层
  28. else:
  29. x = m(x) # run
  30. y.append(x if m.i in self.save else None) # save output
  31. if visualize:
  32. feature_visualization(x, m.type, m.i, save_dir=visualize)
  33. if embed and m.i in embed:
  34. embeddings.append(nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten
  35. if m.i == max(embed):
  36. return torch.unbind(torch.cat(embeddings, 1), dim=0)
  37. return x

到这里就完成了修改部分,但是这里面细节很多,大家千万要注意不要替换多余的代码,导致报错,也不要拉下任何一部,都会导致运行失败,而且报错很难排查!!!很难排查!!!


注意!!! 额外的修改!

关注我的其实都知道,我大部分的修改都是一样的,这个网络需要额外的修改一步,就是s一个参数,将下面的s改为640!!!即可完美运行!!


打印计算量问题解决方案

我们找到如下文件'ultralytics/utils/torch_utils.py'按照如下的图片进行修改,否则容易打印不出来计算量。


注意事项!!!

如果大家在验证的时候报错形状不匹配的错误可以固定 验证集 的图片尺寸,方法如下 ->

找到下面这个文件ultralytics/ models /yolo/detect/train.py然后其中有一个类是DetectionTrainer class中的build_dataset函数中的一个参数rect=mode == 'val'改为rect=False


五、RepViT的yaml文件

复制如下yaml文件进行运行!!!


5.1 RepViT 的yaml文件版本1

此版本训练信息:YOLO11-RepViT summary: 559 layers, 2,118,115 parameters, 2,118,099 gradients, 5.4 GFLOPs

使用说明:# 下面 [-1, 1, LSKNet, [0.25]] 参数位置的0.25是通道放缩的系数, YOLOv11N是0.25 YOLOv11S是0.5 YOLOv11M是1. YOLOv11l是1 YOLOv11是1.5大家根据自己训练的YOLO版本设定即可.

# 本文支持版本有 __all__ = ['repvit_m0_6','repvit_m0_9', 'repvit_m1_0', 'repvit_m1_1', 'repvit_m1_5', 'repvit_m2_3']

  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. # YOLO11 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
  3. # Parameters
  4. nc: 80 # number of classes
  5. scales: # model compound scaling constants, i.e. 'model=yolo11n.yaml' will call yolo11.yaml with scale 'n'
  6. # [depth, width, max_channels]
  7. n: [0.50, 0.25, 1024] # summary: 319 layers, 2624080 parameters, 2624064 gradients, 6.6 GFLOPs
  8. s: [0.50, 0.50, 1024] # summary: 319 layers, 9458752 parameters, 9458736 gradients, 21.7 GFLOPs
  9. m: [0.50, 1.00, 512] # summary: 409 layers, 20114688 parameters, 20114672 gradients, 68.5 GFLOPs
  10. l: [1.00, 1.00, 512] # summary: 631 layers, 25372160 parameters, 25372144 gradients, 87.6 GFLOPs
  11. x: [1.00, 1.50, 512] # summary: 631 layers, 56966176 parameters, 56966160 gradients, 196.0 GFLOPs
  12. # 下面 [-1, 1, repvit_m0_6, [0.25]] 参数位置的0.25是通道放缩的系数, YOLOv11N是0.25 YOLOv11S是0.5 YOLOv11M是1. YOLOv11l是1 YOLOv111.5大家根据自己训练的YOLO版本设定即可.
  13. # 本文支持版本有 __all__ = ['repvit_m0_6','repvit_m0_9', 'repvit_m1_0', 'repvit_m1_1', 'repvit_m1_5', 'repvit_m2_3']
  14. # YOLO11n backbone
  15. backbone:
  16. # [from, repeats, module, args]
  17. - [-1, 1, repvit_m0_6, [0.5]] # 0-4 P1/2 这里是四层大家不要被yaml文件限制住了思维.
  18. - [-1, 1, SPPF, [1024, 5]] # 5
  19. - [-1, 2, C2PSA, [1024]] # 6
  20. # YOLO11n head
  21. head:
  22. - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  23. - [[-1, 3], 1, Concat, [1]] # cat backbone P4
  24. - [-1, 2, C3k2, [512, False]] # 9
  25. - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  26. - [[-1, 2], 1, Concat, [1]] # cat backbone P3
  27. - [-1, 2, C3k2, [256, False]] # 12 (P3/8-small)
  28. - [-1, 1, Conv, [256, 3, 2]]
  29. - [[-1, 9], 1, Concat, [1]] # cat head P4
  30. - [-1, 2, C3k2, [512, False]] # 15 (P4/16-medium)
  31. - [-1, 1, Conv, [512, 3, 2]]
  32. - [[-1, 6], 1, Concat, [1]] # cat head P5
  33. - [-1, 2, C3k2, [1024, True]] # 18 (P5/32-large)
  34. - [[12, 15, 18], 1, Detect, [nc]] # Detect(P3, P4, P5)


5.2 训练文件

  1. import warnings
  2. warnings.filterwarnings('ignore')
  3. from ultralytics import YOLO
  4. if __name__ == '__main__':
  5. model = YOLO('ultralytics/cfg/models/v8/yolov8-C2f-FasterBlock.yaml')
  6. # model.load('yolov8n.pt') # loading pretrain weights
  7. model.train(data=r'替换数据集yaml文件地址',
  8. # 如果大家任务是其它的'ultralytics/cfg/default.yaml'找到这里修改task可以改成detect, segment, classify, pose
  9. cache=False,
  10. imgsz=640,
  11. epochs=150,
  12. single_cls=False, # 是否是单类别检测
  13. batch=4,
  14. close_mosaic=10,
  15. workers=0,
  16. device='0',
  17. optimizer='SGD', # using SGD
  18. # resume='', # 如过想续训就设置last.pt的地址
  19. amp=False, # 如果出现训练损失为Nan可以关闭amp
  20. project='runs/train',
  21. name='exp',
  22. )


六、成功运行记录

下面是成功运行的截图,已经完成了有1个epochs的训练,图片太大截不全第2个epochs,这里改完之后打印出了点问题,但是不影响任何功能,后期我找时间修复一下这个问题。

​​


七、本文总结

到此本文的正式分享内容就结束了,在这里给大家推荐我的YOLOv11改进有效涨点专栏,本专栏目前为新开的平均质量分98分,后期我会根据各种最新的前沿顶会进行论文复现,也会对一些老的改进机制进行补充, 目前本专栏免费阅读(暂时,大家尽早关注不迷路~) ,如果大家觉得本文帮助到你了,订阅本专栏,关注后续更多的更新~

​​