π [Fix] BoxMatcher for filter outsided bbox
Browse files
yolo/utils/bounding_box_utils.py
CHANGED
@@ -212,19 +212,20 @@ class BoxMatcher:
|
|
212 |
topk_masks = topk_targets > 0
|
213 |
return topk_targets, topk_masks
|
214 |
|
215 |
-
def filter_duplicates(self,
|
216 |
"""
|
217 |
Filter the maximum suitability target index of each anchor.
|
218 |
|
219 |
Args:
|
220 |
-
|
221 |
|
222 |
Returns:
|
223 |
unique_indices [batch x anchors x 1]: The index of the best targets for each anchors
|
224 |
"""
|
225 |
duplicates = (topk_mask.sum(1, keepdim=True) > 1).repeat([1, topk_mask.size(1), 1])
|
226 |
-
max_idx = F.one_hot(
|
227 |
topk_mask = torch.where(duplicates, max_idx, topk_mask)
|
|
|
228 |
unique_indices = topk_mask.argmax(dim=1)
|
229 |
return unique_indices[..., None], topk_mask.sum(1), topk_mask
|
230 |
|
@@ -278,7 +279,7 @@ class BoxMatcher:
|
|
278 |
topk_targets, topk_mask = self.filter_topk(target_matrix, topk=self.topk)
|
279 |
|
280 |
# delete one anchor pred assign to mutliple gts
|
281 |
-
unique_indices, valid_mask, topk_mask = self.filter_duplicates(iou_mat, topk_mask)
|
282 |
|
283 |
align_bbox = torch.gather(target_bbox, 1, unique_indices.repeat(1, 1, 4))
|
284 |
align_cls = torch.gather(target_cls, 1, unique_indices).squeeze(-1)
|
|
|
212 |
topk_masks = topk_targets > 0
|
213 |
return topk_targets, topk_masks
|
214 |
|
215 |
+
def filter_duplicates(self, iou_mat: Tensor, topk_mask: Tensor, grid_mask: Tensor):
|
216 |
"""
|
217 |
Filter the maximum suitability target index of each anchor.
|
218 |
|
219 |
Args:
|
220 |
+
iou_mat [batch x targets x anchors]: The suitability for each targets-anchors
|
221 |
|
222 |
Returns:
|
223 |
unique_indices [batch x anchors x 1]: The index of the best targets for each anchors
|
224 |
"""
|
225 |
duplicates = (topk_mask.sum(1, keepdim=True) > 1).repeat([1, topk_mask.size(1), 1])
|
226 |
+
max_idx = F.one_hot(iou_mat.argmax(1), topk_mask.size(1)).permute(0, 2, 1)
|
227 |
topk_mask = torch.where(duplicates, max_idx, topk_mask)
|
228 |
+
topk_mask &= grid_mask
|
229 |
unique_indices = topk_mask.argmax(dim=1)
|
230 |
return unique_indices[..., None], topk_mask.sum(1), topk_mask
|
231 |
|
|
|
279 |
topk_targets, topk_mask = self.filter_topk(target_matrix, topk=self.topk)
|
280 |
|
281 |
# delete one anchor pred assign to mutliple gts
|
282 |
+
unique_indices, valid_mask, topk_mask = self.filter_duplicates(iou_mat, topk_mask, grid_mask)
|
283 |
|
284 |
align_bbox = torch.gather(target_bbox, 1, unique_indices.repeat(1, 1, 4))
|
285 |
align_cls = torch.gather(target_cls, 1, unique_indices).squeeze(-1)
|