henry000 commited on
Commit
d13852b
·
1 Parent(s): 68d5954

✨ [New] YOLOv7 structure! enable build model

Browse files
yolo/config/model/v7-base.yaml CHANGED
@@ -1,11 +1,17 @@
 
 
1
  anchor:
2
- reg_max: 16
 
 
 
3
  strides: [8, 16, 32]
4
 
5
  model:
6
  backbone:
7
  - Conv:
8
  args: {out_channels: 32, kernel_size: 3}
 
9
  - Conv:
10
  args: {out_channels: 64, kernel_size: 3, stride: 2}
11
  - Conv:
@@ -55,7 +61,7 @@ model:
55
  args: {out_channels: 128, kernel_size: 3}
56
  - Concat:
57
  source: [-1, -3, -5, -6]
58
- tags: 8x
59
  - Conv:
60
  args: {out_channels: 512, kernel_size: 1}
61
  - Pool:
@@ -86,7 +92,7 @@ model:
86
  source: [-1, -3, -5, -6]
87
  - Conv:
88
  args: {out_channels: 1024, kernel_size: 1}
89
- tags: 16x
90
  - Pool:
91
  args: {padding: 0}
92
  - Conv:
@@ -115,17 +121,18 @@ model:
115
  source: [-1, -3, -5, -6]
116
  - Conv:
117
  args: {out_channels: 1024, kernel_size: 1}
118
- tags: 32x
119
  head:
120
  - SPPCSPConv:
121
  args: {out_channels: 512}
 
122
  - Conv:
123
  args: {out_channels: 256, kernel_size: 1}
124
  - UpSample:
125
  args: {scale_factor: 2}
126
  - Conv:
127
  args: {out_channels: 256, kernel_size: 1}
128
- source: 16x
129
  - Concat:
130
  source: [-1, -2]
131
  - Conv:
@@ -145,13 +152,14 @@ model:
145
  source: [-1, -2, -3, -4, -5, -6]
146
  - Conv:
147
  args: {out_channels: 256, kernel_size: 1}
 
148
  - Conv:
149
  args: {out_channels: 128, kernel_size: 1}
150
  - UpSample:
151
  args: {scale_factor: 2}
152
  - Conv:
153
  args: {out_channels: 128, kernel_size: 1}
154
- source: 8x
155
  - Concat:
156
  source: [-1, -2]
157
  - Conv:
@@ -171,6 +179,7 @@ model:
171
  source: [-1, -2, -3, -4, -5, -6]
172
  - Conv:
173
  args: {out_channels: 128, kernel_size: 1}
 
174
  - Pool:
175
  args: {padding: 0}
176
  - Conv:
@@ -181,7 +190,7 @@ model:
181
  - Conv:
182
  args: {out_channels: 128, kernel_size: 3, stride: 2}
183
  - Concat:
184
- source: [-1, -3, 63]
185
  - Conv:
186
  args: {out_channels: 256, kernel_size: 1}
187
  - Conv:
@@ -199,6 +208,7 @@ model:
199
  source: [-1, -2, -3, -4, -5, -6]
200
  - Conv:
201
  args: {out_channels: 256, kernel_size: 1}
 
202
  - Pool:
203
  args: {padding: 0}
204
  - Conv:
@@ -209,7 +219,7 @@ model:
209
  - Conv:
210
  args: {out_channels: 256, kernel_size: 3, stride: 2}
211
  - Concat:
212
- source: [-1, -3, 51]
213
  - Conv:
214
  args: {out_channels: 512, kernel_size: 1}
215
  - Conv:
@@ -227,20 +237,19 @@ model:
227
  source: [-1, -2, -3, -4, -5, -6]
228
  - Conv:
229
  args: {out_channels: 512, kernel_size: 1}
 
230
  - RepConv:
231
  args: {out_channels: 256}
232
- source: 75
233
  - RepConv:
234
  args: {out_channels: 512}
235
- source: 88
236
  - RepConv:
237
  args: {out_channels: 1024}
238
- source: 101
239
- - IDetect:
240
  args:
241
- anchors:
242
- - [12,16, 19,36, 40,28] # P3/8
243
- - [36,75, 76,55, 72,146] # P4/16
244
- - [142,110, 192,243, 459,401] # P5/32
245
- source: [102, 103, 104]
246
  output: True
 
 
1
+ name: v7-base
2
+
3
  anchor:
4
+ anchor:
5
+ - [12,16, 19,36, 40,28] # P5/8
6
+ - [36,75, 76,55, 72,146] # P4/16
7
+ - [142,110, 192,243, 459,401] # P5/32
8
  strides: [8, 16, 32]
9
 
10
  model:
11
  backbone:
12
  - Conv:
13
  args: {out_channels: 32, kernel_size: 3}
14
+ source: 0
15
  - Conv:
16
  args: {out_channels: 64, kernel_size: 3, stride: 2}
17
  - Conv:
 
61
  args: {out_channels: 128, kernel_size: 3}
62
  - Concat:
63
  source: [-1, -3, -5, -6]
64
+ tags: B3
65
  - Conv:
66
  args: {out_channels: 512, kernel_size: 1}
67
  - Pool:
 
92
  source: [-1, -3, -5, -6]
93
  - Conv:
94
  args: {out_channels: 1024, kernel_size: 1}
95
+ tags: B4
96
  - Pool:
97
  args: {padding: 0}
98
  - Conv:
 
121
  source: [-1, -3, -5, -6]
122
  - Conv:
123
  args: {out_channels: 1024, kernel_size: 1}
124
+ tags: B5
125
  head:
126
  - SPPCSPConv:
127
  args: {out_channels: 512}
128
+ tags: N3
129
  - Conv:
130
  args: {out_channels: 256, kernel_size: 1}
131
  - UpSample:
132
  args: {scale_factor: 2}
133
  - Conv:
134
  args: {out_channels: 256, kernel_size: 1}
135
+ source: B4
136
  - Concat:
137
  source: [-1, -2]
138
  - Conv:
 
152
  source: [-1, -2, -3, -4, -5, -6]
153
  - Conv:
154
  args: {out_channels: 256, kernel_size: 1}
155
+ tags: N2
156
  - Conv:
157
  args: {out_channels: 128, kernel_size: 1}
158
  - UpSample:
159
  args: {scale_factor: 2}
160
  - Conv:
161
  args: {out_channels: 128, kernel_size: 1}
162
+ source: B3
163
  - Concat:
164
  source: [-1, -2]
165
  - Conv:
 
179
  source: [-1, -2, -3, -4, -5, -6]
180
  - Conv:
181
  args: {out_channels: 128, kernel_size: 1}
182
+ tags: P3
183
  - Pool:
184
  args: {padding: 0}
185
  - Conv:
 
190
  - Conv:
191
  args: {out_channels: 128, kernel_size: 3, stride: 2}
192
  - Concat:
193
+ source: [-1, -3, N2]
194
  - Conv:
195
  args: {out_channels: 256, kernel_size: 1}
196
  - Conv:
 
208
  source: [-1, -2, -3, -4, -5, -6]
209
  - Conv:
210
  args: {out_channels: 256, kernel_size: 1}
211
+ tags: P4
212
  - Pool:
213
  args: {padding: 0}
214
  - Conv:
 
219
  - Conv:
220
  args: {out_channels: 256, kernel_size: 3, stride: 2}
221
  - Concat:
222
+ source: [-1, -3, N3]
223
  - Conv:
224
  args: {out_channels: 512, kernel_size: 1}
225
  - Conv:
 
237
  source: [-1, -2, -3, -4, -5, -6]
238
  - Conv:
239
  args: {out_channels: 512, kernel_size: 1}
240
+ tags: P5
241
  - RepConv:
242
  args: {out_channels: 256}
243
+ source: P3
244
  - RepConv:
245
  args: {out_channels: 512}
246
+ source: P4
247
  - RepConv:
248
  args: {out_channels: 1024}
249
+ source: P5
250
+ - MultiheadDetection:
251
  args:
252
+ version: v7
253
+ source: [-3, -2, -1]
 
 
 
254
  output: True
255
+ tags: Main
yolo/model/module.py CHANGED
@@ -91,13 +91,40 @@ class Detection(nn.Module):
91
  return class_x, anchor_x, vector_x
92
 
93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  class MultiheadDetection(nn.Module):
95
  """Mutlihead Detection module for Dual detect or Triple detect"""
96
 
97
  def __init__(self, in_channels: List[int], num_classes: int, **head_kwargs):
98
  super().__init__()
 
 
 
 
 
99
  self.heads = nn.ModuleList(
100
- [Detection((in_channels[0], in_channel), num_classes, **head_kwargs) for in_channel in in_channels]
101
  )
102
 
103
  def forward(self, x_list: List[torch.Tensor]) -> List[torch.Tensor]:
@@ -320,6 +347,32 @@ class CBLinear(nn.Module):
320
  return x.split(self.out_channels, dim=1)
321
 
322
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
323
  class SPPELAN(nn.Module):
324
  """SPPELAN module comprising multiple pooling and convolution layers."""
325
 
@@ -360,3 +413,39 @@ class CBFuse(nn.Module):
360
  res = [F.interpolate(x[pick_id], size=target_size, mode=self.mode) for pick_id, x in zip(self.idx, x_list)]
361
  out = torch.stack(res + [target]).sum(dim=0)
362
  return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  return class_x, anchor_x, vector_x
92
 
93
 
94
+ class IDetection(nn.Module):
95
+ def __init__(self, in_channels: Tuple[int], num_classes: int, *args, anchor_num: int = 3, **kwargs):
96
+ super().__init__()
97
+
98
+ if isinstance(in_channels, tuple):
99
+ in_channels = in_channels[1]
100
+
101
+ out_channel = num_classes + 5
102
+ out_channels = out_channel * anchor_num
103
+ self.head_conv = nn.Conv2d(in_channels, out_channels, 1)
104
+
105
+ self.implicit_a = ImplicitA(in_channels)
106
+ self.implicit_m = ImplicitM(out_channels)
107
+
108
+ def forward(self, x):
109
+ x = self.implicit_a(x)
110
+ x = self.head_conv(x)
111
+ x = self.implicit_m(x)
112
+
113
+ return x
114
+
115
+
116
  class MultiheadDetection(nn.Module):
117
  """Mutlihead Detection module for Dual detect or Triple detect"""
118
 
119
  def __init__(self, in_channels: List[int], num_classes: int, **head_kwargs):
120
  super().__init__()
121
+ DetectionHead = Detection
122
+
123
+ if head_kwargs.pop("version", None) == "v7":
124
+ DetectionHead = IDetection
125
+
126
  self.heads = nn.ModuleList(
127
+ [DetectionHead((in_channels[0], in_channel), num_classes, **head_kwargs) for in_channel in in_channels]
128
  )
129
 
130
  def forward(self, x_list: List[torch.Tensor]) -> List[torch.Tensor]:
 
347
  return x.split(self.out_channels, dim=1)
348
 
349
 
350
+ class SPPCSPConv(nn.Module):
351
+ # CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
352
+ def __init__(self, in_channels: int, out_channels: int, expand: float = 0.5, kernel_sizes: Tuple[int] = (5, 9, 13)):
353
+ super().__init__()
354
+ neck_channels = int(2 * out_channels * expand)
355
+ self.pre_conv = nn.Sequential(
356
+ Conv(in_channels, neck_channels, 1),
357
+ Conv(neck_channels, neck_channels, 3),
358
+ Conv(neck_channels, neck_channels, 1),
359
+ )
360
+ self.short_conv = Conv(in_channels, neck_channels, 1)
361
+ self.pools = nn.ModuleList([Pool(kernel_size=kernel_size, stride=1) for kernel_size in kernel_sizes])
362
+ self.post_conv = nn.Sequential(Conv(4 * neck_channels, neck_channels, 1), Conv(neck_channels, neck_channels, 3))
363
+ self.merge_conv = Conv(2 * neck_channels, out_channels, 1)
364
+
365
+ def forward(self, x):
366
+ features = [self.pre_conv(x)]
367
+ for pool in self.pools:
368
+ features.append(pool(features[-1]))
369
+ features = torch.cat(features, dim=1)
370
+ y1 = self.post_conv(features)
371
+ y2 = self.short_conv(x)
372
+ y = torch.cat((y1, y2), dim=1)
373
+ return self.merge_conv(y)
374
+
375
+
376
  class SPPELAN(nn.Module):
377
  """SPPELAN module comprising multiple pooling and convolution layers."""
378
 
 
413
  res = [F.interpolate(x[pick_id], size=target_size, mode=self.mode) for pick_id, x in zip(self.idx, x_list)]
414
  out = torch.stack(res + [target]).sum(dim=0)
415
  return out
416
+
417
+
418
+ class ImplicitA(nn.Module):
419
+ """
420
+ Implement YOLOR - implicit knowledge(Add), paper: https://arxiv.org/abs/2105.04206
421
+ """
422
+
423
+ def __init__(self, channel: int, mean: float = 0.0, std: float = 0.02):
424
+ super().__init__()
425
+ self.channel = channel
426
+ self.mean = mean
427
+ self.std = std
428
+
429
+ self.implicit = nn.Parameter(torch.empty(1, channel, 1, 1))
430
+ nn.init.normal_(self.implicit, mean=mean, std=self.std)
431
+
432
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
433
+ return self.implicit + x
434
+
435
+
436
+ class ImplicitM(nn.Module):
437
+ """
438
+ Implement YOLOR - implicit knowledge(multiply), paper: https://arxiv.org/abs/2105.04206
439
+ """
440
+
441
+ def __init__(self, channel: int, mean: float = 1.0, std: float = 0.02):
442
+ super().__init__()
443
+ self.channel = channel
444
+ self.mean = mean
445
+ self.std = std
446
+
447
+ self.implicit = nn.Parameter(torch.empty(1, channel, 1, 1))
448
+ nn.init.normal_(self.implicit, mean=self.mean, std=self.std)
449
+
450
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
451
+ return self.implicit * x