henry000 commited on
Commit
aee6ac8
Β·
1 Parent(s): e0c8580

πŸ› [Fix] #38 local mAP bugs, one pd match < 1 gt

Browse files
Files changed (1) hide show
  1. yolo/utils/bounding_box_utils.py +9 -11
yolo/utils/bounding_box_utils.py CHANGED
@@ -412,22 +412,22 @@ def calculate_map(predictions, ground_truths, iou_thresholds=arange(0.5, 1, 0.05
412
  ious = calculate_iou(predictions[:, 1:-1], ground_truths[:, 1:]) # [n_preds, n_gts]
413
 
414
  for threshold in iou_thresholds:
415
- tp = torch.zeros(n_preds, device=device)
416
- fp = torch.zeros(n_preds, device=device)
417
 
418
- max_iou, max_indices = torch.max(ious, dim=1)
419
  above_threshold = max_iou >= threshold
420
  matched_classes = predictions[:, 0] == ground_truths[max_indices, 0]
421
- tp[above_threshold & matched_classes] = 1
422
- fp[above_threshold & ~matched_classes] = 1
423
- fp[max_iou < threshold] = 1
 
 
424
 
425
  _, indices = torch.sort(predictions[:, 1], descending=True)
426
  tp = tp[indices]
427
- fp = fp[indices]
428
 
429
  tp_cumsum = torch.cumsum(tp, dim=0)
430
- fp_cumsum = torch.cumsum(fp, dim=0)
431
 
432
  precision = tp_cumsum / (tp_cumsum + fp_cumsum + 1e-6)
433
  recall = tp_cumsum / (n_gts + 1e-6)
@@ -438,9 +438,7 @@ def calculate_map(predictions, ground_truths, iou_thresholds=arange(0.5, 1, 0.05
438
  precision, _ = torch.cummax(precision.flip(0), dim=0)
439
  precision = precision.flip(0)
440
 
441
- indices = (recall[1:] != recall[:-1]).nonzero(as_tuple=True)[0]
442
- ap = torch.sum((recall[indices + 1] - recall[indices]) * precision[indices + 1])
443
-
444
  aps.append(ap)
445
 
446
  mAP = {
 
412
  ious = calculate_iou(predictions[:, 1:-1], ground_truths[:, 1:]) # [n_preds, n_gts]
413
 
414
  for threshold in iou_thresholds:
415
+ tp = torch.zeros(n_preds, device=device, dtype=bool)
 
416
 
417
+ max_iou, max_indices = ious.max(dim=1)
418
  above_threshold = max_iou >= threshold
419
  matched_classes = predictions[:, 0] == ground_truths[max_indices, 0]
420
+ max_match = torch.zeros_like(ious)
421
+ max_match[arange(n_preds), max_indices] = max_iou
422
+ if max_match.size(0):
423
+ tp[max_match.argmax(dim=0)] = True
424
+ tp[~above_threshold | ~matched_classes] = False
425
 
426
  _, indices = torch.sort(predictions[:, 1], descending=True)
427
  tp = tp[indices]
 
428
 
429
  tp_cumsum = torch.cumsum(tp, dim=0)
430
+ fp_cumsum = torch.cumsum(~tp, dim=0)
431
 
432
  precision = tp_cumsum / (tp_cumsum + fp_cumsum + 1e-6)
433
  recall = tp_cumsum / (n_gts + 1e-6)
 
438
  precision, _ = torch.cummax(precision.flip(0), dim=0)
439
  precision = precision.flip(0)
440
 
441
+ ap = torch.trapezoid(precision, recall)
 
 
442
  aps.append(ap)
443
 
444
  mAP = {