glenn-jocher commited on
Commit
53bfcbe
·
unverified ·
1 Parent(s): cd540d8

Update AP calculation (#4260)

Browse files

* Update AP calculation

* Cleanup

* Remove original

Files changed (1) hide show
  1. val.py +21 -20
val.py CHANGED
@@ -50,26 +50,27 @@ def save_one_json(predn, jdict, path, class_map):
50
  'score': round(p[4], 5)})
51
 
52
 
53
- def process_batch(predictions, labels, iouv):
54
- # Evaluate 1 batch of predictions
55
- correct = torch.zeros(predictions.shape[0], len(iouv), dtype=torch.bool, device=iouv.device)
56
- detected = [] # label indices
57
- tcls, pcls = labels[:, 0], predictions[:, 5]
58
- nl = labels.shape[0] # number of labels
59
- for cls in torch.unique(tcls):
60
- ti = (cls == tcls).nonzero().view(-1) # label indices
61
- pi = (cls == pcls).nonzero().view(-1) # prediction indices
62
- if pi.shape[0]: # find detections
63
- ious, i = box_iou(predictions[pi, 0:4], labels[ti, 1:5]).max(1) # best ious, indices
64
- detected_set = set()
65
- for j in (ious > iouv[0]).nonzero():
66
- d = ti[i[j]] # detected label
67
- if d.item() not in detected_set:
68
- detected_set.add(d.item())
69
- detected.append(d) # append detections
70
- correct[pi[j]] = ious[j] > iouv # iou_thres is 1xn
71
- if len(detected) == nl: # all labels already located in image
72
- break
 
73
  return correct
74
 
75
 
 
50
  'score': round(p[4], 5)})
51
 
52
 
53
+ def process_batch(detections, labels, iouv):
54
+ """
55
+ Return correct predictions matrix. Both sets of boxes are in (x1, y1, x2, y2) format.
56
+ Arguments:
57
+ detections (Array[N, 6]), x1, y1, x2, y2, conf, class
58
+ labels (Array[M, 5]), class, x1, y1, x2, y2
59
+ Returns:
60
+ correct (Array[N, 10]), for 10 IoU levels
61
+ """
62
+ correct = torch.zeros(detections.shape[0], iouv.shape[0], dtype=torch.bool, device=iouv.device)
63
+ iou = box_iou(labels[:, 1:], detections[:, :4])
64
+ x = torch.where((iou >= iouv[0]) & (labels[:, 0:1] == detections[:, 5])) # IoU above threshold and classes match
65
+ if x[0].shape[0]:
66
+ matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy() # [label, detection, iou]
67
+ if x[0].shape[0] > 1:
68
+ matches = matches[matches[:, 2].argsort()[::-1]]
69
+ matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
70
+ # matches = matches[matches[:, 2].argsort()[::-1]]
71
+ matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
72
+ matches = torch.Tensor(matches).to(iouv.device)
73
+ correct[matches[:, 1].long()] = matches[:, 2:3] >= iouv
74
  return correct
75
 
76