glenn-jocher commited on
Commit
57a0ae3
·
1 Parent(s): 95c46f7

AutoAnchor implementation

Browse files
Files changed (1) hide show
  1. utils/utils.py +45 -50
utils/utils.py CHANGED
@@ -53,18 +53,23 @@ def check_img_size(img_size, s=32):
53
 
54
 
55
  def check_anchors(dataset, model, thr=4.0, imgsz=640):
56
- # Check best possible recall of dataset with current anchors
 
57
  anchors = model.module.model[-1].anchor_grid if hasattr(model, 'module') else model.model[-1].anchor_grid
58
  shapes = imgsz * dataset.shapes / dataset.shapes.max(1, keepdims=True)
59
  wh = torch.tensor(np.concatenate([l[:, 3:5] * s for s, l in zip(shapes, dataset.labels)])).float() # wh
60
  ratio = wh[:, None] / anchors.view(-1, 2).cpu()[None] # ratio
61
  m = torch.max(ratio, 1. / ratio).max(2)[0] # max ratio
62
  bpr = (m.min(1)[0] < thr).float().mean() # best possible recall
63
- mr = (m < thr).float().mean() # match ratio
64
- print(('AutoAnchor labels:' + '%10s' * 6) % ('n', 'mean', 'min', 'max', 'matching', 'recall'))
65
- print((' ' + '%10.4g' * 6) % (wh.shape[0], wh.mean(), wh.min(), wh.max(), mr, bpr))
66
- assert bpr > 0.9, 'Best possible recall %.3g (BPR) below 0.9 threshold. Training cancelled. ' \
67
- 'Compute new anchors with utils.utils.kmeans_anchors() and update model before training.' % bpr
 
 
 
 
68
 
69
 
70
  def check_file(file):
@@ -689,14 +694,14 @@ def coco_single_class_labels(path='../coco/labels/train2014/', label_class=43):
689
  shutil.copyfile(src=img_file, dst='new/images/' + Path(file).name.replace('txt', 'jpg')) # copy images
690
 
691
 
692
- def kmean_anchors(path='./data/coco128.yaml', n=9, img_size=(640, 640), thr=0.20, gen=1000):
693
  """ Creates kmeans-evolved anchors from training dataset
694
 
695
  Arguments:
696
- path: path to dataset *.yaml
697
  n: number of anchors
698
- img_size: (min, max) image size used for multi-scale training (can be same values)
699
- thr: IoU threshold hyperparameter used for training (0.0 - 1.0)
700
  gen: generations to evolve anchors using genetic algorithm
701
 
702
  Return:
@@ -705,52 +710,41 @@ def kmean_anchors(path='./data/coco128.yaml', n=9, img_size=(640, 640), thr=0.20
705
  Usage:
706
  from utils.utils import *; _ = kmean_anchors()
707
  """
 
 
 
 
 
 
 
708
 
709
- from utils.datasets import LoadImagesAndLabels
 
 
710
 
711
  def print_results(k):
712
  k = k[np.argsort(k.prod(1))] # sort small to large
713
- iou = wh_iou(wh, torch.Tensor(k))
714
- max_iou = iou.max(1)[0]
715
- bpr, aat = (max_iou > thr).float().mean(), (iou > thr).float().mean() * n # best possible recall, anch > thr
716
-
717
- # thr = 5.0
718
- # r = wh[:, None] / k[None]
719
- # ar = torch.max(r, 1. / r).max(2)[0]
720
- # max_ar = ar.min(1)[0]
721
- # bpr, aat = (max_ar < thr).float().mean(), (ar < thr).float().mean() * n # best possible recall, anch > thr
722
-
723
- print('%.2f iou_thr: %.3f best possible recall, %.2f anchors > thr' % (thr, bpr, aat))
724
- print('n=%g, img_size=%s, IoU_all=%.3f/%.3f-mean/best, IoU>thr=%.3f-mean: ' %
725
- (n, img_size, iou.mean(), max_iou.mean(), iou[iou > thr].mean()), end='')
726
  for i, x in enumerate(k):
727
  print('%i,%i' % (round(x[0]), round(x[1])), end=', ' if i < len(k) - 1 else '\n') # use in *.cfg
728
  return k
729
 
730
- def fitness(k): # mutation fitness
731
- iou = wh_iou(wh, torch.Tensor(k)) # iou
732
- max_iou = iou.max(1)[0]
733
- return (max_iou * (max_iou > thr).float()).mean() # product
734
-
735
- # def fitness_ratio(k): # mutation fitness
736
- # # wh(5316,2), k(9,2)
737
- # r = wh[:, None] / k[None]
738
- # x = torch.max(r, 1. / r).max(2)[0]
739
- # m = x.min(1)[0]
740
- # return 1. / (m * (m < 5).float()).mean() # product
741
 
742
  # Get label wh
743
- wh = []
744
- with open(path) as f:
745
- data_dict = yaml.load(f, Loader=yaml.FullLoader) # model dict
746
- dataset = LoadImagesAndLabels(data_dict['train'], augment=True, rect=True)
747
- nr = 1 if img_size[0] == img_size[1] else 3 # number augmentation repetitions
748
- for s, l in zip(dataset.shapes, dataset.labels):
749
- # wh.append(l[:, 3:5] * (s / s.max())) # image normalized to letterbox normalized wh
750
- wh.append(l[:, 3:5] * s) # image normalized to pixels
751
- wh = np.concatenate(wh, 0).repeat(nr, axis=0) # augment 3x
752
- # wh *= np.random.uniform(img_size[0], img_size[1], size=(wh.shape[0], 1)) # normalized to pixels (multi-scale)
753
- wh = wh[(wh > 2.0).all(1)] # remove below threshold boxes (< 2 pixels wh)
754
 
755
  # Kmeans calculation
756
  from scipy.cluster.vq import kmeans
@@ -758,10 +752,10 @@ def kmean_anchors(path='./data/coco128.yaml', n=9, img_size=(640, 640), thr=0.20
758
  s = wh.std(0) # sigmas for whitening
759
  k, dist = kmeans(wh / s, n, iter=30) # points, mean distance
760
  k *= s
761
- wh = torch.Tensor(wh)
762
  k = print_results(k)
763
 
764
- # # Plot
765
  # k, d = [None] * 20, [None] * 20
766
  # for i in tqdm(range(1, 21)):
767
  # k[i-1], d[i-1] = kmeans(wh / s, i) # points, mean distance
@@ -777,7 +771,7 @@ def kmean_anchors(path='./data/coco128.yaml', n=9, img_size=(640, 640), thr=0.20
777
  # Evolve
778
  npr = np.random
779
  f, sh, mp, s = fitness(k), k.shape, 0.9, 0.1 # fitness, generations, mutation prob, sigma
780
- for _ in tqdm(range(gen), desc='Evolving anchors'):
781
  v = np.ones(sh)
782
  while (v == 1).all(): # mutate until a change occurs (prevent duplicates)
783
  v = ((npr.random(sh) < mp) * npr.random() * npr.randn(*sh) * s + 1).clip(0.3, 3.0)
@@ -785,7 +779,8 @@ def kmean_anchors(path='./data/coco128.yaml', n=9, img_size=(640, 640), thr=0.20
785
  fg = fitness(kg)
786
  if fg > f:
787
  f, k = fg, kg.copy()
788
- print_results(k)
 
789
  k = print_results(k)
790
  return k
791
 
 
53
 
54
 
55
  def check_anchors(dataset, model, thr=4.0, imgsz=640):
56
+ # Check anchor fit to data, recompute if necessary
57
+ print('\nAnalyzing anchors... ', end='')
58
  anchors = model.module.model[-1].anchor_grid if hasattr(model, 'module') else model.model[-1].anchor_grid
59
  shapes = imgsz * dataset.shapes / dataset.shapes.max(1, keepdims=True)
60
  wh = torch.tensor(np.concatenate([l[:, 3:5] * s for s, l in zip(shapes, dataset.labels)])).float() # wh
61
  ratio = wh[:, None] / anchors.view(-1, 2).cpu()[None] # ratio
62
  m = torch.max(ratio, 1. / ratio).max(2)[0] # max ratio
63
  bpr = (m.min(1)[0] < thr).float().mean() # best possible recall
64
+ # mr = (m < thr).float().mean() # match ratio
65
+
66
+ print('Best Possible Recall (BPR) = %.3f' % bpr, end='')
67
+ if bpr < 0.99: # threshold to recompute
68
+ print('. Generating new anchors for improved recall, please wait...' % bpr)
69
+ new_anchors = kmean_anchors(dataset, n=9, img_size=640, thr=4.0, gen=1000, verbose=False)
70
+ anchors[:] = torch.tensor(new_anchors).view_as(anchors).type_as(anchors)
71
+ print('New anchors saved to model. Update model *.yaml to use these anchors in the future.')
72
+ print('') # newline
73
 
74
 
75
  def check_file(file):
 
694
  shutil.copyfile(src=img_file, dst='new/images/' + Path(file).name.replace('txt', 'jpg')) # copy images
695
 
696
 
697
+ def kmean_anchors(path='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=1000, verbose=True):
698
  """ Creates kmeans-evolved anchors from training dataset
699
 
700
  Arguments:
701
+ path: path to dataset *.yaml, or a loaded dataset
702
  n: number of anchors
703
+ img_size: image size used for training
704
+ thr: anchor-label wh ratio threshold hyperparameter hyp['anchor_t'] used for training, default=4.0
705
  gen: generations to evolve anchors using genetic algorithm
706
 
707
  Return:
 
710
  Usage:
711
  from utils.utils import *; _ = kmean_anchors()
712
  """
713
+ thr = 1. / thr
714
+
715
+ def metric(k): # compute metrics
716
+ r = wh[:, None] / k[None]
717
+ x = torch.min(r, 1. / r).min(2)[0] # ratio metric
718
+ # x = wh_iou(wh, torch.tensor(k)) # iou metric
719
+ return x, x.max(1)[0] # x, best_x
720
 
721
+ def fitness(k): # mutation fitness
722
+ _, best = metric(k)
723
+ return (best * (best > thr).float()).mean() # fitness
724
 
725
  def print_results(k):
726
  k = k[np.argsort(k.prod(1))] # sort small to large
727
+ x, best = metric(k)
728
+ bpr, aat = (best > thr).float().mean(), (x > thr).float().mean() * n # best possible recall, anch > thr
729
+ print('thr=%.2f: %.3f best possible recall, %.2f anchors past thr' % (thr, bpr, aat))
730
+ print('n=%g, img_size=%s, metric_all=%.3f/%.3f-mean/best, past_thr=%.3f-mean: ' %
731
+ (n, img_size, x.mean(), best.mean(), x[x > thr].mean()), end='')
 
 
 
 
 
 
 
 
732
  for i, x in enumerate(k):
733
  print('%i,%i' % (round(x[0]), round(x[1])), end=', ' if i < len(k) - 1 else '\n') # use in *.cfg
734
  return k
735
 
736
+ if isinstance(path, str): # *.yaml file
737
+ with open(path) as f:
738
+ data_dict = yaml.load(f, Loader=yaml.FullLoader) # model dict
739
+ from utils.datasets import LoadImagesAndLabels
740
+ dataset = LoadImagesAndLabels(data_dict['train'], augment=True, rect=True)
741
+ else:
742
+ dataset = path # dataset
 
 
 
 
743
 
744
  # Get label wh
745
+ shapes = img_size * dataset.shapes / dataset.shapes.max(1, keepdims=True)
746
+ wh = torch.tensor(np.concatenate([l[:, 3:5] * s for s, l in zip(shapes, dataset.labels)])).float() # wh
747
+ wh = wh[(wh > 2.0).all(1)].numpy() # filter > 2 pixels
 
 
 
 
 
 
 
 
748
 
749
  # Kmeans calculation
750
  from scipy.cluster.vq import kmeans
 
752
  s = wh.std(0) # sigmas for whitening
753
  k, dist = kmeans(wh / s, n, iter=30) # points, mean distance
754
  k *= s
755
+ wh = torch.tensor(wh)
756
  k = print_results(k)
757
 
758
+ # Plot
759
  # k, d = [None] * 20, [None] * 20
760
  # for i in tqdm(range(1, 21)):
761
  # k[i-1], d[i-1] = kmeans(wh / s, i) # points, mean distance
 
771
  # Evolve
772
  npr = np.random
773
  f, sh, mp, s = fitness(k), k.shape, 0.9, 0.1 # fitness, generations, mutation prob, sigma
774
+ for _ in tqdm(range(gen), desc='Evolving anchors with Genetic Algorithm:'):
775
  v = np.ones(sh)
776
  while (v == 1).all(): # mutate until a change occurs (prevent duplicates)
777
  v = ((npr.random(sh) < mp) * npr.random() * npr.randn(*sh) * s + 1).clip(0.3, 3.0)
 
779
  fg = fitness(kg)
780
  if fg > f:
781
  f, k = fg, kg.copy()
782
+ if verbose:
783
+ print_results(k)
784
  k = print_results(k)
785
  return k
786