Ge Zheng
commited on
Commit
·
9a33b8a
1
Parent(s):
0588424
fix(models): no-candidate anchor issue for tiny objects during label assign (#1589)
Browse files- yolox/models/yolo_head.py +43 -109
yolox/models/yolo_head.py
CHANGED
@@ -32,7 +32,6 @@ class YOLOXHead(nn.Module):
|
|
32 |
"""
|
33 |
super().__init__()
|
34 |
|
35 |
-
self.n_anchors = 1
|
36 |
self.num_classes = num_classes
|
37 |
self.decode_in_inference = True # for deploy, set to False
|
38 |
|
@@ -97,7 +96,7 @@ class YOLOXHead(nn.Module):
|
|
97 |
self.cls_preds.append(
|
98 |
nn.Conv2d(
|
99 |
in_channels=int(256 * width),
|
100 |
-
out_channels=self.
|
101 |
kernel_size=1,
|
102 |
stride=1,
|
103 |
padding=0,
|
@@ -115,7 +114,7 @@ class YOLOXHead(nn.Module):
|
|
115 |
self.obj_preds.append(
|
116 |
nn.Conv2d(
|
117 |
in_channels=int(256 * width),
|
118 |
-
out_channels=
|
119 |
kernel_size=1,
|
120 |
stride=1,
|
121 |
padding=0,
|
@@ -131,12 +130,12 @@ class YOLOXHead(nn.Module):
|
|
131 |
|
132 |
def initialize_biases(self, prior_prob):
|
133 |
for conv in self.cls_preds:
|
134 |
-
b = conv.bias.view(
|
135 |
b.data.fill_(-math.log((1 - prior_prob) / prior_prob))
|
136 |
conv.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
|
137 |
|
138 |
for conv in self.obj_preds:
|
139 |
-
b = conv.bias.view(
|
140 |
b.data.fill_(-math.log((1 - prior_prob) / prior_prob))
|
141 |
conv.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
|
142 |
|
@@ -177,7 +176,7 @@ class YOLOXHead(nn.Module):
|
|
177 |
batch_size = reg_output.shape[0]
|
178 |
hsize, wsize = reg_output.shape[-2:]
|
179 |
reg_output = reg_output.view(
|
180 |
-
batch_size,
|
181 |
)
|
182 |
reg_output = reg_output.permute(0, 1, 3, 4, 2).reshape(
|
183 |
batch_size, -1, 4
|
@@ -224,9 +223,9 @@ class YOLOXHead(nn.Module):
|
|
224 |
grid = torch.stack((xv, yv), 2).view(1, 1, hsize, wsize, 2).type(dtype)
|
225 |
self.grids[k] = grid
|
226 |
|
227 |
-
output = output.view(batch_size,
|
228 |
output = output.permute(0, 1, 3, 4, 2).reshape(
|
229 |
-
batch_size,
|
230 |
)
|
231 |
grid = grid.view(1, -1, 2)
|
232 |
output[..., :2] = (output[..., :2] + grid) * stride
|
@@ -265,7 +264,7 @@ class YOLOXHead(nn.Module):
|
|
265 |
dtype,
|
266 |
):
|
267 |
bbox_preds = outputs[:, :, :4] # [batch, n_anchors_all, 4]
|
268 |
-
obj_preds = outputs[:, :, 4]
|
269 |
cls_preds = outputs[:, :, 5:] # [batch, n_anchors_all, n_cls]
|
270 |
|
271 |
# calculate targets
|
@@ -311,7 +310,6 @@ class YOLOXHead(nn.Module):
|
|
311 |
) = self.get_assignments( # noqa
|
312 |
batch_idx,
|
313 |
num_gt,
|
314 |
-
total_num_anchors,
|
315 |
gt_bboxes_per_image,
|
316 |
gt_classes,
|
317 |
bboxes_preds_per_image,
|
@@ -319,10 +317,7 @@ class YOLOXHead(nn.Module):
|
|
319 |
x_shifts,
|
320 |
y_shifts,
|
321 |
cls_preds,
|
322 |
-
bbox_preds,
|
323 |
obj_preds,
|
324 |
-
labels,
|
325 |
-
imgs,
|
326 |
)
|
327 |
except RuntimeError as e:
|
328 |
# TODO: the string might change, consider a better way
|
@@ -344,7 +339,6 @@ class YOLOXHead(nn.Module):
|
|
344 |
) = self.get_assignments( # noqa
|
345 |
batch_idx,
|
346 |
num_gt,
|
347 |
-
total_num_anchors,
|
348 |
gt_bboxes_per_image,
|
349 |
gt_classes,
|
350 |
bboxes_preds_per_image,
|
@@ -352,10 +346,7 @@ class YOLOXHead(nn.Module):
|
|
352 |
x_shifts,
|
353 |
y_shifts,
|
354 |
cls_preds,
|
355 |
-
bbox_preds,
|
356 |
obj_preds,
|
357 |
-
labels,
|
358 |
-
imgs,
|
359 |
"cpu",
|
360 |
)
|
361 |
|
@@ -433,7 +424,6 @@ class YOLOXHead(nn.Module):
|
|
433 |
self,
|
434 |
batch_idx,
|
435 |
num_gt,
|
436 |
-
total_num_anchors,
|
437 |
gt_bboxes_per_image,
|
438 |
gt_classes,
|
439 |
bboxes_preds_per_image,
|
@@ -441,15 +431,12 @@ class YOLOXHead(nn.Module):
|
|
441 |
x_shifts,
|
442 |
y_shifts,
|
443 |
cls_preds,
|
444 |
-
bbox_preds,
|
445 |
obj_preds,
|
446 |
-
labels,
|
447 |
-
imgs,
|
448 |
mode="gpu",
|
449 |
):
|
450 |
|
451 |
if mode == "cpu":
|
452 |
-
print("
|
453 |
gt_bboxes_per_image = gt_bboxes_per_image.cpu().float()
|
454 |
bboxes_preds_per_image = bboxes_preds_per_image.cpu().float()
|
455 |
gt_classes = gt_classes.cpu().float()
|
@@ -457,13 +444,11 @@ class YOLOXHead(nn.Module):
|
|
457 |
x_shifts = x_shifts.cpu()
|
458 |
y_shifts = y_shifts.cpu()
|
459 |
|
460 |
-
fg_mask,
|
461 |
gt_bboxes_per_image,
|
462 |
expanded_strides,
|
463 |
x_shifts,
|
464 |
y_shifts,
|
465 |
-
total_num_anchors,
|
466 |
-
num_gt,
|
467 |
)
|
468 |
|
469 |
bboxes_preds_per_image = bboxes_preds_per_image[fg_mask]
|
@@ -480,8 +465,6 @@ class YOLOXHead(nn.Module):
|
|
480 |
gt_cls_per_image = (
|
481 |
F.one_hot(gt_classes.to(torch.int64), self.num_classes)
|
482 |
.float()
|
483 |
-
.unsqueeze(1)
|
484 |
-
.repeat(1, num_in_boxes_anchor, 1)
|
485 |
)
|
486 |
pair_wise_ious_loss = -torch.log(pair_wise_ious + 1e-8)
|
487 |
|
@@ -490,18 +473,19 @@ class YOLOXHead(nn.Module):
|
|
490 |
|
491 |
with torch.cuda.amp.autocast(enabled=False):
|
492 |
cls_preds_ = (
|
493 |
-
cls_preds_.float().
|
494 |
-
|
495 |
-
)
|
496 |
pair_wise_cls_loss = F.binary_cross_entropy(
|
497 |
-
cls_preds_.
|
|
|
|
|
498 |
).sum(-1)
|
499 |
del cls_preds_
|
500 |
|
501 |
cost = (
|
502 |
pair_wise_cls_loss
|
503 |
+ 3.0 * pair_wise_ious_loss
|
504 |
-
+
|
505 |
)
|
506 |
|
507 |
(
|
@@ -509,7 +493,7 @@ class YOLOXHead(nn.Module):
|
|
509 |
gt_matched_classes,
|
510 |
pred_ious_this_matching,
|
511 |
matched_gt_inds,
|
512 |
-
) = self.
|
513 |
del pair_wise_cls_loss, cost, pair_wise_ious, pair_wise_ious_loss
|
514 |
|
515 |
if mode == "cpu":
|
@@ -526,74 +510,29 @@ class YOLOXHead(nn.Module):
|
|
526 |
num_fg,
|
527 |
)
|
528 |
|
529 |
-
def
|
530 |
self,
|
531 |
gt_bboxes_per_image,
|
532 |
expanded_strides,
|
533 |
x_shifts,
|
534 |
y_shifts,
|
535 |
-
total_num_anchors,
|
536 |
-
num_gt,
|
537 |
):
|
|
|
|
|
|
|
|
|
|
|
538 |
expanded_strides_per_image = expanded_strides[0]
|
539 |
-
|
540 |
-
|
541 |
-
x_centers_per_image = (
|
542 |
-
(x_shifts_per_image + 0.5 * expanded_strides_per_image)
|
543 |
-
.unsqueeze(0)
|
544 |
-
.repeat(num_gt, 1)
|
545 |
-
) # [n_anchor] -> [n_gt, n_anchor]
|
546 |
-
y_centers_per_image = (
|
547 |
-
(y_shifts_per_image + 0.5 * expanded_strides_per_image)
|
548 |
-
.unsqueeze(0)
|
549 |
-
.repeat(num_gt, 1)
|
550 |
-
)
|
551 |
-
|
552 |
-
gt_bboxes_per_image_l = (
|
553 |
-
(gt_bboxes_per_image[:, 0] - 0.5 * gt_bboxes_per_image[:, 2])
|
554 |
-
.unsqueeze(1)
|
555 |
-
.repeat(1, total_num_anchors)
|
556 |
-
)
|
557 |
-
gt_bboxes_per_image_r = (
|
558 |
-
(gt_bboxes_per_image[:, 0] + 0.5 * gt_bboxes_per_image[:, 2])
|
559 |
-
.unsqueeze(1)
|
560 |
-
.repeat(1, total_num_anchors)
|
561 |
-
)
|
562 |
-
gt_bboxes_per_image_t = (
|
563 |
-
(gt_bboxes_per_image[:, 1] - 0.5 * gt_bboxes_per_image[:, 3])
|
564 |
-
.unsqueeze(1)
|
565 |
-
.repeat(1, total_num_anchors)
|
566 |
-
)
|
567 |
-
gt_bboxes_per_image_b = (
|
568 |
-
(gt_bboxes_per_image[:, 1] + 0.5 * gt_bboxes_per_image[:, 3])
|
569 |
-
.unsqueeze(1)
|
570 |
-
.repeat(1, total_num_anchors)
|
571 |
-
)
|
572 |
|
573 |
-
b_l = x_centers_per_image - gt_bboxes_per_image_l
|
574 |
-
b_r = gt_bboxes_per_image_r - x_centers_per_image
|
575 |
-
b_t = y_centers_per_image - gt_bboxes_per_image_t
|
576 |
-
b_b = gt_bboxes_per_image_b - y_centers_per_image
|
577 |
-
bbox_deltas = torch.stack([b_l, b_t, b_r, b_b], 2)
|
578 |
-
|
579 |
-
is_in_boxes = bbox_deltas.min(dim=-1).values > 0.0
|
580 |
-
is_in_boxes_all = is_in_boxes.sum(dim=0) > 0
|
581 |
# in fixed center
|
582 |
-
|
583 |
-
|
584 |
-
|
585 |
-
|
586 |
-
|
587 |
-
|
588 |
-
gt_bboxes_per_image_r = (gt_bboxes_per_image[:, 0]).unsqueeze(1).repeat(
|
589 |
-
1, total_num_anchors
|
590 |
-
) + center_radius * expanded_strides_per_image.unsqueeze(0)
|
591 |
-
gt_bboxes_per_image_t = (gt_bboxes_per_image[:, 1]).unsqueeze(1).repeat(
|
592 |
-
1, total_num_anchors
|
593 |
-
) - center_radius * expanded_strides_per_image.unsqueeze(0)
|
594 |
-
gt_bboxes_per_image_b = (gt_bboxes_per_image[:, 1]).unsqueeze(1).repeat(
|
595 |
-
1, total_num_anchors
|
596 |
-
) + center_radius * expanded_strides_per_image.unsqueeze(0)
|
597 |
|
598 |
c_l = x_centers_per_image - gt_bboxes_per_image_l
|
599 |
c_r = gt_bboxes_per_image_r - x_centers_per_image
|
@@ -601,26 +540,19 @@ class YOLOXHead(nn.Module):
|
|
601 |
c_b = gt_bboxes_per_image_b - y_centers_per_image
|
602 |
center_deltas = torch.stack([c_l, c_t, c_r, c_b], 2)
|
603 |
is_in_centers = center_deltas.min(dim=-1).values > 0.0
|
604 |
-
|
|
|
605 |
|
606 |
-
|
607 |
-
is_in_boxes_anchor = is_in_boxes_all | is_in_centers_all
|
608 |
-
|
609 |
-
is_in_boxes_and_center = (
|
610 |
-
is_in_boxes[:, is_in_boxes_anchor] & is_in_centers[:, is_in_boxes_anchor]
|
611 |
-
)
|
612 |
-
return is_in_boxes_anchor, is_in_boxes_and_center
|
613 |
|
614 |
-
def
|
615 |
# Dynamic K
|
616 |
# ---------------------------------------------------------------
|
617 |
matching_matrix = torch.zeros_like(cost, dtype=torch.uint8)
|
618 |
|
619 |
-
|
620 |
-
|
621 |
-
topk_ious, _ = torch.topk(ious_in_boxes_matrix, n_candidate_k, dim=1)
|
622 |
dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1)
|
623 |
-
dynamic_ks = dynamic_ks.tolist()
|
624 |
for gt_idx in range(num_gt):
|
625 |
_, pos_idx = torch.topk(
|
626 |
cost[gt_idx], k=dynamic_ks[gt_idx], largest=False
|
@@ -630,11 +562,13 @@ class YOLOXHead(nn.Module):
|
|
630 |
del topk_ious, dynamic_ks, pos_idx
|
631 |
|
632 |
anchor_matching_gt = matching_matrix.sum(0)
|
|
|
633 |
if anchor_matching_gt.max() > 1:
|
634 |
-
|
635 |
-
|
636 |
-
matching_matrix[
|
637 |
-
|
|
|
638 |
num_fg = fg_mask_inboxes.sum().item()
|
639 |
|
640 |
fg_mask[fg_mask.clone()] = fg_mask_inboxes
|
|
|
32 |
"""
|
33 |
super().__init__()
|
34 |
|
|
|
35 |
self.num_classes = num_classes
|
36 |
self.decode_in_inference = True # for deploy, set to False
|
37 |
|
|
|
96 |
self.cls_preds.append(
|
97 |
nn.Conv2d(
|
98 |
in_channels=int(256 * width),
|
99 |
+
out_channels=self.num_classes,
|
100 |
kernel_size=1,
|
101 |
stride=1,
|
102 |
padding=0,
|
|
|
114 |
self.obj_preds.append(
|
115 |
nn.Conv2d(
|
116 |
in_channels=int(256 * width),
|
117 |
+
out_channels=1,
|
118 |
kernel_size=1,
|
119 |
stride=1,
|
120 |
padding=0,
|
|
|
130 |
|
131 |
def initialize_biases(self, prior_prob):
|
132 |
for conv in self.cls_preds:
|
133 |
+
b = conv.bias.view(1, -1)
|
134 |
b.data.fill_(-math.log((1 - prior_prob) / prior_prob))
|
135 |
conv.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
|
136 |
|
137 |
for conv in self.obj_preds:
|
138 |
+
b = conv.bias.view(1, -1)
|
139 |
b.data.fill_(-math.log((1 - prior_prob) / prior_prob))
|
140 |
conv.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
|
141 |
|
|
|
176 |
batch_size = reg_output.shape[0]
|
177 |
hsize, wsize = reg_output.shape[-2:]
|
178 |
reg_output = reg_output.view(
|
179 |
+
batch_size, 1, 4, hsize, wsize
|
180 |
)
|
181 |
reg_output = reg_output.permute(0, 1, 3, 4, 2).reshape(
|
182 |
batch_size, -1, 4
|
|
|
223 |
grid = torch.stack((xv, yv), 2).view(1, 1, hsize, wsize, 2).type(dtype)
|
224 |
self.grids[k] = grid
|
225 |
|
226 |
+
output = output.view(batch_size, 1, n_ch, hsize, wsize)
|
227 |
output = output.permute(0, 1, 3, 4, 2).reshape(
|
228 |
+
batch_size, hsize * wsize, -1
|
229 |
)
|
230 |
grid = grid.view(1, -1, 2)
|
231 |
output[..., :2] = (output[..., :2] + grid) * stride
|
|
|
264 |
dtype,
|
265 |
):
|
266 |
bbox_preds = outputs[:, :, :4] # [batch, n_anchors_all, 4]
|
267 |
+
obj_preds = outputs[:, :, 4:5] # [batch, n_anchors_all, 1]
|
268 |
cls_preds = outputs[:, :, 5:] # [batch, n_anchors_all, n_cls]
|
269 |
|
270 |
# calculate targets
|
|
|
310 |
) = self.get_assignments( # noqa
|
311 |
batch_idx,
|
312 |
num_gt,
|
|
|
313 |
gt_bboxes_per_image,
|
314 |
gt_classes,
|
315 |
bboxes_preds_per_image,
|
|
|
317 |
x_shifts,
|
318 |
y_shifts,
|
319 |
cls_preds,
|
|
|
320 |
obj_preds,
|
|
|
|
|
321 |
)
|
322 |
except RuntimeError as e:
|
323 |
# TODO: the string might change, consider a better way
|
|
|
339 |
) = self.get_assignments( # noqa
|
340 |
batch_idx,
|
341 |
num_gt,
|
|
|
342 |
gt_bboxes_per_image,
|
343 |
gt_classes,
|
344 |
bboxes_preds_per_image,
|
|
|
346 |
x_shifts,
|
347 |
y_shifts,
|
348 |
cls_preds,
|
|
|
349 |
obj_preds,
|
|
|
|
|
350 |
"cpu",
|
351 |
)
|
352 |
|
|
|
424 |
self,
|
425 |
batch_idx,
|
426 |
num_gt,
|
|
|
427 |
gt_bboxes_per_image,
|
428 |
gt_classes,
|
429 |
bboxes_preds_per_image,
|
|
|
431 |
x_shifts,
|
432 |
y_shifts,
|
433 |
cls_preds,
|
|
|
434 |
obj_preds,
|
|
|
|
|
435 |
mode="gpu",
|
436 |
):
|
437 |
|
438 |
if mode == "cpu":
|
439 |
+
print("-----------Using CPU for the Current Batch-------------")
|
440 |
gt_bboxes_per_image = gt_bboxes_per_image.cpu().float()
|
441 |
bboxes_preds_per_image = bboxes_preds_per_image.cpu().float()
|
442 |
gt_classes = gt_classes.cpu().float()
|
|
|
444 |
x_shifts = x_shifts.cpu()
|
445 |
y_shifts = y_shifts.cpu()
|
446 |
|
447 |
+
fg_mask, geometry_relation = self.get_geometry_constraint(
|
448 |
gt_bboxes_per_image,
|
449 |
expanded_strides,
|
450 |
x_shifts,
|
451 |
y_shifts,
|
|
|
|
|
452 |
)
|
453 |
|
454 |
bboxes_preds_per_image = bboxes_preds_per_image[fg_mask]
|
|
|
465 |
gt_cls_per_image = (
|
466 |
F.one_hot(gt_classes.to(torch.int64), self.num_classes)
|
467 |
.float()
|
|
|
|
|
468 |
)
|
469 |
pair_wise_ious_loss = -torch.log(pair_wise_ious + 1e-8)
|
470 |
|
|
|
473 |
|
474 |
with torch.cuda.amp.autocast(enabled=False):
|
475 |
cls_preds_ = (
|
476 |
+
cls_preds_.float().sigmoid_() * obj_preds_.float().sigmoid_()
|
477 |
+
).sqrt()
|
|
|
478 |
pair_wise_cls_loss = F.binary_cross_entropy(
|
479 |
+
cls_preds_.unsqueeze(0).repeat(num_gt, 1, 1),
|
480 |
+
gt_cls_per_image.unsqueeze(1).repeat(1, num_in_boxes_anchor, 1),
|
481 |
+
reduction="none"
|
482 |
).sum(-1)
|
483 |
del cls_preds_
|
484 |
|
485 |
cost = (
|
486 |
pair_wise_cls_loss
|
487 |
+ 3.0 * pair_wise_ious_loss
|
488 |
+
+ float(1e6) * (~geometry_relation)
|
489 |
)
|
490 |
|
491 |
(
|
|
|
493 |
gt_matched_classes,
|
494 |
pred_ious_this_matching,
|
495 |
matched_gt_inds,
|
496 |
+
) = self.simota_matching(cost, pair_wise_ious, gt_classes, num_gt, fg_mask)
|
497 |
del pair_wise_cls_loss, cost, pair_wise_ious, pair_wise_ious_loss
|
498 |
|
499 |
if mode == "cpu":
|
|
|
510 |
num_fg,
|
511 |
)
|
512 |
|
513 |
+
def get_geometry_constraint(
|
514 |
self,
|
515 |
gt_bboxes_per_image,
|
516 |
expanded_strides,
|
517 |
x_shifts,
|
518 |
y_shifts,
|
|
|
|
|
519 |
):
|
520 |
+
"""
|
521 |
+
Calculate whether the center of an object is located in a fixed range of
|
522 |
+
an anchor. This is used to avert inappropriate matching. It can also reduce
|
523 |
+
the number of candidate anchors so that the GPU memory is saved.
|
524 |
+
"""
|
525 |
expanded_strides_per_image = expanded_strides[0]
|
526 |
+
x_centers_per_image = ((x_shifts[0] + 0.5) * expanded_strides_per_image).unsqueeze(0)
|
527 |
+
y_centers_per_image = ((y_shifts[0] + 0.5) * expanded_strides_per_image).unsqueeze(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
528 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
529 |
# in fixed center
|
530 |
+
center_radius = 1.5
|
531 |
+
center_dist = expanded_strides_per_image.unsqueeze(0) * center_radius
|
532 |
+
gt_bboxes_per_image_l = (gt_bboxes_per_image[:, 0:1]) - center_dist
|
533 |
+
gt_bboxes_per_image_r = (gt_bboxes_per_image[:, 0:1]) + center_dist
|
534 |
+
gt_bboxes_per_image_t = (gt_bboxes_per_image[:, 1:2]) - center_dist
|
535 |
+
gt_bboxes_per_image_b = (gt_bboxes_per_image[:, 1:2]) + center_dist
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
536 |
|
537 |
c_l = x_centers_per_image - gt_bboxes_per_image_l
|
538 |
c_r = gt_bboxes_per_image_r - x_centers_per_image
|
|
|
540 |
c_b = gt_bboxes_per_image_b - y_centers_per_image
|
541 |
center_deltas = torch.stack([c_l, c_t, c_r, c_b], 2)
|
542 |
is_in_centers = center_deltas.min(dim=-1).values > 0.0
|
543 |
+
anchor_filter = is_in_centers.sum(dim=0) > 0
|
544 |
+
geometry_relation = is_in_centers[:, anchor_filter]
|
545 |
|
546 |
+
return anchor_filter, geometry_relation
|
|
|
|
|
|
|
|
|
|
|
|
|
547 |
|
548 |
+
def simota_matching(self, cost, pair_wise_ious, gt_classes, num_gt, fg_mask):
|
549 |
# Dynamic K
|
550 |
# ---------------------------------------------------------------
|
551 |
matching_matrix = torch.zeros_like(cost, dtype=torch.uint8)
|
552 |
|
553 |
+
n_candidate_k = min(10, pair_wise_ious.size(1))
|
554 |
+
topk_ious, _ = torch.topk(pair_wise_ious, n_candidate_k, dim=1)
|
|
|
555 |
dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1)
|
|
|
556 |
for gt_idx in range(num_gt):
|
557 |
_, pos_idx = torch.topk(
|
558 |
cost[gt_idx], k=dynamic_ks[gt_idx], largest=False
|
|
|
562 |
del topk_ious, dynamic_ks, pos_idx
|
563 |
|
564 |
anchor_matching_gt = matching_matrix.sum(0)
|
565 |
+
# deal with the case that one anchor matches multiple ground-truths
|
566 |
if anchor_matching_gt.max() > 1:
|
567 |
+
multiple_match_mask = anchor_matching_gt > 1
|
568 |
+
_, cost_argmin = torch.min(cost[:, multiple_match_mask], dim=0)
|
569 |
+
matching_matrix[:, multiple_match_mask] *= 0
|
570 |
+
matching_matrix[cost_argmin, multiple_match_mask] = 1
|
571 |
+
fg_mask_inboxes = anchor_matching_gt > 0
|
572 |
num_fg = fg_mask_inboxes.sum().item()
|
573 |
|
574 |
fg_mask[fg_mask.clone()] = fg_mask_inboxes
|