Refactor `Detect()` anchors for ONNX <> OpenCV DNN compatibility (#4833)
Browse files* refactor anchors and anchor_grid in Detect Layer
* fix CI failures by adding compatibility
* fix tf failure
* fix different devices errors
* Cleanup
* fix anchors overwriting issue
* better refactoring
* Remove self.anchor_grid shape check (redundant with self.grid check)
Also PEP8 / 120 line width
* Convert _make_grid() from static to dynamic method
* Remove anchor_grid.to(device)
clone() should already clone to same device as self.anchors
* fix different devices error
Co-authored-by: Glenn Jocher <[email protected]>
- models/common.py +2 -0
- models/experimental.py +4 -0
- models/tf.py +1 -1
- models/yolo.py +13 -9
- utils/autoanchor.py +4 -6
models/common.py
CHANGED
@@ -295,6 +295,8 @@ class AutoShape(nn.Module):
|
|
295 |
m = self.model.model[-1] # Detect()
|
296 |
m.stride = fn(m.stride)
|
297 |
m.grid = list(map(fn, m.grid))
|
|
|
|
|
298 |
return self
|
299 |
|
300 |
@torch.no_grad()
|
|
|
295 |
m = self.model.model[-1] # Detect()
|
296 |
m.stride = fn(m.stride)
|
297 |
m.grid = list(map(fn, m.grid))
|
298 |
+
if isinstance(m.anchor_grid, list):
|
299 |
+
m.anchor_grid = list(map(fn, m.anchor_grid))
|
300 |
return self
|
301 |
|
302 |
@torch.no_grad()
|
models/experimental.py
CHANGED
@@ -102,6 +102,10 @@ def attempt_load(weights, map_location=None, inplace=True, fuse=True):
|
|
102 |
for m in model.modules():
|
103 |
if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model]:
|
104 |
m.inplace = inplace # pytorch 1.7.0 compatibility
|
|
|
|
|
|
|
|
|
105 |
elif type(m) is Conv:
|
106 |
m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
|
107 |
|
|
|
102 |
for m in model.modules():
|
103 |
if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model]:
|
104 |
m.inplace = inplace # pytorch 1.7.0 compatibility
|
105 |
+
if type(m) is Detect:
|
106 |
+
if not isinstance(m.anchor_grid, list): # new Detect Layer compatibility
|
107 |
+
delattr(m, 'anchor_grid')
|
108 |
+
setattr(m, 'anchor_grid', [torch.zeros(1)] * m.nl)
|
109 |
elif type(m) is Conv:
|
110 |
m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
|
111 |
|
models/tf.py
CHANGED
@@ -193,7 +193,7 @@ class TFDetect(keras.layers.Layer):
|
|
193 |
self.na = len(anchors[0]) // 2 # number of anchors
|
194 |
self.grid = [tf.zeros(1)] * self.nl # init grid
|
195 |
self.anchors = tf.convert_to_tensor(w.anchors.numpy(), dtype=tf.float32)
|
196 |
-
self.anchor_grid = tf.reshape(tf.
|
197 |
[self.nl, 1, -1, 1, 2])
|
198 |
self.m = [TFConv2d(x, self.no * self.na, 1, w=w.m[i]) for i, x in enumerate(ch)]
|
199 |
self.training = False # set to False after building model
|
|
|
193 |
self.na = len(anchors[0]) // 2 # number of anchors
|
194 |
self.grid = [tf.zeros(1)] * self.nl # init grid
|
195 |
self.anchors = tf.convert_to_tensor(w.anchors.numpy(), dtype=tf.float32)
|
196 |
+
self.anchor_grid = tf.reshape(self.anchors * tf.reshape(self.stride, [self.nl, 1, 1]),
|
197 |
[self.nl, 1, -1, 1, 2])
|
198 |
self.m = [TFConv2d(x, self.no * self.na, 1, w=w.m[i]) for i, x in enumerate(ch)]
|
199 |
self.training = False # set to False after building model
|
models/yolo.py
CHANGED
@@ -44,9 +44,8 @@ class Detect(nn.Module):
|
|
44 |
self.nl = len(anchors) # number of detection layers
|
45 |
self.na = len(anchors[0]) // 2 # number of anchors
|
46 |
self.grid = [torch.zeros(1)] * self.nl # init grid
|
47 |
-
|
48 |
-
self.register_buffer('anchors',
|
49 |
-
self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2)
|
50 |
self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
|
51 |
self.inplace = inplace # use in-place ops (e.g. slice assignment)
|
52 |
|
@@ -59,7 +58,7 @@ class Detect(nn.Module):
|
|
59 |
|
60 |
if not self.training: # inference
|
61 |
if self.grid[i].shape[2:4] != x[i].shape[2:4] or self.onnx_dynamic:
|
62 |
-
self.grid[i] = self._make_grid(nx, ny
|
63 |
|
64 |
y = x[i].sigmoid()
|
65 |
if self.inplace:
|
@@ -67,16 +66,19 @@ class Detect(nn.Module):
|
|
67 |
y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
|
68 |
else: # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953
|
69 |
xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
|
70 |
-
wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]
|
71 |
y = torch.cat((xy, wh, y[..., 4:]), -1)
|
72 |
z.append(y.view(bs, -1, self.no))
|
73 |
|
74 |
return x if self.training else (torch.cat(z, 1), x)
|
75 |
|
76 |
-
|
77 |
-
|
78 |
-
yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
|
79 |
-
|
|
|
|
|
|
|
80 |
|
81 |
|
82 |
class Model(nn.Module):
|
@@ -239,6 +241,8 @@ class Model(nn.Module):
|
|
239 |
if isinstance(m, Detect):
|
240 |
m.stride = fn(m.stride)
|
241 |
m.grid = list(map(fn, m.grid))
|
|
|
|
|
242 |
return self
|
243 |
|
244 |
|
|
|
44 |
self.nl = len(anchors) # number of detection layers
|
45 |
self.na = len(anchors[0]) // 2 # number of anchors
|
46 |
self.grid = [torch.zeros(1)] * self.nl # init grid
|
47 |
+
self.anchor_grid = [torch.zeros(1)] * self.nl # init anchor grid
|
48 |
+
self.register_buffer('anchors', torch.tensor(anchors).float().view(self.nl, -1, 2)) # shape(nl,na,2)
|
|
|
49 |
self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
|
50 |
self.inplace = inplace # use in-place ops (e.g. slice assignment)
|
51 |
|
|
|
58 |
|
59 |
if not self.training: # inference
|
60 |
if self.grid[i].shape[2:4] != x[i].shape[2:4] or self.onnx_dynamic:
|
61 |
+
self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i)
|
62 |
|
63 |
y = x[i].sigmoid()
|
64 |
if self.inplace:
|
|
|
66 |
y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
|
67 |
else: # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953
|
68 |
xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
|
69 |
+
wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
|
70 |
y = torch.cat((xy, wh, y[..., 4:]), -1)
|
71 |
z.append(y.view(bs, -1, self.no))
|
72 |
|
73 |
return x if self.training else (torch.cat(z, 1), x)
|
74 |
|
75 |
+
def _make_grid(self, nx=20, ny=20, i=0):
|
76 |
+
d = self.anchors[i].device
|
77 |
+
yv, xv = torch.meshgrid([torch.arange(ny).to(d), torch.arange(nx).to(d)])
|
78 |
+
grid = torch.stack((xv, yv), 2).expand((1, self.na, ny, nx, 2)).float()
|
79 |
+
anchor_grid = (self.anchors[i].clone() * self.stride[i]) \
|
80 |
+
.view((1, self.na, 1, 1, 2)).expand((1, self.na, ny, nx, 2)).float()
|
81 |
+
return grid, anchor_grid
|
82 |
|
83 |
|
84 |
class Model(nn.Module):
|
|
|
241 |
if isinstance(m, Detect):
|
242 |
m.stride = fn(m.stride)
|
243 |
m.grid = list(map(fn, m.grid))
|
244 |
+
if isinstance(m.anchor_grid, list):
|
245 |
+
m.anchor_grid = list(map(fn, m.anchor_grid))
|
246 |
return self
|
247 |
|
248 |
|
utils/autoanchor.py
CHANGED
@@ -15,13 +15,12 @@ from utils.general import colorstr
|
|
15 |
|
16 |
def check_anchor_order(m):
|
17 |
# Check anchor order against stride order for YOLOv5 Detect() module m, and correct if necessary
|
18 |
-
a = m.
|
19 |
da = a[-1] - a[0] # delta a
|
20 |
ds = m.stride[-1] - m.stride[0] # delta s
|
21 |
if da.sign() != ds.sign(): # same order
|
22 |
print('Reversing anchor order')
|
23 |
m.anchors[:] = m.anchors.flip(0)
|
24 |
-
m.anchor_grid[:] = m.anchor_grid.flip(0)
|
25 |
|
26 |
|
27 |
def check_anchors(dataset, model, thr=4.0, imgsz=640):
|
@@ -41,12 +40,12 @@ def check_anchors(dataset, model, thr=4.0, imgsz=640):
|
|
41 |
bpr = (best > 1. / thr).float().mean() # best possible recall
|
42 |
return bpr, aat
|
43 |
|
44 |
-
anchors = m.
|
45 |
-
bpr, aat = metric(anchors)
|
46 |
print(f'anchors/target = {aat:.2f}, Best Possible Recall (BPR) = {bpr:.4f}', end='')
|
47 |
if bpr < 0.98: # threshold to recompute
|
48 |
print('. Attempting to improve anchors, please wait...')
|
49 |
-
na = m.
|
50 |
try:
|
51 |
anchors = kmean_anchors(dataset, n=na, img_size=imgsz, thr=thr, gen=1000, verbose=False)
|
52 |
except Exception as e:
|
@@ -54,7 +53,6 @@ def check_anchors(dataset, model, thr=4.0, imgsz=640):
|
|
54 |
new_bpr = metric(anchors)[0]
|
55 |
if new_bpr > bpr: # replace anchors
|
56 |
anchors = torch.tensor(anchors, device=m.anchors.device).type_as(m.anchors)
|
57 |
-
m.anchor_grid[:] = anchors.clone().view_as(m.anchor_grid) # for inference
|
58 |
m.anchors[:] = anchors.clone().view_as(m.anchors) / m.stride.to(m.anchors.device).view(-1, 1, 1) # loss
|
59 |
check_anchor_order(m)
|
60 |
print(f'{prefix}New anchors saved to model. Update model *.yaml to use these anchors in the future.')
|
|
|
15 |
|
16 |
def check_anchor_order(m):
|
17 |
# Check anchor order against stride order for YOLOv5 Detect() module m, and correct if necessary
|
18 |
+
a = m.anchors.prod(-1).view(-1) # anchor area
|
19 |
da = a[-1] - a[0] # delta a
|
20 |
ds = m.stride[-1] - m.stride[0] # delta s
|
21 |
if da.sign() != ds.sign(): # same order
|
22 |
print('Reversing anchor order')
|
23 |
m.anchors[:] = m.anchors.flip(0)
|
|
|
24 |
|
25 |
|
26 |
def check_anchors(dataset, model, thr=4.0, imgsz=640):
|
|
|
40 |
bpr = (best > 1. / thr).float().mean() # best possible recall
|
41 |
return bpr, aat
|
42 |
|
43 |
+
anchors = m.anchors.clone() * m.stride.to(m.anchors.device).view(-1, 1, 1) # current anchors
|
44 |
+
bpr, aat = metric(anchors.cpu().view(-1, 2))
|
45 |
print(f'anchors/target = {aat:.2f}, Best Possible Recall (BPR) = {bpr:.4f}', end='')
|
46 |
if bpr < 0.98: # threshold to recompute
|
47 |
print('. Attempting to improve anchors, please wait...')
|
48 |
+
na = m.anchors.numel() // 2 # number of anchors
|
49 |
try:
|
50 |
anchors = kmean_anchors(dataset, n=na, img_size=imgsz, thr=thr, gen=1000, verbose=False)
|
51 |
except Exception as e:
|
|
|
53 |
new_bpr = metric(anchors)[0]
|
54 |
if new_bpr > bpr: # replace anchors
|
55 |
anchors = torch.tensor(anchors, device=m.anchors.device).type_as(m.anchors)
|
|
|
56 |
m.anchors[:] = anchors.clone().view_as(m.anchors) / m.stride.to(m.anchors.device).view(-1, 1, 1) # loss
|
57 |
check_anchor_order(m)
|
58 |
print(f'{prefix}New anchors saved to model. Update model *.yaml to use these anchors in the future.')
|