Commit
·
fc171e2
1
Parent(s):
1f1917e
check_anchor_order() update
Browse files- models/yolo.py +2 -1
- utils/utils.py +13 -1
models/yolo.py
CHANGED
@@ -61,8 +61,9 @@ class Model(nn.Module):
|
|
61 |
|
62 |
# Build strides, anchors
|
63 |
m = self.model[-1] # Detect()
|
64 |
-
m.stride = torch.tensor([
|
65 |
m.anchors /= m.stride.view(-1, 1, 1)
|
|
|
66 |
self.stride = m.stride
|
67 |
|
68 |
# Init weights, biases
|
|
|
61 |
|
62 |
# Build strides, anchors
|
63 |
m = self.model[-1] # Detect()
|
64 |
+
m.stride = torch.tensor([128 / x.shape[-2] for x in self.forward(torch.zeros(1, ch, 128, 128))]) # forward
|
65 |
m.anchors /= m.stride.view(-1, 1, 1)
|
66 |
+
check_anchor_order(m)
|
67 |
self.stride = m.stride
|
68 |
|
69 |
# Init weights, biases
|
utils/utils.py
CHANGED
@@ -58,7 +58,8 @@ def check_anchors(dataset, model, thr=4.0, imgsz=640):
|
|
58 |
print('\nAnalyzing anchors... ', end='')
|
59 |
m = model.module.model[-1] if hasattr(model, 'module') else model.model[-1] # Detect()
|
60 |
shapes = imgsz * dataset.shapes / dataset.shapes.max(1, keepdims=True)
|
61 |
-
|
|
|
62 |
|
63 |
def metric(k): # compute metric
|
64 |
r = wh[:, None] / k[None]
|
@@ -77,12 +78,23 @@ def check_anchors(dataset, model, thr=4.0, imgsz=640):
|
|
77 |
new_anchors = torch.tensor(new_anchors, device=m.anchors.device).type_as(m.anchors)
|
78 |
m.anchor_grid[:] = new_anchors.clone().view_as(m.anchor_grid) # for inference
|
79 |
m.anchors[:] = new_anchors.clone().view_as(m.anchors) / m.stride.to(m.anchors.device).view(-1, 1, 1) # loss
|
|
|
80 |
print('New anchors saved to model. Update model *.yaml to use these anchors in the future.')
|
81 |
else:
|
82 |
print('Original anchors better than new anchors. Proceeding with original anchors.')
|
83 |
print('') # newline
|
84 |
|
85 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
def check_file(file):
|
87 |
# Searches for file if not found locally
|
88 |
if os.path.isfile(file):
|
|
|
58 |
print('\nAnalyzing anchors... ', end='')
|
59 |
m = model.module.model[-1] if hasattr(model, 'module') else model.model[-1] # Detect()
|
60 |
shapes = imgsz * dataset.shapes / dataset.shapes.max(1, keepdims=True)
|
61 |
+
scale = np.random.uniform(0.9, 1.1, size=(shapes.shape[0], 1)) # augment scale
|
62 |
+
wh = torch.tensor(np.concatenate([l[:, 3:5] * s for s, l in zip(shapes * scale, dataset.labels)])).float() # wh
|
63 |
|
64 |
def metric(k): # compute metric
|
65 |
r = wh[:, None] / k[None]
|
|
|
78 |
new_anchors = torch.tensor(new_anchors, device=m.anchors.device).type_as(m.anchors)
|
79 |
m.anchor_grid[:] = new_anchors.clone().view_as(m.anchor_grid) # for inference
|
80 |
m.anchors[:] = new_anchors.clone().view_as(m.anchors) / m.stride.to(m.anchors.device).view(-1, 1, 1) # loss
|
81 |
+
check_anchor_order(m)
|
82 |
print('New anchors saved to model. Update model *.yaml to use these anchors in the future.')
|
83 |
else:
|
84 |
print('Original anchors better than new anchors. Proceeding with original anchors.')
|
85 |
print('') # newline
|
86 |
|
87 |
|
88 |
+
def check_anchor_order(m):
|
89 |
+
# Check anchor order against stride order for YOLOv5 Detect() module m, and correct if necessary
|
90 |
+
a = m.anchor_grid.prod(-1).view(-1) # anchor area
|
91 |
+
da = a[-1] - a[0] # delta a
|
92 |
+
ds = m.stride[-1] - m.stride[0] # delta s
|
93 |
+
if da.sign() != ds.sign(): # same order
|
94 |
+
m.anchors[:] = m.anchors.flip(0)
|
95 |
+
m.anchor_grid[:] = m.anchor_grid.flip(0)
|
96 |
+
|
97 |
+
|
98 |
def check_file(file):
|
99 |
# Searches for file if not found locally
|
100 |
if os.path.isfile(file):
|