glenn-jocher commited on
Commit
ec7a926
·
1 Parent(s): d989bc9

AutoAnchor update to display anchors/target

Browse files
Files changed (1) hide show
  1. utils/utils.py +7 -5
utils/utils.py CHANGED
@@ -84,15 +84,17 @@ def check_anchors(dataset, model, thr=4.0, imgsz=640):
84
  r = wh[:, None] / k[None]
85
  x = torch.min(r, 1. / r).min(2)[0] # ratio metric
86
  best = x.max(1)[0] # best_x
87
- return (best > 1. / thr).float().mean() # best possible recall
 
 
88
 
89
- bpr = metric(m.anchor_grid.clone().cpu().view(-1, 2))
90
- print('Best Possible Recall (BPR) = %.4f' % bpr, end='')
91
- if bpr < 0.99: # threshold to recompute
92
  print('. Attempting to generate improved anchors, please wait...' % bpr)
93
  na = m.anchor_grid.numel() // 2 # number of anchors
94
  new_anchors = kmean_anchors(dataset, n=na, img_size=imgsz, thr=thr, gen=1000, verbose=False)
95
- new_bpr = metric(new_anchors.reshape(-1, 2))
96
  if new_bpr > bpr: # replace anchors
97
  new_anchors = torch.tensor(new_anchors, device=m.anchors.device).type_as(m.anchors)
98
  m.anchor_grid[:] = new_anchors.clone().view_as(m.anchor_grid) # for inference
 
84
  r = wh[:, None] / k[None]
85
  x = torch.min(r, 1. / r).min(2)[0] # ratio metric
86
  best = x.max(1)[0] # best_x
87
+ aat = (x > 1. / thr).float().sum(1).mean() # anchors above threshold
88
+ bpr = (best > 1. / thr).float().mean() # best possible recall
89
+ return bpr, aat
90
 
91
+ bpr, aat = metric(m.anchor_grid.clone().cpu().view(-1, 2))
92
+ print('anchors/target = %.2f, Best Possible Recall (BPR) = %.4f' % (aat, bpr), end='')
93
+ if bpr < 0.98: # threshold to recompute
94
  print('. Attempting to generate improved anchors, please wait...' % bpr)
95
  na = m.anchor_grid.numel() // 2 # number of anchors
96
  new_anchors = kmean_anchors(dataset, n=na, img_size=imgsz, thr=thr, gen=1000, verbose=False)
97
+ new_bpr = metric(new_anchors.reshape(-1, 2))[0]
98
  if new_bpr > bpr: # replace anchors
99
  new_anchors = torch.tensor(new_anchors, device=m.anchors.device).type_as(m.anchors)
100
  m.anchor_grid[:] = new_anchors.clone().view_as(m.anchor_grid) # for inference