Ge Zheng commited on
Commit
9a33b8a
·
1 Parent(s): 0588424

fix(models): no-candidate anchor issue for tiny objects during label assign (#1589)

Browse files
Files changed (1) hide show
  1. yolox/models/yolo_head.py +43 -109
yolox/models/yolo_head.py CHANGED
@@ -32,7 +32,6 @@ class YOLOXHead(nn.Module):
32
  """
33
  super().__init__()
34
 
35
- self.n_anchors = 1
36
  self.num_classes = num_classes
37
  self.decode_in_inference = True # for deploy, set to False
38
 
@@ -97,7 +96,7 @@ class YOLOXHead(nn.Module):
97
  self.cls_preds.append(
98
  nn.Conv2d(
99
  in_channels=int(256 * width),
100
- out_channels=self.n_anchors * self.num_classes,
101
  kernel_size=1,
102
  stride=1,
103
  padding=0,
@@ -115,7 +114,7 @@ class YOLOXHead(nn.Module):
115
  self.obj_preds.append(
116
  nn.Conv2d(
117
  in_channels=int(256 * width),
118
- out_channels=self.n_anchors * 1,
119
  kernel_size=1,
120
  stride=1,
121
  padding=0,
@@ -131,12 +130,12 @@ class YOLOXHead(nn.Module):
131
 
132
  def initialize_biases(self, prior_prob):
133
  for conv in self.cls_preds:
134
- b = conv.bias.view(self.n_anchors, -1)
135
  b.data.fill_(-math.log((1 - prior_prob) / prior_prob))
136
  conv.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
137
 
138
  for conv in self.obj_preds:
139
- b = conv.bias.view(self.n_anchors, -1)
140
  b.data.fill_(-math.log((1 - prior_prob) / prior_prob))
141
  conv.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
142
 
@@ -177,7 +176,7 @@ class YOLOXHead(nn.Module):
177
  batch_size = reg_output.shape[0]
178
  hsize, wsize = reg_output.shape[-2:]
179
  reg_output = reg_output.view(
180
- batch_size, self.n_anchors, 4, hsize, wsize
181
  )
182
  reg_output = reg_output.permute(0, 1, 3, 4, 2).reshape(
183
  batch_size, -1, 4
@@ -224,9 +223,9 @@ class YOLOXHead(nn.Module):
224
  grid = torch.stack((xv, yv), 2).view(1, 1, hsize, wsize, 2).type(dtype)
225
  self.grids[k] = grid
226
 
227
- output = output.view(batch_size, self.n_anchors, n_ch, hsize, wsize)
228
  output = output.permute(0, 1, 3, 4, 2).reshape(
229
- batch_size, self.n_anchors * hsize * wsize, -1
230
  )
231
  grid = grid.view(1, -1, 2)
232
  output[..., :2] = (output[..., :2] + grid) * stride
@@ -265,7 +264,7 @@ class YOLOXHead(nn.Module):
265
  dtype,
266
  ):
267
  bbox_preds = outputs[:, :, :4] # [batch, n_anchors_all, 4]
268
- obj_preds = outputs[:, :, 4].unsqueeze(-1) # [batch, n_anchors_all, 1]
269
  cls_preds = outputs[:, :, 5:] # [batch, n_anchors_all, n_cls]
270
 
271
  # calculate targets
@@ -311,7 +310,6 @@ class YOLOXHead(nn.Module):
311
  ) = self.get_assignments( # noqa
312
  batch_idx,
313
  num_gt,
314
- total_num_anchors,
315
  gt_bboxes_per_image,
316
  gt_classes,
317
  bboxes_preds_per_image,
@@ -319,10 +317,7 @@ class YOLOXHead(nn.Module):
319
  x_shifts,
320
  y_shifts,
321
  cls_preds,
322
- bbox_preds,
323
  obj_preds,
324
- labels,
325
- imgs,
326
  )
327
  except RuntimeError as e:
328
  # TODO: the string might change, consider a better way
@@ -344,7 +339,6 @@ class YOLOXHead(nn.Module):
344
  ) = self.get_assignments( # noqa
345
  batch_idx,
346
  num_gt,
347
- total_num_anchors,
348
  gt_bboxes_per_image,
349
  gt_classes,
350
  bboxes_preds_per_image,
@@ -352,10 +346,7 @@ class YOLOXHead(nn.Module):
352
  x_shifts,
353
  y_shifts,
354
  cls_preds,
355
- bbox_preds,
356
  obj_preds,
357
- labels,
358
- imgs,
359
  "cpu",
360
  )
361
 
@@ -433,7 +424,6 @@ class YOLOXHead(nn.Module):
433
  self,
434
  batch_idx,
435
  num_gt,
436
- total_num_anchors,
437
  gt_bboxes_per_image,
438
  gt_classes,
439
  bboxes_preds_per_image,
@@ -441,15 +431,12 @@ class YOLOXHead(nn.Module):
441
  x_shifts,
442
  y_shifts,
443
  cls_preds,
444
- bbox_preds,
445
  obj_preds,
446
- labels,
447
- imgs,
448
  mode="gpu",
449
  ):
450
 
451
  if mode == "cpu":
452
- print("------------CPU Mode for This Batch-------------")
453
  gt_bboxes_per_image = gt_bboxes_per_image.cpu().float()
454
  bboxes_preds_per_image = bboxes_preds_per_image.cpu().float()
455
  gt_classes = gt_classes.cpu().float()
@@ -457,13 +444,11 @@ class YOLOXHead(nn.Module):
457
  x_shifts = x_shifts.cpu()
458
  y_shifts = y_shifts.cpu()
459
 
460
- fg_mask, is_in_boxes_and_center = self.get_in_boxes_info(
461
  gt_bboxes_per_image,
462
  expanded_strides,
463
  x_shifts,
464
  y_shifts,
465
- total_num_anchors,
466
- num_gt,
467
  )
468
 
469
  bboxes_preds_per_image = bboxes_preds_per_image[fg_mask]
@@ -480,8 +465,6 @@ class YOLOXHead(nn.Module):
480
  gt_cls_per_image = (
481
  F.one_hot(gt_classes.to(torch.int64), self.num_classes)
482
  .float()
483
- .unsqueeze(1)
484
- .repeat(1, num_in_boxes_anchor, 1)
485
  )
486
  pair_wise_ious_loss = -torch.log(pair_wise_ious + 1e-8)
487
 
@@ -490,18 +473,19 @@ class YOLOXHead(nn.Module):
490
 
491
  with torch.cuda.amp.autocast(enabled=False):
492
  cls_preds_ = (
493
- cls_preds_.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_()
494
- * obj_preds_.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_()
495
- )
496
  pair_wise_cls_loss = F.binary_cross_entropy(
497
- cls_preds_.sqrt_(), gt_cls_per_image, reduction="none"
 
 
498
  ).sum(-1)
499
  del cls_preds_
500
 
501
  cost = (
502
  pair_wise_cls_loss
503
  + 3.0 * pair_wise_ious_loss
504
- + 100000.0 * (~is_in_boxes_and_center)
505
  )
506
 
507
  (
@@ -509,7 +493,7 @@ class YOLOXHead(nn.Module):
509
  gt_matched_classes,
510
  pred_ious_this_matching,
511
  matched_gt_inds,
512
- ) = self.dynamic_k_matching(cost, pair_wise_ious, gt_classes, num_gt, fg_mask)
513
  del pair_wise_cls_loss, cost, pair_wise_ious, pair_wise_ious_loss
514
 
515
  if mode == "cpu":
@@ -526,74 +510,29 @@ class YOLOXHead(nn.Module):
526
  num_fg,
527
  )
528
 
529
- def get_in_boxes_info(
530
  self,
531
  gt_bboxes_per_image,
532
  expanded_strides,
533
  x_shifts,
534
  y_shifts,
535
- total_num_anchors,
536
- num_gt,
537
  ):
 
 
 
 
 
538
  expanded_strides_per_image = expanded_strides[0]
539
- x_shifts_per_image = x_shifts[0] * expanded_strides_per_image
540
- y_shifts_per_image = y_shifts[0] * expanded_strides_per_image
541
- x_centers_per_image = (
542
- (x_shifts_per_image + 0.5 * expanded_strides_per_image)
543
- .unsqueeze(0)
544
- .repeat(num_gt, 1)
545
- ) # [n_anchor] -> [n_gt, n_anchor]
546
- y_centers_per_image = (
547
- (y_shifts_per_image + 0.5 * expanded_strides_per_image)
548
- .unsqueeze(0)
549
- .repeat(num_gt, 1)
550
- )
551
-
552
- gt_bboxes_per_image_l = (
553
- (gt_bboxes_per_image[:, 0] - 0.5 * gt_bboxes_per_image[:, 2])
554
- .unsqueeze(1)
555
- .repeat(1, total_num_anchors)
556
- )
557
- gt_bboxes_per_image_r = (
558
- (gt_bboxes_per_image[:, 0] + 0.5 * gt_bboxes_per_image[:, 2])
559
- .unsqueeze(1)
560
- .repeat(1, total_num_anchors)
561
- )
562
- gt_bboxes_per_image_t = (
563
- (gt_bboxes_per_image[:, 1] - 0.5 * gt_bboxes_per_image[:, 3])
564
- .unsqueeze(1)
565
- .repeat(1, total_num_anchors)
566
- )
567
- gt_bboxes_per_image_b = (
568
- (gt_bboxes_per_image[:, 1] + 0.5 * gt_bboxes_per_image[:, 3])
569
- .unsqueeze(1)
570
- .repeat(1, total_num_anchors)
571
- )
572
 
573
- b_l = x_centers_per_image - gt_bboxes_per_image_l
574
- b_r = gt_bboxes_per_image_r - x_centers_per_image
575
- b_t = y_centers_per_image - gt_bboxes_per_image_t
576
- b_b = gt_bboxes_per_image_b - y_centers_per_image
577
- bbox_deltas = torch.stack([b_l, b_t, b_r, b_b], 2)
578
-
579
- is_in_boxes = bbox_deltas.min(dim=-1).values > 0.0
580
- is_in_boxes_all = is_in_boxes.sum(dim=0) > 0
581
  # in fixed center
582
-
583
- center_radius = 2.5
584
-
585
- gt_bboxes_per_image_l = (gt_bboxes_per_image[:, 0]).unsqueeze(1).repeat(
586
- 1, total_num_anchors
587
- ) - center_radius * expanded_strides_per_image.unsqueeze(0)
588
- gt_bboxes_per_image_r = (gt_bboxes_per_image[:, 0]).unsqueeze(1).repeat(
589
- 1, total_num_anchors
590
- ) + center_radius * expanded_strides_per_image.unsqueeze(0)
591
- gt_bboxes_per_image_t = (gt_bboxes_per_image[:, 1]).unsqueeze(1).repeat(
592
- 1, total_num_anchors
593
- ) - center_radius * expanded_strides_per_image.unsqueeze(0)
594
- gt_bboxes_per_image_b = (gt_bboxes_per_image[:, 1]).unsqueeze(1).repeat(
595
- 1, total_num_anchors
596
- ) + center_radius * expanded_strides_per_image.unsqueeze(0)
597
 
598
  c_l = x_centers_per_image - gt_bboxes_per_image_l
599
  c_r = gt_bboxes_per_image_r - x_centers_per_image
@@ -601,26 +540,19 @@ class YOLOXHead(nn.Module):
601
  c_b = gt_bboxes_per_image_b - y_centers_per_image
602
  center_deltas = torch.stack([c_l, c_t, c_r, c_b], 2)
603
  is_in_centers = center_deltas.min(dim=-1).values > 0.0
604
- is_in_centers_all = is_in_centers.sum(dim=0) > 0
 
605
 
606
- # in boxes and in centers
607
- is_in_boxes_anchor = is_in_boxes_all | is_in_centers_all
608
-
609
- is_in_boxes_and_center = (
610
- is_in_boxes[:, is_in_boxes_anchor] & is_in_centers[:, is_in_boxes_anchor]
611
- )
612
- return is_in_boxes_anchor, is_in_boxes_and_center
613
 
614
- def dynamic_k_matching(self, cost, pair_wise_ious, gt_classes, num_gt, fg_mask):
615
  # Dynamic K
616
  # ---------------------------------------------------------------
617
  matching_matrix = torch.zeros_like(cost, dtype=torch.uint8)
618
 
619
- ious_in_boxes_matrix = pair_wise_ious
620
- n_candidate_k = min(10, ious_in_boxes_matrix.size(1))
621
- topk_ious, _ = torch.topk(ious_in_boxes_matrix, n_candidate_k, dim=1)
622
  dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1)
623
- dynamic_ks = dynamic_ks.tolist()
624
  for gt_idx in range(num_gt):
625
  _, pos_idx = torch.topk(
626
  cost[gt_idx], k=dynamic_ks[gt_idx], largest=False
@@ -630,11 +562,13 @@ class YOLOXHead(nn.Module):
630
  del topk_ious, dynamic_ks, pos_idx
631
 
632
  anchor_matching_gt = matching_matrix.sum(0)
 
633
  if anchor_matching_gt.max() > 1:
634
- _, cost_argmin = torch.min(cost[:, anchor_matching_gt > 1], dim=0)
635
- matching_matrix[:, anchor_matching_gt > 1] *= 0
636
- matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1
637
- fg_mask_inboxes = matching_matrix.sum(0) > 0
 
638
  num_fg = fg_mask_inboxes.sum().item()
639
 
640
  fg_mask[fg_mask.clone()] = fg_mask_inboxes
 
32
  """
33
  super().__init__()
34
 
 
35
  self.num_classes = num_classes
36
  self.decode_in_inference = True # for deploy, set to False
37
 
 
96
  self.cls_preds.append(
97
  nn.Conv2d(
98
  in_channels=int(256 * width),
99
+ out_channels=self.num_classes,
100
  kernel_size=1,
101
  stride=1,
102
  padding=0,
 
114
  self.obj_preds.append(
115
  nn.Conv2d(
116
  in_channels=int(256 * width),
117
+ out_channels=1,
118
  kernel_size=1,
119
  stride=1,
120
  padding=0,
 
130
 
131
  def initialize_biases(self, prior_prob):
132
  for conv in self.cls_preds:
133
+ b = conv.bias.view(1, -1)
134
  b.data.fill_(-math.log((1 - prior_prob) / prior_prob))
135
  conv.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
136
 
137
  for conv in self.obj_preds:
138
+ b = conv.bias.view(1, -1)
139
  b.data.fill_(-math.log((1 - prior_prob) / prior_prob))
140
  conv.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
141
 
 
176
  batch_size = reg_output.shape[0]
177
  hsize, wsize = reg_output.shape[-2:]
178
  reg_output = reg_output.view(
179
+ batch_size, 1, 4, hsize, wsize
180
  )
181
  reg_output = reg_output.permute(0, 1, 3, 4, 2).reshape(
182
  batch_size, -1, 4
 
223
  grid = torch.stack((xv, yv), 2).view(1, 1, hsize, wsize, 2).type(dtype)
224
  self.grids[k] = grid
225
 
226
+ output = output.view(batch_size, 1, n_ch, hsize, wsize)
227
  output = output.permute(0, 1, 3, 4, 2).reshape(
228
+ batch_size, hsize * wsize, -1
229
  )
230
  grid = grid.view(1, -1, 2)
231
  output[..., :2] = (output[..., :2] + grid) * stride
 
264
  dtype,
265
  ):
266
  bbox_preds = outputs[:, :, :4] # [batch, n_anchors_all, 4]
267
+ obj_preds = outputs[:, :, 4:5] # [batch, n_anchors_all, 1]
268
  cls_preds = outputs[:, :, 5:] # [batch, n_anchors_all, n_cls]
269
 
270
  # calculate targets
 
310
  ) = self.get_assignments( # noqa
311
  batch_idx,
312
  num_gt,
 
313
  gt_bboxes_per_image,
314
  gt_classes,
315
  bboxes_preds_per_image,
 
317
  x_shifts,
318
  y_shifts,
319
  cls_preds,
 
320
  obj_preds,
 
 
321
  )
322
  except RuntimeError as e:
323
  # TODO: the string might change, consider a better way
 
339
  ) = self.get_assignments( # noqa
340
  batch_idx,
341
  num_gt,
 
342
  gt_bboxes_per_image,
343
  gt_classes,
344
  bboxes_preds_per_image,
 
346
  x_shifts,
347
  y_shifts,
348
  cls_preds,
 
349
  obj_preds,
 
 
350
  "cpu",
351
  )
352
 
 
424
  self,
425
  batch_idx,
426
  num_gt,
 
427
  gt_bboxes_per_image,
428
  gt_classes,
429
  bboxes_preds_per_image,
 
431
  x_shifts,
432
  y_shifts,
433
  cls_preds,
 
434
  obj_preds,
 
 
435
  mode="gpu",
436
  ):
437
 
438
  if mode == "cpu":
439
+ print("-----------Using CPU for the Current Batch-------------")
440
  gt_bboxes_per_image = gt_bboxes_per_image.cpu().float()
441
  bboxes_preds_per_image = bboxes_preds_per_image.cpu().float()
442
  gt_classes = gt_classes.cpu().float()
 
444
  x_shifts = x_shifts.cpu()
445
  y_shifts = y_shifts.cpu()
446
 
447
+ fg_mask, geometry_relation = self.get_geometry_constraint(
448
  gt_bboxes_per_image,
449
  expanded_strides,
450
  x_shifts,
451
  y_shifts,
 
 
452
  )
453
 
454
  bboxes_preds_per_image = bboxes_preds_per_image[fg_mask]
 
465
  gt_cls_per_image = (
466
  F.one_hot(gt_classes.to(torch.int64), self.num_classes)
467
  .float()
 
 
468
  )
469
  pair_wise_ious_loss = -torch.log(pair_wise_ious + 1e-8)
470
 
 
473
 
474
  with torch.cuda.amp.autocast(enabled=False):
475
  cls_preds_ = (
476
+ cls_preds_.float().sigmoid_() * obj_preds_.float().sigmoid_()
477
+ ).sqrt()
 
478
  pair_wise_cls_loss = F.binary_cross_entropy(
479
+ cls_preds_.unsqueeze(0).repeat(num_gt, 1, 1),
480
+ gt_cls_per_image.unsqueeze(1).repeat(1, num_in_boxes_anchor, 1),
481
+ reduction="none"
482
  ).sum(-1)
483
  del cls_preds_
484
 
485
  cost = (
486
  pair_wise_cls_loss
487
  + 3.0 * pair_wise_ious_loss
488
+ + float(1e6) * (~geometry_relation)
489
  )
490
 
491
  (
 
493
  gt_matched_classes,
494
  pred_ious_this_matching,
495
  matched_gt_inds,
496
+ ) = self.simota_matching(cost, pair_wise_ious, gt_classes, num_gt, fg_mask)
497
  del pair_wise_cls_loss, cost, pair_wise_ious, pair_wise_ious_loss
498
 
499
  if mode == "cpu":
 
510
  num_fg,
511
  )
512
 
513
+ def get_geometry_constraint(
514
  self,
515
  gt_bboxes_per_image,
516
  expanded_strides,
517
  x_shifts,
518
  y_shifts,
 
 
519
  ):
520
+ """
521
+ Calculate whether the center of an object is located in a fixed range of
522
+ an anchor. This is used to avert inappropriate matching. It can also reduce
523
+ the number of candidate anchors so that the GPU memory is saved.
524
+ """
525
  expanded_strides_per_image = expanded_strides[0]
526
+ x_centers_per_image = ((x_shifts[0] + 0.5) * expanded_strides_per_image).unsqueeze(0)
527
+ y_centers_per_image = ((y_shifts[0] + 0.5) * expanded_strides_per_image).unsqueeze(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
528
 
 
 
 
 
 
 
 
 
529
  # in fixed center
530
+ center_radius = 1.5
531
+ center_dist = expanded_strides_per_image.unsqueeze(0) * center_radius
532
+ gt_bboxes_per_image_l = (gt_bboxes_per_image[:, 0:1]) - center_dist
533
+ gt_bboxes_per_image_r = (gt_bboxes_per_image[:, 0:1]) + center_dist
534
+ gt_bboxes_per_image_t = (gt_bboxes_per_image[:, 1:2]) - center_dist
535
+ gt_bboxes_per_image_b = (gt_bboxes_per_image[:, 1:2]) + center_dist
 
 
 
 
 
 
 
 
 
536
 
537
  c_l = x_centers_per_image - gt_bboxes_per_image_l
538
  c_r = gt_bboxes_per_image_r - x_centers_per_image
 
540
  c_b = gt_bboxes_per_image_b - y_centers_per_image
541
  center_deltas = torch.stack([c_l, c_t, c_r, c_b], 2)
542
  is_in_centers = center_deltas.min(dim=-1).values > 0.0
543
+ anchor_filter = is_in_centers.sum(dim=0) > 0
544
+ geometry_relation = is_in_centers[:, anchor_filter]
545
 
546
+ return anchor_filter, geometry_relation
 
 
 
 
 
 
547
 
548
+ def simota_matching(self, cost, pair_wise_ious, gt_classes, num_gt, fg_mask):
549
  # Dynamic K
550
  # ---------------------------------------------------------------
551
  matching_matrix = torch.zeros_like(cost, dtype=torch.uint8)
552
 
553
+ n_candidate_k = min(10, pair_wise_ious.size(1))
554
+ topk_ious, _ = torch.topk(pair_wise_ious, n_candidate_k, dim=1)
 
555
  dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1)
 
556
  for gt_idx in range(num_gt):
557
  _, pos_idx = torch.topk(
558
  cost[gt_idx], k=dynamic_ks[gt_idx], largest=False
 
562
  del topk_ious, dynamic_ks, pos_idx
563
 
564
  anchor_matching_gt = matching_matrix.sum(0)
565
+ # deal with the case that one anchor matches multiple ground-truths
566
  if anchor_matching_gt.max() > 1:
567
+ multiple_match_mask = anchor_matching_gt > 1
568
+ _, cost_argmin = torch.min(cost[:, multiple_match_mask], dim=0)
569
+ matching_matrix[:, multiple_match_mask] *= 0
570
+ matching_matrix[cost_argmin, multiple_match_mask] = 1
571
+ fg_mask_inboxes = anchor_matching_gt > 0
572
  num_fg = fg_mask_inboxes.sum().item()
573
 
574
  fg_mask[fg_mask.clone()] = fg_mask_inboxes