AkashDataScience commited on
Commit
9eb3c3e
·
1 Parent(s): 7117863
Files changed (1) hide show
  1. models/yolo.py +818 -0
models/yolo.py ADDED
@@ -0,0 +1,818 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import platform
4
+ import sys
5
+ from copy import deepcopy
6
+ from pathlib import Path
7
+
8
+ FILE = Path(__file__).resolve()
9
+ ROOT = FILE.parents[1] # YOLO root directory
10
+ if str(ROOT) not in sys.path:
11
+ sys.path.append(str(ROOT)) # add ROOT to PATH
12
+ if platform.system() != 'Windows':
13
+ ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
14
+
15
+ from models.common import *
16
+ from models.experimental import *
17
+ from utils.general import LOGGER, check_version, check_yaml, make_divisible, print_args
18
+ from utils.plots import feature_visualization
19
+ from utils.torch_utils import (fuse_conv_and_bn, initialize_weights, model_info, profile, scale_img, select_device,
20
+ time_sync)
21
+ from utils.tal.anchor_generator import make_anchors, dist2bbox
22
+
23
+ try:
24
+ import thop # for FLOPs computation
25
+ except ImportError:
26
+ thop = None
27
+
28
+
29
+ class Detect(nn.Module):
30
+ # YOLO Detect head for detection models
31
+ dynamic = False # force grid reconstruction
32
+ export = False # export mode
33
+ shape = None
34
+ anchors = torch.empty(0) # init
35
+ strides = torch.empty(0) # init
36
+
37
+ def __init__(self, nc=80, ch=(), inplace=True): # detection layer
38
+ super().__init__()
39
+ self.nc = nc # number of classes
40
+ self.nl = len(ch) # number of detection layers
41
+ self.reg_max = 16
42
+ self.no = nc + self.reg_max * 4 # number of outputs per anchor
43
+ self.inplace = inplace # use inplace ops (e.g. slice assignment)
44
+ self.stride = torch.zeros(self.nl) # strides computed during build
45
+
46
+ c2, c3 = max((ch[0] // 4, self.reg_max * 4, 16)), max((ch[0], min((self.nc * 2, 128)))) # channels
47
+ self.cv2 = nn.ModuleList(
48
+ nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch)
49
+ self.cv3 = nn.ModuleList(
50
+ nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch)
51
+ self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()
52
+
53
+ def forward(self, x):
54
+ shape = x[0].shape # BCHW
55
+ for i in range(self.nl):
56
+ x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
57
+ if self.training:
58
+ return x
59
+ elif self.dynamic or self.shape != shape:
60
+ self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
61
+ self.shape = shape
62
+
63
+ box, cls = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2).split((self.reg_max * 4, self.nc), 1)
64
+ dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
65
+ y = torch.cat((dbox, cls.sigmoid()), 1)
66
+ return y if self.export else (y, x)
67
+
68
+ def bias_init(self):
69
+ # Initialize Detect() biases, WARNING: requires stride availability
70
+ m = self # self.model[-1] # Detect() module
71
+ # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
72
+ # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
73
+ for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
74
+ a[-1].bias.data[:] = 1.0 # box
75
+ b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
76
+
77
+
78
+ class DDetect(nn.Module):
79
+ # YOLO Detect head for detection models
80
+ dynamic = False # force grid reconstruction
81
+ export = False # export mode
82
+ shape = None
83
+ anchors = torch.empty(0) # init
84
+ strides = torch.empty(0) # init
85
+
86
+ def __init__(self, nc=80, ch=(), inplace=True): # detection layer
87
+ super().__init__()
88
+ self.nc = nc # number of classes
89
+ self.nl = len(ch) # number of detection layers
90
+ self.reg_max = 16
91
+ self.no = nc + self.reg_max * 4 # number of outputs per anchor
92
+ self.inplace = inplace # use inplace ops (e.g. slice assignment)
93
+ self.stride = torch.zeros(self.nl) # strides computed during build
94
+
95
+ c2, c3 = make_divisible(max((ch[0] // 4, self.reg_max * 4, 16)), 4), max((ch[0], min((self.nc * 2, 128)))) # channels
96
+ self.cv2 = nn.ModuleList(
97
+ nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3, g=4), nn.Conv2d(c2, 4 * self.reg_max, 1, groups=4)) for x in ch)
98
+ self.cv3 = nn.ModuleList(
99
+ nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch)
100
+ self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()
101
+
102
+ def forward(self, x):
103
+ shape = x[0].shape # BCHW
104
+ for i in range(self.nl):
105
+ x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
106
+ if self.training:
107
+ return x
108
+ elif self.dynamic or self.shape != shape:
109
+ self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
110
+ self.shape = shape
111
+
112
+ box, cls = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2).split((self.reg_max * 4, self.nc), 1)
113
+ dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
114
+ y = torch.cat((dbox, cls.sigmoid()), 1)
115
+ return y if self.export else (y, x)
116
+
117
+ def bias_init(self):
118
+ # Initialize Detect() biases, WARNING: requires stride availability
119
+ m = self # self.model[-1] # Detect() module
120
+ # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
121
+ # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
122
+ for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
123
+ a[-1].bias.data[:] = 1.0 # box
124
+ b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
125
+
126
+
127
+ class DualDetect(nn.Module):
128
+ # YOLO Detect head for detection models
129
+ dynamic = False # force grid reconstruction
130
+ export = False # export mode
131
+ shape = None
132
+ anchors = torch.empty(0) # init
133
+ strides = torch.empty(0) # init
134
+
135
+ def __init__(self, nc=80, ch=(), inplace=True): # detection layer
136
+ super().__init__()
137
+ self.nc = nc # number of classes
138
+ self.nl = len(ch) // 2 # number of detection layers
139
+ self.reg_max = 16
140
+ self.no = nc + self.reg_max * 4 # number of outputs per anchor
141
+ self.inplace = inplace # use inplace ops (e.g. slice assignment)
142
+ self.stride = torch.zeros(self.nl) # strides computed during build
143
+
144
+ c2, c3 = max((ch[0] // 4, self.reg_max * 4, 16)), max((ch[0], min((self.nc * 2, 128)))) # channels
145
+ c4, c5 = max((ch[self.nl] // 4, self.reg_max * 4, 16)), max((ch[self.nl], min((self.nc * 2, 128)))) # channels
146
+ self.cv2 = nn.ModuleList(
147
+ nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch[:self.nl])
148
+ self.cv3 = nn.ModuleList(
149
+ nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch[:self.nl])
150
+ self.cv4 = nn.ModuleList(
151
+ nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, 4 * self.reg_max, 1)) for x in ch[self.nl:])
152
+ self.cv5 = nn.ModuleList(
153
+ nn.Sequential(Conv(x, c5, 3), Conv(c5, c5, 3), nn.Conv2d(c5, self.nc, 1)) for x in ch[self.nl:])
154
+ self.dfl = DFL(self.reg_max)
155
+ self.dfl2 = DFL(self.reg_max)
156
+
157
+ def forward(self, x):
158
+ shape = x[0].shape # BCHW
159
+ d1 = []
160
+ d2 = []
161
+ for i in range(self.nl):
162
+ d1.append(torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1))
163
+ d2.append(torch.cat((self.cv4[i](x[self.nl+i]), self.cv5[i](x[self.nl+i])), 1))
164
+ if self.training:
165
+ return [d1, d2]
166
+ elif self.dynamic or self.shape != shape:
167
+ self.anchors, self.strides = (d1.transpose(0, 1) for d1 in make_anchors(d1, self.stride, 0.5))
168
+ self.shape = shape
169
+
170
+ box, cls = torch.cat([di.view(shape[0], self.no, -1) for di in d1], 2).split((self.reg_max * 4, self.nc), 1)
171
+ dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
172
+ box2, cls2 = torch.cat([di.view(shape[0], self.no, -1) for di in d2], 2).split((self.reg_max * 4, self.nc), 1)
173
+ dbox2 = dist2bbox(self.dfl2(box2), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
174
+ y = [torch.cat((dbox, cls.sigmoid()), 1), torch.cat((dbox2, cls2.sigmoid()), 1)]
175
+ return y if self.export else (y, [d1, d2])
176
+
177
+ def bias_init(self):
178
+ # Initialize Detect() biases, WARNING: requires stride availability
179
+ m = self # self.model[-1] # Detect() module
180
+ # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
181
+ # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
182
+ for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
183
+ a[-1].bias.data[:] = 1.0 # box
184
+ b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
185
+ for a, b, s in zip(m.cv4, m.cv5, m.stride): # from
186
+ a[-1].bias.data[:] = 1.0 # box
187
+ b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
188
+
189
+
190
+ class DualDDetect(nn.Module):
191
+ # YOLO Detect head for detection models
192
+ dynamic = False # force grid reconstruction
193
+ export = False # export mode
194
+ shape = None
195
+ anchors = torch.empty(0) # init
196
+ strides = torch.empty(0) # init
197
+
198
+ def __init__(self, nc=80, ch=(), inplace=True): # detection layer
199
+ super().__init__()
200
+ self.nc = nc # number of classes
201
+ self.nl = len(ch) // 2 # number of detection layers
202
+ self.reg_max = 16
203
+ self.no = nc + self.reg_max * 4 # number of outputs per anchor
204
+ self.inplace = inplace # use inplace ops (e.g. slice assignment)
205
+ self.stride = torch.zeros(self.nl) # strides computed during build
206
+
207
+ c2, c3 = make_divisible(max((ch[0] // 4, self.reg_max * 4, 16)), 4), max((ch[0], min((self.nc * 2, 128)))) # channels
208
+ c4, c5 = make_divisible(max((ch[self.nl] // 4, self.reg_max * 4, 16)), 4), max((ch[self.nl], min((self.nc * 2, 128)))) # channels
209
+ self.cv2 = nn.ModuleList(
210
+ nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3, g=4), nn.Conv2d(c2, 4 * self.reg_max, 1, groups=4)) for x in ch[:self.nl])
211
+ self.cv3 = nn.ModuleList(
212
+ nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch[:self.nl])
213
+ self.cv4 = nn.ModuleList(
214
+ nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3, g=4), nn.Conv2d(c4, 4 * self.reg_max, 1, groups=4)) for x in ch[self.nl:])
215
+ self.cv5 = nn.ModuleList(
216
+ nn.Sequential(Conv(x, c5, 3), Conv(c5, c5, 3), nn.Conv2d(c5, self.nc, 1)) for x in ch[self.nl:])
217
+ self.dfl = DFL(self.reg_max)
218
+ self.dfl2 = DFL(self.reg_max)
219
+
220
+ def forward(self, x):
221
+ shape = x[0].shape # BCHW
222
+ d1 = []
223
+ d2 = []
224
+ for i in range(self.nl):
225
+ d1.append(torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1))
226
+ d2.append(torch.cat((self.cv4[i](x[self.nl+i]), self.cv5[i](x[self.nl+i])), 1))
227
+ if self.training:
228
+ return [d1, d2]
229
+ elif self.dynamic or self.shape != shape:
230
+ self.anchors, self.strides = (d1.transpose(0, 1) for d1 in make_anchors(d1, self.stride, 0.5))
231
+ self.shape = shape
232
+
233
+ box, cls = torch.cat([di.view(shape[0], self.no, -1) for di in d1], 2).split((self.reg_max * 4, self.nc), 1)
234
+ dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
235
+ box2, cls2 = torch.cat([di.view(shape[0], self.no, -1) for di in d2], 2).split((self.reg_max * 4, self.nc), 1)
236
+ dbox2 = dist2bbox(self.dfl2(box2), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
237
+ y = [torch.cat((dbox, cls.sigmoid()), 1), torch.cat((dbox2, cls2.sigmoid()), 1)]
238
+ return y if self.export else (y, [d1, d2])
239
+ #y = torch.cat((dbox2, cls2.sigmoid()), 1)
240
+ #return y if self.export else (y, d2)
241
+ #y1 = torch.cat((dbox, cls.sigmoid()), 1)
242
+ #y2 = torch.cat((dbox2, cls2.sigmoid()), 1)
243
+ #return [y1, y2] if self.export else [(y1, d1), (y2, d2)]
244
+ #return [y1, y2] if self.export else [(y1, y2), (d1, d2)]
245
+
246
+ def bias_init(self):
247
+ # Initialize Detect() biases, WARNING: requires stride availability
248
+ m = self # self.model[-1] # Detect() module
249
+ # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
250
+ # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
251
+ for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
252
+ a[-1].bias.data[:] = 1.0 # box
253
+ b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
254
+ for a, b, s in zip(m.cv4, m.cv5, m.stride): # from
255
+ a[-1].bias.data[:] = 1.0 # box
256
+ b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
257
+
258
+
259
+ class TripleDetect(nn.Module):
260
+ # YOLO Detect head for detection models
261
+ dynamic = False # force grid reconstruction
262
+ export = False # export mode
263
+ shape = None
264
+ anchors = torch.empty(0) # init
265
+ strides = torch.empty(0) # init
266
+
267
+ def __init__(self, nc=80, ch=(), inplace=True): # detection layer
268
+ super().__init__()
269
+ self.nc = nc # number of classes
270
+ self.nl = len(ch) // 3 # number of detection layers
271
+ self.reg_max = 16
272
+ self.no = nc + self.reg_max * 4 # number of outputs per anchor
273
+ self.inplace = inplace # use inplace ops (e.g. slice assignment)
274
+ self.stride = torch.zeros(self.nl) # strides computed during build
275
+
276
+ c2, c3 = max((ch[0] // 4, self.reg_max * 4, 16)), max((ch[0], min((self.nc * 2, 128)))) # channels
277
+ c4, c5 = max((ch[self.nl] // 4, self.reg_max * 4, 16)), max((ch[self.nl], min((self.nc * 2, 128)))) # channels
278
+ c6, c7 = max((ch[self.nl * 2] // 4, self.reg_max * 4, 16)), max((ch[self.nl * 2], min((self.nc * 2, 128)))) # channels
279
+ self.cv2 = nn.ModuleList(
280
+ nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch[:self.nl])
281
+ self.cv3 = nn.ModuleList(
282
+ nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch[:self.nl])
283
+ self.cv4 = nn.ModuleList(
284
+ nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, 4 * self.reg_max, 1)) for x in ch[self.nl:self.nl*2])
285
+ self.cv5 = nn.ModuleList(
286
+ nn.Sequential(Conv(x, c5, 3), Conv(c5, c5, 3), nn.Conv2d(c5, self.nc, 1)) for x in ch[self.nl:self.nl*2])
287
+ self.cv6 = nn.ModuleList(
288
+ nn.Sequential(Conv(x, c6, 3), Conv(c6, c6, 3), nn.Conv2d(c6, 4 * self.reg_max, 1)) for x in ch[self.nl*2:self.nl*3])
289
+ self.cv7 = nn.ModuleList(
290
+ nn.Sequential(Conv(x, c7, 3), Conv(c7, c7, 3), nn.Conv2d(c7, self.nc, 1)) for x in ch[self.nl*2:self.nl*3])
291
+ self.dfl = DFL(self.reg_max)
292
+ self.dfl2 = DFL(self.reg_max)
293
+ self.dfl3 = DFL(self.reg_max)
294
+
295
+ def forward(self, x):
296
+ shape = x[0].shape # BCHW
297
+ d1 = []
298
+ d2 = []
299
+ d3 = []
300
+ for i in range(self.nl):
301
+ d1.append(torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1))
302
+ d2.append(torch.cat((self.cv4[i](x[self.nl+i]), self.cv5[i](x[self.nl+i])), 1))
303
+ d3.append(torch.cat((self.cv6[i](x[self.nl*2+i]), self.cv7[i](x[self.nl*2+i])), 1))
304
+ if self.training:
305
+ return [d1, d2, d3]
306
+ elif self.dynamic or self.shape != shape:
307
+ self.anchors, self.strides = (d1.transpose(0, 1) for d1 in make_anchors(d1, self.stride, 0.5))
308
+ self.shape = shape
309
+
310
+ box, cls = torch.cat([di.view(shape[0], self.no, -1) for di in d1], 2).split((self.reg_max * 4, self.nc), 1)
311
+ dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
312
+ box2, cls2 = torch.cat([di.view(shape[0], self.no, -1) for di in d2], 2).split((self.reg_max * 4, self.nc), 1)
313
+ dbox2 = dist2bbox(self.dfl2(box2), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
314
+ box3, cls3 = torch.cat([di.view(shape[0], self.no, -1) for di in d3], 2).split((self.reg_max * 4, self.nc), 1)
315
+ dbox3 = dist2bbox(self.dfl3(box3), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
316
+ y = [torch.cat((dbox, cls.sigmoid()), 1), torch.cat((dbox2, cls2.sigmoid()), 1), torch.cat((dbox3, cls3.sigmoid()), 1)]
317
+ return y if self.export else (y, [d1, d2, d3])
318
+
319
+ def bias_init(self):
320
+ # Initialize Detect() biases, WARNING: requires stride availability
321
+ m = self # self.model[-1] # Detect() module
322
+ # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
323
+ # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
324
+ for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
325
+ a[-1].bias.data[:] = 1.0 # box
326
+ b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
327
+ for a, b, s in zip(m.cv4, m.cv5, m.stride): # from
328
+ a[-1].bias.data[:] = 1.0 # box
329
+ b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
330
+ for a, b, s in zip(m.cv6, m.cv7, m.stride): # from
331
+ a[-1].bias.data[:] = 1.0 # box
332
+ b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
333
+
334
+
335
+ class TripleDDetect(nn.Module):
336
+ # YOLO Detect head for detection models
337
+ dynamic = False # force grid reconstruction
338
+ export = False # export mode
339
+ shape = None
340
+ anchors = torch.empty(0) # init
341
+ strides = torch.empty(0) # init
342
+
343
+ def __init__(self, nc=80, ch=(), inplace=True): # detection layer
344
+ super().__init__()
345
+ self.nc = nc # number of classes
346
+ self.nl = len(ch) // 3 # number of detection layers
347
+ self.reg_max = 16
348
+ self.no = nc + self.reg_max * 4 # number of outputs per anchor
349
+ self.inplace = inplace # use inplace ops (e.g. slice assignment)
350
+ self.stride = torch.zeros(self.nl) # strides computed during build
351
+
352
+ c2, c3 = make_divisible(max((ch[0] // 4, self.reg_max * 4, 16)), 4), \
353
+ max((ch[0], min((self.nc * 2, 128)))) # channels
354
+ c4, c5 = make_divisible(max((ch[self.nl] // 4, self.reg_max * 4, 16)), 4), \
355
+ max((ch[self.nl], min((self.nc * 2, 128)))) # channels
356
+ c6, c7 = make_divisible(max((ch[self.nl * 2] // 4, self.reg_max * 4, 16)), 4), \
357
+ max((ch[self.nl * 2], min((self.nc * 2, 128)))) # channels
358
+ self.cv2 = nn.ModuleList(
359
+ nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3, g=4),
360
+ nn.Conv2d(c2, 4 * self.reg_max, 1, groups=4)) for x in ch[:self.nl])
361
+ self.cv3 = nn.ModuleList(
362
+ nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch[:self.nl])
363
+ self.cv4 = nn.ModuleList(
364
+ nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3, g=4),
365
+ nn.Conv2d(c4, 4 * self.reg_max, 1, groups=4)) for x in ch[self.nl:self.nl*2])
366
+ self.cv5 = nn.ModuleList(
367
+ nn.Sequential(Conv(x, c5, 3), Conv(c5, c5, 3), nn.Conv2d(c5, self.nc, 1)) for x in ch[self.nl:self.nl*2])
368
+ self.cv6 = nn.ModuleList(
369
+ nn.Sequential(Conv(x, c6, 3), Conv(c6, c6, 3, g=4),
370
+ nn.Conv2d(c6, 4 * self.reg_max, 1, groups=4)) for x in ch[self.nl*2:self.nl*3])
371
+ self.cv7 = nn.ModuleList(
372
+ nn.Sequential(Conv(x, c7, 3), Conv(c7, c7, 3), nn.Conv2d(c7, self.nc, 1)) for x in ch[self.nl*2:self.nl*3])
373
+ self.dfl = DFL(self.reg_max)
374
+ self.dfl2 = DFL(self.reg_max)
375
+ self.dfl3 = DFL(self.reg_max)
376
+
377
+ def forward(self, x):
378
+ shape = x[0].shape # BCHW
379
+ d1 = []
380
+ d2 = []
381
+ d3 = []
382
+ for i in range(self.nl):
383
+ d1.append(torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1))
384
+ d2.append(torch.cat((self.cv4[i](x[self.nl+i]), self.cv5[i](x[self.nl+i])), 1))
385
+ d3.append(torch.cat((self.cv6[i](x[self.nl*2+i]), self.cv7[i](x[self.nl*2+i])), 1))
386
+ if self.training:
387
+ return [d1, d2, d3]
388
+ elif self.dynamic or self.shape != shape:
389
+ self.anchors, self.strides = (d1.transpose(0, 1) for d1 in make_anchors(d1, self.stride, 0.5))
390
+ self.shape = shape
391
+
392
+ box, cls = torch.cat([di.view(shape[0], self.no, -1) for di in d1], 2).split((self.reg_max * 4, self.nc), 1)
393
+ dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
394
+ box2, cls2 = torch.cat([di.view(shape[0], self.no, -1) for di in d2], 2).split((self.reg_max * 4, self.nc), 1)
395
+ dbox2 = dist2bbox(self.dfl2(box2), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
396
+ box3, cls3 = torch.cat([di.view(shape[0], self.no, -1) for di in d3], 2).split((self.reg_max * 4, self.nc), 1)
397
+ dbox3 = dist2bbox(self.dfl3(box3), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
398
+ #y = [torch.cat((dbox, cls.sigmoid()), 1), torch.cat((dbox2, cls2.sigmoid()), 1), torch.cat((dbox3, cls3.sigmoid()), 1)]
399
+ #return y if self.export else (y, [d1, d2, d3])
400
+ y = torch.cat((dbox3, cls3.sigmoid()), 1)
401
+ return y if self.export else (y, d3)
402
+
403
+ def bias_init(self):
404
+ # Initialize Detect() biases, WARNING: requires stride availability
405
+ m = self # self.model[-1] # Detect() module
406
+ # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
407
+ # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
408
+ for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
409
+ a[-1].bias.data[:] = 1.0 # box
410
+ b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
411
+ for a, b, s in zip(m.cv4, m.cv5, m.stride): # from
412
+ a[-1].bias.data[:] = 1.0 # box
413
+ b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
414
+ for a, b, s in zip(m.cv6, m.cv7, m.stride): # from
415
+ a[-1].bias.data[:] = 1.0 # box
416
+ b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
417
+
418
+
419
+ class Segment(Detect):
420
+ # YOLO Segment head for segmentation models
421
+ def __init__(self, nc=80, nm=32, npr=256, ch=(), inplace=True):
422
+ super().__init__(nc, ch, inplace)
423
+ self.nm = nm # number of masks
424
+ self.npr = npr # number of protos
425
+ self.proto = Proto(ch[0], self.npr, self.nm) # protos
426
+ self.detect = Detect.forward
427
+
428
+ c4 = max(ch[0] // 4, self.nm)
429
+ self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nm, 1)) for x in ch)
430
+
431
+ def forward(self, x):
432
+ p = self.proto(x[0])
433
+ bs = p.shape[0]
434
+
435
+ mc = torch.cat([self.cv4[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2) # mask coefficients
436
+ x = self.detect(self, x)
437
+ if self.training:
438
+ return x, mc, p
439
+ return (torch.cat([x, mc], 1), p) if self.export else (torch.cat([x[0], mc], 1), (x[1], mc, p))
440
+
441
+
442
+ class DSegment(DDetect):
443
+ # YOLO Segment head for segmentation models
444
+ def __init__(self, nc=80, nm=32, npr=256, ch=(), inplace=True):
445
+ super().__init__(nc, ch[:-1], inplace)
446
+ self.nl = len(ch)-1
447
+ self.nm = nm # number of masks
448
+ self.npr = npr # number of protos
449
+ self.proto = Conv(ch[-1], self.nm, 1) # protos
450
+ self.detect = DDetect.forward
451
+
452
+ c4 = max(ch[0] // 4, self.nm)
453
+ self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nm, 1)) for x in ch[:-1])
454
+
455
+ def forward(self, x):
456
+ p = self.proto(x[-1])
457
+ bs = p.shape[0]
458
+
459
+ mc = torch.cat([self.cv4[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2) # mask coefficients
460
+ x = self.detect(self, x[:-1])
461
+ if self.training:
462
+ return x, mc, p
463
+ return (torch.cat([x, mc], 1), p) if self.export else (torch.cat([x[0], mc], 1), (x[1], mc, p))
464
+
465
+
466
+ class DualDSegment(DualDDetect):
467
+ # YOLO Segment head for segmentation models
468
+ def __init__(self, nc=80, nm=32, npr=256, ch=(), inplace=True):
469
+ super().__init__(nc, ch[:-2], inplace)
470
+ self.nl = (len(ch)-2) // 2
471
+ self.nm = nm # number of masks
472
+ self.npr = npr # number of protos
473
+ self.proto = Conv(ch[-2], self.nm, 1) # protos
474
+ self.proto2 = Conv(ch[-1], self.nm, 1) # protos
475
+ self.detect = DualDDetect.forward
476
+
477
+ c6 = max(ch[0] // 4, self.nm)
478
+ c7 = max(ch[self.nl] // 4, self.nm)
479
+ self.cv6 = nn.ModuleList(nn.Sequential(Conv(x, c6, 3), Conv(c6, c6, 3), nn.Conv2d(c6, self.nm, 1)) for x in ch[:self.nl])
480
+ self.cv7 = nn.ModuleList(nn.Sequential(Conv(x, c7, 3), Conv(c7, c7, 3), nn.Conv2d(c7, self.nm, 1)) for x in ch[self.nl:self.nl*2])
481
+
482
+ def forward(self, x):
483
+ p = [self.proto(x[-2]), self.proto2(x[-1])]
484
+ bs = p[0].shape[0]
485
+
486
+ mc = [torch.cat([self.cv6[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2),
487
+ torch.cat([self.cv7[i](x[self.nl+i]).view(bs, self.nm, -1) for i in range(self.nl)], 2)] # mask coefficients
488
+ d = self.detect(self, x[:-2])
489
+ if self.training:
490
+ return d, mc, p
491
+ return (torch.cat([d[0][1], mc[1]], 1), (d[1][1], mc[1], p[1]))
492
+
493
+
494
+ class Panoptic(Detect):
495
+ # YOLO Panoptic head for panoptic segmentation models
496
+ def __init__(self, nc=80, sem_nc=93, nm=32, npr=256, ch=(), inplace=True):
497
+ super().__init__(nc, ch, inplace)
498
+ self.sem_nc = sem_nc
499
+ self.nm = nm # number of masks
500
+ self.npr = npr # number of protos
501
+ self.proto = Proto(ch[0], self.npr, self.nm) # protos
502
+ self.uconv = UConv(ch[0], ch[0]//4, self.sem_nc+self.nc)
503
+ self.detect = Detect.forward
504
+
505
+ c4 = max(ch[0] // 4, self.nm)
506
+ self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nm, 1)) for x in ch)
507
+
508
+
509
+ def forward(self, x):
510
+ p = self.proto(x[0])
511
+ s = self.uconv(x[0])
512
+ bs = p.shape[0]
513
+
514
+ mc = torch.cat([self.cv4[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2) # mask coefficients
515
+ x = self.detect(self, x)
516
+ if self.training:
517
+ return x, mc, p, s
518
+ return (torch.cat([x, mc], 1), p, s) if self.export else (torch.cat([x[0], mc], 1), (x[1], mc, p, s))
519
+
520
+
521
+ class BaseModel(nn.Module):
522
+ # YOLO base model
523
+ def forward(self, x, profile=False, visualize=False):
524
+ return self._forward_once(x, profile, visualize) # single-scale inference, train
525
+
526
+ def _forward_once(self, x, profile=False, visualize=False):
527
+ y, dt = [], [] # outputs
528
+ for m in self.model:
529
+ if m.f != -1: # if not from previous layer
530
+ x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
531
+ if profile:
532
+ self._profile_one_layer(m, x, dt)
533
+ x = m(x) # run
534
+ y.append(x if m.i in self.save else None) # save output
535
+ if visualize:
536
+ feature_visualization(x, m.type, m.i, save_dir=visualize)
537
+ return x
538
+
539
+ def _profile_one_layer(self, m, x, dt):
540
+ c = m == self.model[-1] # is final layer, copy input as inplace fix
541
+ o = thop.profile(m, inputs=(x.copy() if c else x,), verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPs
542
+ t = time_sync()
543
+ for _ in range(10):
544
+ m(x.copy() if c else x)
545
+ dt.append((time_sync() - t) * 100)
546
+ if m == self.model[0]:
547
+ LOGGER.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s} module")
548
+ LOGGER.info(f'{dt[-1]:10.2f} {o:10.2f} {m.np:10.0f} {m.type}')
549
+ if c:
550
+ LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s} Total")
551
+
552
+ def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
553
+ LOGGER.info('Fusing layers... ')
554
+ for m in self.model.modules():
555
+ if isinstance(m, (RepConvN)) and hasattr(m, 'fuse_convs'):
556
+ m.fuse_convs()
557
+ m.forward = m.forward_fuse # update forward
558
+ if isinstance(m, (Conv, DWConv)) and hasattr(m, 'bn'):
559
+ m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
560
+ delattr(m, 'bn') # remove batchnorm
561
+ m.forward = m.forward_fuse # update forward
562
+ self.info()
563
+ return self
564
+
565
+ def info(self, verbose=False, img_size=640): # print model information
566
+ model_info(self, verbose, img_size)
567
+
568
+ def _apply(self, fn):
569
+ # Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
570
+ self = super()._apply(fn)
571
+ m = self.model[-1] # Detect()
572
+ if isinstance(m, (Detect, DualDetect, TripleDetect, DDetect, DualDDetect, TripleDDetect, Segment, DSegment, DualDSegment, Panoptic)):
573
+ m.stride = fn(m.stride)
574
+ m.anchors = fn(m.anchors)
575
+ m.strides = fn(m.strides)
576
+ # m.grid = list(map(fn, m.grid))
577
+ return self
578
+
579
+
580
+ class DetectionModel(BaseModel):
581
+ # YOLO detection model
582
+ def __init__(self, cfg='yolo.yaml', ch=3, nc=None, anchors=None): # model, input channels, number of classes
583
+ super().__init__()
584
+ if isinstance(cfg, dict):
585
+ self.yaml = cfg # model dict
586
+ else: # is *.yaml
587
+ import yaml # for torch hub
588
+ self.yaml_file = Path(cfg).name
589
+ with open(cfg, encoding='ascii', errors='ignore') as f:
590
+ self.yaml = yaml.safe_load(f) # model dict
591
+
592
+ # Define model
593
+ ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
594
+ if nc and nc != self.yaml['nc']:
595
+ LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
596
+ self.yaml['nc'] = nc # override yaml value
597
+ if anchors:
598
+ LOGGER.info(f'Overriding model.yaml anchors with anchors={anchors}')
599
+ self.yaml['anchors'] = round(anchors) # override yaml value
600
+ self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist
601
+ self.names = [str(i) for i in range(self.yaml['nc'])] # default names
602
+ self.inplace = self.yaml.get('inplace', True)
603
+
604
+ # Build strides, anchors
605
+ m = self.model[-1] # Detect()
606
+ if isinstance(m, (Detect, DDetect, Segment, DSegment, Panoptic)):
607
+ s = 256 # 2x min stride
608
+ m.inplace = self.inplace
609
+ forward = lambda x: self.forward(x)[0] if isinstance(m, (Segment, DSegment, Panoptic)) else self.forward(x)
610
+ m.stride = torch.tensor([s / x.shape[-2] for x in forward(torch.zeros(1, ch, s, s))]) # forward
611
+ # check_anchor_order(m)
612
+ # m.anchors /= m.stride.view(-1, 1, 1)
613
+ self.stride = m.stride
614
+ m.bias_init() # only run once
615
+ if isinstance(m, (DualDetect, TripleDetect, DualDDetect, TripleDDetect, DualDSegment)):
616
+ s = 256 # 2x min stride
617
+ m.inplace = self.inplace
618
+ forward = lambda x: self.forward(x)[0][0] if isinstance(m, (DualDSegment)) else self.forward(x)[0]
619
+ m.stride = torch.tensor([s / x.shape[-2] for x in forward(torch.zeros(1, ch, s, s))]) # forward
620
+ # check_anchor_order(m)
621
+ # m.anchors /= m.stride.view(-1, 1, 1)
622
+ self.stride = m.stride
623
+ m.bias_init() # only run once
624
+
625
+ # Init weights, biases
626
+ initialize_weights(self)
627
+ self.info()
628
+ LOGGER.info('')
629
+
630
+ def forward(self, x, augment=False, profile=False, visualize=False):
631
+ if augment:
632
+ return self._forward_augment(x) # augmented inference, None
633
+ return self._forward_once(x, profile, visualize) # single-scale inference, train
634
+
635
+ def _forward_augment(self, x):
636
+ img_size = x.shape[-2:] # height, width
637
+ s = [1, 0.83, 0.67] # scales
638
+ f = [None, 3, None] # flips (2-ud, 3-lr)
639
+ y = [] # outputs
640
+ for si, fi in zip(s, f):
641
+ xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max()))
642
+ yi = self._forward_once(xi)[0] # forward
643
+ # cv2.imwrite(f'img_{si}.jpg', 255 * xi[0].cpu().numpy().transpose((1, 2, 0))[:, :, ::-1]) # save
644
+ yi = self._descale_pred(yi, fi, si, img_size)
645
+ y.append(yi)
646
+ y = self._clip_augmented(y) # clip augmented tails
647
+ return torch.cat(y, 1), None # augmented inference, train
648
+
649
+ def _descale_pred(self, p, flips, scale, img_size):
650
+ # de-scale predictions following augmented inference (inverse operation)
651
+ if self.inplace:
652
+ p[..., :4] /= scale # de-scale
653
+ if flips == 2:
654
+ p[..., 1] = img_size[0] - p[..., 1] # de-flip ud
655
+ elif flips == 3:
656
+ p[..., 0] = img_size[1] - p[..., 0] # de-flip lr
657
+ else:
658
+ x, y, wh = p[..., 0:1] / scale, p[..., 1:2] / scale, p[..., 2:4] / scale # de-scale
659
+ if flips == 2:
660
+ y = img_size[0] - y # de-flip ud
661
+ elif flips == 3:
662
+ x = img_size[1] - x # de-flip lr
663
+ p = torch.cat((x, y, wh, p[..., 4:]), -1)
664
+ return p
665
+
666
+ def _clip_augmented(self, y):
667
+ # Clip YOLO augmented inference tails
668
+ nl = self.model[-1].nl # number of detection layers (P3-P5)
669
+ g = sum(4 ** x for x in range(nl)) # grid points
670
+ e = 1 # exclude layer count
671
+ i = (y[0].shape[1] // g) * sum(4 ** x for x in range(e)) # indices
672
+ y[0] = y[0][:, :-i] # large
673
+ i = (y[-1].shape[1] // g) * sum(4 ** (nl - 1 - x) for x in range(e)) # indices
674
+ y[-1] = y[-1][:, i:] # small
675
+ return y
676
+
677
+
678
+ Model = DetectionModel # retain YOLO 'Model' class for backwards compatibility
679
+
680
+
681
+ class SegmentationModel(DetectionModel):
682
+ # YOLO segmentation model
683
+ def __init__(self, cfg='yolo-seg.yaml', ch=3, nc=None, anchors=None):
684
+ super().__init__(cfg, ch, nc, anchors)
685
+
686
+
687
+ class ClassificationModel(BaseModel):
688
+ # YOLO classification model
689
+ def __init__(self, cfg=None, model=None, nc=1000, cutoff=10): # yaml, model, number of classes, cutoff index
690
+ super().__init__()
691
+ self._from_detection_model(model, nc, cutoff) if model is not None else self._from_yaml(cfg)
692
+
693
+ def _from_detection_model(self, model, nc=1000, cutoff=10):
694
+ # Create a YOLO classification model from a YOLO detection model
695
+ if isinstance(model, DetectMultiBackend):
696
+ model = model.model # unwrap DetectMultiBackend
697
+ model.model = model.model[:cutoff] # backbone
698
+ m = model.model[-1] # last layer
699
+ ch = m.conv.in_channels if hasattr(m, 'conv') else m.cv1.conv.in_channels # ch into module
700
+ c = Classify(ch, nc) # Classify()
701
+ c.i, c.f, c.type = m.i, m.f, 'models.common.Classify' # index, from, type
702
+ model.model[-1] = c # replace
703
+ self.model = model.model
704
+ self.stride = model.stride
705
+ self.save = []
706
+ self.nc = nc
707
+
708
+ def _from_yaml(self, cfg):
709
+ # Create a YOLO classification model from a *.yaml file
710
+ self.model = None
711
+
712
+
713
+ def parse_model(d, ch): # model_dict, input_channels(3)
714
+ # Parse a YOLO model.yaml dictionary
715
+ LOGGER.info(f"\n{'':>3}{'from':>18}{'n':>3}{'params':>10} {'module':<40}{'arguments':<30}")
716
+ anchors, nc, gd, gw, act = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple'], d.get('activation')
717
+ if act:
718
+ Conv.default_act = eval(act) # redefine default activation, i.e. Conv.default_act = nn.SiLU()
719
+ RepConvN.default_act = eval(act) # redefine default activation, i.e. Conv.default_act = nn.SiLU()
720
+ LOGGER.info(f"{colorstr('activation:')} {act}") # print
721
+ na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors
722
+ no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
723
+
724
+ layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
725
+ for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
726
+ m = eval(m) if isinstance(m, str) else m # eval strings
727
+ for j, a in enumerate(args):
728
+ with contextlib.suppress(NameError):
729
+ args[j] = eval(a) if isinstance(a, str) else a # eval strings
730
+
731
+ n = n_ = max(round(n * gd), 1) if n > 1 else n # depth gain
732
+ if m in {
733
+ Conv, AConv, ConvTranspose,
734
+ Bottleneck, SPP, SPPF, DWConv, BottleneckCSP, nn.ConvTranspose2d, DWConvTranspose2d, SPPCSPC, ADown,
735
+ RepNCSPELAN4, SPPELAN}:
736
+ c1, c2 = ch[f], args[0]
737
+ if c2 != no: # if not output
738
+ c2 = make_divisible(c2 * gw, 8)
739
+
740
+ args = [c1, c2, *args[1:]]
741
+ if m in {BottleneckCSP, SPPCSPC}:
742
+ args.insert(2, n) # number of repeats
743
+ n = 1
744
+ elif m is nn.BatchNorm2d:
745
+ args = [ch[f]]
746
+ elif m is Concat:
747
+ c2 = sum(ch[x] for x in f)
748
+ elif m is Shortcut:
749
+ c2 = ch[f[0]]
750
+ elif m is ReOrg:
751
+ c2 = ch[f] * 4
752
+ elif m is CBLinear:
753
+ c2 = args[0]
754
+ c1 = ch[f]
755
+ args = [c1, c2, *args[1:]]
756
+ elif m is CBFuse:
757
+ c2 = ch[f[-1]]
758
+ # TODO: channel, gw, gd
759
+ elif m in {Detect, DualDetect, TripleDetect, DDetect, DualDDetect, TripleDDetect, Segment, DSegment, DualDSegment, Panoptic}:
760
+ args.append([ch[x] for x in f])
761
+ # if isinstance(args[1], int): # number of anchors
762
+ # args[1] = [list(range(args[1] * 2))] * len(f)
763
+ if m in {Segment, DSegment, DualDSegment, Panoptic}:
764
+ args[2] = make_divisible(args[2] * gw, 8)
765
+ elif m is Contract:
766
+ c2 = ch[f] * args[0] ** 2
767
+ elif m is Expand:
768
+ c2 = ch[f] // args[0] ** 2
769
+ else:
770
+ c2 = ch[f]
771
+
772
+ m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
773
+ t = str(m)[8:-2].replace('__main__.', '') # module type
774
+ np = sum(x.numel() for x in m_.parameters()) # number params
775
+ m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params
776
+ LOGGER.info(f'{i:>3}{str(f):>18}{n_:>3}{np:10.0f} {t:<40}{str(args):<30}') # print
777
+ save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
778
+ layers.append(m_)
779
+ if i == 0:
780
+ ch = []
781
+ ch.append(c2)
782
+ return nn.Sequential(*layers), sorted(save)
783
+
784
+
785
+ if __name__ == '__main__':
786
+ parser = argparse.ArgumentParser()
787
+ parser.add_argument('--cfg', type=str, default='yolo.yaml', help='model.yaml')
788
+ parser.add_argument('--batch-size', type=int, default=1, help='total batch size for all GPUs')
789
+ parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
790
+ parser.add_argument('--profile', action='store_true', help='profile model speed')
791
+ parser.add_argument('--line-profile', action='store_true', help='profile model speed layer by layer')
792
+ parser.add_argument('--test', action='store_true', help='test all yolo*.yaml')
793
+ opt = parser.parse_args()
794
+ opt.cfg = check_yaml(opt.cfg) # check YAML
795
+ print_args(vars(opt))
796
+ device = select_device(opt.device)
797
+
798
+ # Create model
799
+ im = torch.rand(opt.batch_size, 3, 640, 640).to(device)
800
+ model = Model(opt.cfg).to(device)
801
+ model.eval()
802
+
803
+ # Options
804
+ if opt.line_profile: # profile layer by layer
805
+ model(im, profile=True)
806
+
807
+ elif opt.profile: # profile forward-backward
808
+ results = profile(input=im, ops=[model], n=3)
809
+
810
+ elif opt.test: # test all models
811
+ for cfg in Path(ROOT / 'models').rglob('yolo*.yaml'):
812
+ try:
813
+ _ = Model(cfg)
814
+ except Exception as e:
815
+ print(f'Error in {cfg}: {e}')
816
+
817
+ else: # report fused model summary
818
+ model.fuse()