π [Fix] #38 local mAP bugs, one pd match < 1 gt
Browse files
yolo/utils/bounding_box_utils.py
CHANGED
@@ -412,22 +412,22 @@ def calculate_map(predictions, ground_truths, iou_thresholds=arange(0.5, 1, 0.05
|
|
412 |
ious = calculate_iou(predictions[:, 1:-1], ground_truths[:, 1:]) # [n_preds, n_gts]
|
413 |
|
414 |
for threshold in iou_thresholds:
|
415 |
-
tp = torch.zeros(n_preds, device=device)
|
416 |
-
fp = torch.zeros(n_preds, device=device)
|
417 |
|
418 |
-
max_iou, max_indices =
|
419 |
above_threshold = max_iou >= threshold
|
420 |
matched_classes = predictions[:, 0] == ground_truths[max_indices, 0]
|
421 |
-
|
422 |
-
|
423 |
-
|
|
|
|
|
424 |
|
425 |
_, indices = torch.sort(predictions[:, 1], descending=True)
|
426 |
tp = tp[indices]
|
427 |
-
fp = fp[indices]
|
428 |
|
429 |
tp_cumsum = torch.cumsum(tp, dim=0)
|
430 |
-
fp_cumsum = torch.cumsum(
|
431 |
|
432 |
precision = tp_cumsum / (tp_cumsum + fp_cumsum + 1e-6)
|
433 |
recall = tp_cumsum / (n_gts + 1e-6)
|
@@ -438,9 +438,7 @@ def calculate_map(predictions, ground_truths, iou_thresholds=arange(0.5, 1, 0.05
|
|
438 |
precision, _ = torch.cummax(precision.flip(0), dim=0)
|
439 |
precision = precision.flip(0)
|
440 |
|
441 |
-
|
442 |
-
ap = torch.sum((recall[indices + 1] - recall[indices]) * precision[indices + 1])
|
443 |
-
|
444 |
aps.append(ap)
|
445 |
|
446 |
mAP = {
|
|
|
412 |
ious = calculate_iou(predictions[:, 1:-1], ground_truths[:, 1:]) # [n_preds, n_gts]
|
413 |
|
414 |
for threshold in iou_thresholds:
|
415 |
+
tp = torch.zeros(n_preds, device=device, dtype=bool)
|
|
|
416 |
|
417 |
+
max_iou, max_indices = ious.max(dim=1)
|
418 |
above_threshold = max_iou >= threshold
|
419 |
matched_classes = predictions[:, 0] == ground_truths[max_indices, 0]
|
420 |
+
max_match = torch.zeros_like(ious)
|
421 |
+
max_match[arange(n_preds), max_indices] = max_iou
|
422 |
+
if max_match.size(0):
|
423 |
+
tp[max_match.argmax(dim=0)] = True
|
424 |
+
tp[~above_threshold | ~matched_classes] = False
|
425 |
|
426 |
_, indices = torch.sort(predictions[:, 1], descending=True)
|
427 |
tp = tp[indices]
|
|
|
428 |
|
429 |
tp_cumsum = torch.cumsum(tp, dim=0)
|
430 |
+
fp_cumsum = torch.cumsum(~tp, dim=0)
|
431 |
|
432 |
precision = tp_cumsum / (tp_cumsum + fp_cumsum + 1e-6)
|
433 |
recall = tp_cumsum / (n_gts + 1e-6)
|
|
|
438 |
precision, _ = torch.cummax(precision.flip(0), dim=0)
|
439 |
precision = precision.flip(0)
|
440 |
|
441 |
+
ap = torch.trapezoid(precision, recall)
|
|
|
|
|
442 |
aps.append(ap)
|
443 |
|
444 |
mAP = {
|