henry000 commited on
Commit
fd5413f
·
1 Parent(s): 604c897

🩹 [Fix] BoxMatcher, change eps and filter function

Browse files
Files changed (1) hide show
  1. yolo/utils/bounding_box_utils.py +10 -9
yolo/utils/bounding_box_utils.py CHANGED
@@ -14,7 +14,7 @@ from yolo.utils.logger import logger
14
 
15
  def calculate_iou(bbox1, bbox2, metrics="iou") -> Tensor:
16
  metrics = metrics.lower()
17
- EPS = 1e-9
18
  dtype = bbox1.dtype
19
  bbox1 = bbox1.to(torch.float32)
20
  bbox2 = bbox2.to(torch.float32)
@@ -210,7 +210,7 @@ class BoxMatcher:
210
  topk_masks = topk_targets > 0
211
  return topk_targets, topk_masks
212
 
213
- def filter_duplicates(self, target_matrix: Tensor):
214
  """
215
  Filter the maximum suitability target index of each anchor.
216
 
@@ -220,9 +220,11 @@ class BoxMatcher:
220
  Returns:
221
  unique_indices [batch x anchors x 1]: The index of the best targets for each anchors
222
  """
223
- # TODO: add a assert for no target on the image
224
- unique_indices = target_matrix.argmax(dim=1)
225
- return unique_indices[..., None]
 
 
226
 
227
  def __call__(self, target: Tensor, predict: Tuple[Tensor]) -> Tuple[Tensor, Tensor]:
228
  """
@@ -249,16 +251,15 @@ class BoxMatcher:
249
  topk_targets, topk_mask = self.filter_topk(target_matrix, topk=self.topk)
250
 
251
  # delete one anchor pred assign to mutliple gts
252
- unique_indices = self.filter_duplicates(topk_targets)
253
-
254
- # TODO: do we need grid_mask? Filter the valid groud truth
255
- valid_mask = (grid_mask.sum(dim=-2) * topk_mask.sum(dim=-2)).bool()
256
 
257
  align_bbox = torch.gather(target_bbox, 1, unique_indices.repeat(1, 1, 4))
258
  align_cls = torch.gather(target_cls, 1, unique_indices).squeeze(-1)
259
  align_cls = F.one_hot(align_cls, self.class_num)
260
 
261
  # normalize class ditribution
 
 
262
  max_target = target_matrix.amax(dim=-1, keepdim=True)
263
  max_iou = iou_mat.amax(dim=-1, keepdim=True)
264
  normalize_term = (target_matrix / (max_target + 1e-9)) * max_iou
 
14
 
15
  def calculate_iou(bbox1, bbox2, metrics="iou") -> Tensor:
16
  metrics = metrics.lower()
17
+ EPS = 1e-7
18
  dtype = bbox1.dtype
19
  bbox1 = bbox1.to(torch.float32)
20
  bbox2 = bbox2.to(torch.float32)
 
210
  topk_masks = topk_targets > 0
211
  return topk_targets, topk_masks
212
 
213
+ def filter_duplicates(self, target_matrix: Tensor, topk_mask: Tensor):
214
  """
215
  Filter the maximum suitability target index of each anchor.
216
 
 
220
  Returns:
221
  unique_indices [batch x anchors x 1]: The index of the best targets for each anchors
222
  """
223
+ duplicates = (topk_mask.sum(1, keepdim=True) > 1).repeat([1, topk_mask.size(1), 1])
224
+ max_idx = F.one_hot(target_matrix.argmax(1), topk_mask.size(1)).permute(0, 2, 1)
225
+ topk_mask = torch.where(duplicates, max_idx, topk_mask)
226
+ unique_indices = topk_mask.argmax(dim=1)
227
+ return unique_indices[..., None], topk_mask.sum(1), topk_mask
228
 
229
  def __call__(self, target: Tensor, predict: Tuple[Tensor]) -> Tuple[Tensor, Tensor]:
230
  """
 
251
  topk_targets, topk_mask = self.filter_topk(target_matrix, topk=self.topk)
252
 
253
  # delete one anchor pred assign to mutliple gts
254
+ unique_indices, valid_mask, topk_mask = self.filter_duplicates(iou_mat, topk_mask)
 
 
 
255
 
256
  align_bbox = torch.gather(target_bbox, 1, unique_indices.repeat(1, 1, 4))
257
  align_cls = torch.gather(target_cls, 1, unique_indices).squeeze(-1)
258
  align_cls = F.one_hot(align_cls, self.class_num)
259
 
260
  # normalize class ditribution
261
+ iou_mat *= topk_mask
262
+ target_matrix *= topk_mask
263
  max_target = target_matrix.amax(dim=-1, keepdim=True)
264
  max_iou = iou_mat.amax(dim=-1, keepdim=True)
265
  normalize_term = (target_matrix / (max_target + 1e-9)) * max_iou