glenn-jocher commited on
Commit
2d41e70
·
unverified ·
1 Parent(s): 38ff499

Scipy kmeans-robust autoanchor update (#2470)

Browse files

Fix for https://github.com/ultralytics/yolov5/issues/2394

Files changed (1) hide show
  1. utils/autoanchor.py +11 -6
utils/autoanchor.py CHANGED
@@ -37,17 +37,21 @@ def check_anchors(dataset, model, thr=4.0, imgsz=640):
37
  bpr = (best > 1. / thr).float().mean() # best possible recall
38
  return bpr, aat
39
 
40
- bpr, aat = metric(m.anchor_grid.clone().cpu().view(-1, 2))
 
41
  print(f'anchors/target = {aat:.2f}, Best Possible Recall (BPR) = {bpr:.4f}', end='')
42
  if bpr < 0.98: # threshold to recompute
43
  print('. Attempting to improve anchors, please wait...')
44
  na = m.anchor_grid.numel() // 2 # number of anchors
45
- new_anchors = kmean_anchors(dataset, n=na, img_size=imgsz, thr=thr, gen=1000, verbose=False)
46
- new_bpr = metric(new_anchors.reshape(-1, 2))[0]
 
 
 
47
  if new_bpr > bpr: # replace anchors
48
- new_anchors = torch.tensor(new_anchors, device=m.anchors.device).type_as(m.anchors)
49
- m.anchor_grid[:] = new_anchors.clone().view_as(m.anchor_grid) # for inference
50
- m.anchors[:] = new_anchors.clone().view_as(m.anchors) / m.stride.to(m.anchors.device).view(-1, 1, 1) # loss
51
  check_anchor_order(m)
52
  print(f'{prefix}New anchors saved to model. Update model *.yaml to use these anchors in the future.')
53
  else:
@@ -119,6 +123,7 @@ def kmean_anchors(path='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=10
119
  print(f'{prefix}Running kmeans for {n} anchors on {len(wh)} points...')
120
  s = wh.std(0) # sigmas for whitening
121
  k, dist = kmeans(wh / s, n, iter=30) # points, mean distance
 
122
  k *= s
123
  wh = torch.tensor(wh, dtype=torch.float32) # filtered
124
  wh0 = torch.tensor(wh0, dtype=torch.float32) # unfiltered
 
37
  bpr = (best > 1. / thr).float().mean() # best possible recall
38
  return bpr, aat
39
 
40
+ anchors = m.anchor_grid.clone().cpu().view(-1, 2) # current anchors
41
+ bpr, aat = metric(anchors)
42
  print(f'anchors/target = {aat:.2f}, Best Possible Recall (BPR) = {bpr:.4f}', end='')
43
  if bpr < 0.98: # threshold to recompute
44
  print('. Attempting to improve anchors, please wait...')
45
  na = m.anchor_grid.numel() // 2 # number of anchors
46
+ try:
47
+ anchors = kmean_anchors(dataset, n=na, img_size=imgsz, thr=thr, gen=1000, verbose=False)
48
+ except Exception as e:
49
+ print(f'{prefix}ERROR: {e}')
50
+ new_bpr = metric(anchors)[0]
51
  if new_bpr > bpr: # replace anchors
52
+ anchors = torch.tensor(anchors, device=m.anchors.device).type_as(m.anchors)
53
+ m.anchor_grid[:] = anchors.clone().view_as(m.anchor_grid) # for inference
54
+ m.anchors[:] = anchors.clone().view_as(m.anchors) / m.stride.to(m.anchors.device).view(-1, 1, 1) # loss
55
  check_anchor_order(m)
56
  print(f'{prefix}New anchors saved to model. Update model *.yaml to use these anchors in the future.')
57
  else:
 
123
  print(f'{prefix}Running kmeans for {n} anchors on {len(wh)} points...')
124
  s = wh.std(0) # sigmas for whitening
125
  k, dist = kmeans(wh / s, n, iter=30) # points, mean distance
126
+ assert len(k) == n, print(f'{prefix}ERROR: scipy.cluster.vq.kmeans requested {n} points but returned only {len(k)}')
127
  k *= s
128
  wh = torch.tensor(wh, dtype=torch.float32) # filtered
129
  wh0 = torch.tensor(wh0, dtype=torch.float32) # unfiltered