Wayne Mao
commited on
Commit
·
de8701c
1
Parent(s):
665328e
chore(model) optimize dynamic_k_matching with int (#861)
Browse files
yolox/models/yolo_head.py
CHANGED
@@ -607,26 +607,27 @@ class YOLOXHead(nn.Module):
|
|
607 |
def dynamic_k_matching(self, cost, pair_wise_ious, gt_classes, num_gt, fg_mask):
|
608 |
# Dynamic K
|
609 |
# ---------------------------------------------------------------
|
610 |
-
matching_matrix = torch.zeros_like(cost)
|
611 |
|
612 |
ious_in_boxes_matrix = pair_wise_ious
|
613 |
n_candidate_k = min(10, ious_in_boxes_matrix.size(1))
|
614 |
topk_ious, _ = torch.topk(ious_in_boxes_matrix, n_candidate_k, dim=1)
|
615 |
dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1)
|
|
|
616 |
for gt_idx in range(num_gt):
|
617 |
_, pos_idx = torch.topk(
|
618 |
-
cost[gt_idx], k=dynamic_ks[gt_idx]
|
619 |
)
|
620 |
-
matching_matrix[gt_idx][pos_idx] = 1
|
621 |
|
622 |
del topk_ious, dynamic_ks, pos_idx
|
623 |
|
624 |
anchor_matching_gt = matching_matrix.sum(0)
|
625 |
if (anchor_matching_gt > 1).sum() > 0:
|
626 |
_, cost_argmin = torch.min(cost[:, anchor_matching_gt > 1], dim=0)
|
627 |
-
matching_matrix[:, anchor_matching_gt > 1] *= 0
|
628 |
-
matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1
|
629 |
-
fg_mask_inboxes = matching_matrix.sum(0) > 0
|
630 |
num_fg = fg_mask_inboxes.sum().item()
|
631 |
|
632 |
fg_mask[fg_mask.clone()] = fg_mask_inboxes
|
|
|
607 |
def dynamic_k_matching(self, cost, pair_wise_ious, gt_classes, num_gt, fg_mask):
|
608 |
# Dynamic K
|
609 |
# ---------------------------------------------------------------
|
610 |
+
matching_matrix = torch.zeros_like(cost, dtype=torch.uint8)
|
611 |
|
612 |
ious_in_boxes_matrix = pair_wise_ious
|
613 |
n_candidate_k = min(10, ious_in_boxes_matrix.size(1))
|
614 |
topk_ious, _ = torch.topk(ious_in_boxes_matrix, n_candidate_k, dim=1)
|
615 |
dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1)
|
616 |
+
dynamic_ks = dynamic_ks.tolist()
|
617 |
for gt_idx in range(num_gt):
|
618 |
_, pos_idx = torch.topk(
|
619 |
+
cost[gt_idx], k=dynamic_ks[gt_idx], largest=False
|
620 |
)
|
621 |
+
matching_matrix[gt_idx][pos_idx] = 1
|
622 |
|
623 |
del topk_ious, dynamic_ks, pos_idx
|
624 |
|
625 |
anchor_matching_gt = matching_matrix.sum(0)
|
626 |
if (anchor_matching_gt > 1).sum() > 0:
|
627 |
_, cost_argmin = torch.min(cost[:, anchor_matching_gt > 1], dim=0)
|
628 |
+
matching_matrix[:, anchor_matching_gt > 1] *= 0
|
629 |
+
matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1
|
630 |
+
fg_mask_inboxes = matching_matrix.sum(0) > 0
|
631 |
num_fg = fg_mask_inboxes.sum().item()
|
632 |
|
633 |
fg_mask[fg_mask.clone()] = fg_mask_inboxes
|