henry000 commited on
Commit
96da794
Β·
1 Parent(s): c4cd90a

πŸ› [Fix] BoxMatcher for filter outsided bbox

Browse files
Files changed (1) hide show
  1. yolo/utils/bounding_box_utils.py +5 -4
yolo/utils/bounding_box_utils.py CHANGED
@@ -212,19 +212,20 @@ class BoxMatcher:
212
  topk_masks = topk_targets > 0
213
  return topk_targets, topk_masks
214
 
215
- def filter_duplicates(self, target_matrix: Tensor, topk_mask: Tensor):
216
  """
217
  Filter the maximum suitability target index of each anchor.
218
 
219
  Args:
220
- target_matrix [batch x targets x anchors]: The suitability for each targets-anchors
221
 
222
  Returns:
223
  unique_indices [batch x anchors x 1]: The index of the best targets for each anchors
224
  """
225
  duplicates = (topk_mask.sum(1, keepdim=True) > 1).repeat([1, topk_mask.size(1), 1])
226
- max_idx = F.one_hot(target_matrix.argmax(1), topk_mask.size(1)).permute(0, 2, 1)
227
  topk_mask = torch.where(duplicates, max_idx, topk_mask)
 
228
  unique_indices = topk_mask.argmax(dim=1)
229
  return unique_indices[..., None], topk_mask.sum(1), topk_mask
230
 
@@ -278,7 +279,7 @@ class BoxMatcher:
278
  topk_targets, topk_mask = self.filter_topk(target_matrix, topk=self.topk)
279
 
280
  # delete one anchor pred assign to mutliple gts
281
- unique_indices, valid_mask, topk_mask = self.filter_duplicates(iou_mat, topk_mask)
282
 
283
  align_bbox = torch.gather(target_bbox, 1, unique_indices.repeat(1, 1, 4))
284
  align_cls = torch.gather(target_cls, 1, unique_indices).squeeze(-1)
 
212
  topk_masks = topk_targets > 0
213
  return topk_targets, topk_masks
214
 
215
+ def filter_duplicates(self, iou_mat: Tensor, topk_mask: Tensor, grid_mask: Tensor):
216
  """
217
  Filter the maximum suitability target index of each anchor.
218
 
219
  Args:
220
+ iou_mat [batch x targets x anchors]: The suitability for each targets-anchors
221
 
222
  Returns:
223
  unique_indices [batch x anchors x 1]: The index of the best targets for each anchors
224
  """
225
  duplicates = (topk_mask.sum(1, keepdim=True) > 1).repeat([1, topk_mask.size(1), 1])
226
+ max_idx = F.one_hot(iou_mat.argmax(1), topk_mask.size(1)).permute(0, 2, 1)
227
  topk_mask = torch.where(duplicates, max_idx, topk_mask)
228
+ topk_mask &= grid_mask
229
  unique_indices = topk_mask.argmax(dim=1)
230
  return unique_indices[..., None], topk_mask.sum(1), topk_mask
231
 
 
279
  topk_targets, topk_mask = self.filter_topk(target_matrix, topk=self.topk)
280
 
281
  # delete one anchor pred assign to mutliple gts
282
+ unique_indices, valid_mask, topk_mask = self.filter_duplicates(iou_mat, topk_mask, grid_mask)
283
 
284
  align_bbox = torch.gather(target_bbox, 1, unique_indices.repeat(1, 1, 4))
285
  align_cls = torch.gather(target_cls, 1, unique_indices).squeeze(-1)