🩹 [Fix] BoxMatcher, change eps and filter function
Browse files
yolo/utils/bounding_box_utils.py
CHANGED
@@ -14,7 +14,7 @@ from yolo.utils.logger import logger
|
|
14 |
|
15 |
def calculate_iou(bbox1, bbox2, metrics="iou") -> Tensor:
|
16 |
metrics = metrics.lower()
|
17 |
-
EPS = 1e-
|
18 |
dtype = bbox1.dtype
|
19 |
bbox1 = bbox1.to(torch.float32)
|
20 |
bbox2 = bbox2.to(torch.float32)
|
@@ -210,7 +210,7 @@ class BoxMatcher:
|
|
210 |
topk_masks = topk_targets > 0
|
211 |
return topk_targets, topk_masks
|
212 |
|
213 |
-
def filter_duplicates(self, target_matrix: Tensor):
|
214 |
"""
|
215 |
Filter the maximum suitability target index of each anchor.
|
216 |
|
@@ -220,9 +220,11 @@ class BoxMatcher:
|
|
220 |
Returns:
|
221 |
unique_indices [batch x anchors x 1]: The index of the best targets for each anchors
|
222 |
"""
|
223 |
-
|
224 |
-
|
225 |
-
|
|
|
|
|
226 |
|
227 |
def __call__(self, target: Tensor, predict: Tuple[Tensor]) -> Tuple[Tensor, Tensor]:
|
228 |
"""
|
@@ -249,16 +251,15 @@ class BoxMatcher:
|
|
249 |
topk_targets, topk_mask = self.filter_topk(target_matrix, topk=self.topk)
|
250 |
|
251 |
# delete one anchor pred assign to mutliple gts
|
252 |
-
unique_indices = self.filter_duplicates(
|
253 |
-
|
254 |
-
# TODO: do we need grid_mask? Filter the valid groud truth
|
255 |
-
valid_mask = (grid_mask.sum(dim=-2) * topk_mask.sum(dim=-2)).bool()
|
256 |
|
257 |
align_bbox = torch.gather(target_bbox, 1, unique_indices.repeat(1, 1, 4))
|
258 |
align_cls = torch.gather(target_cls, 1, unique_indices).squeeze(-1)
|
259 |
align_cls = F.one_hot(align_cls, self.class_num)
|
260 |
|
261 |
# normalize class ditribution
|
|
|
|
|
262 |
max_target = target_matrix.amax(dim=-1, keepdim=True)
|
263 |
max_iou = iou_mat.amax(dim=-1, keepdim=True)
|
264 |
normalize_term = (target_matrix / (max_target + 1e-9)) * max_iou
|
|
|
14 |
|
15 |
def calculate_iou(bbox1, bbox2, metrics="iou") -> Tensor:
|
16 |
metrics = metrics.lower()
|
17 |
+
EPS = 1e-7
|
18 |
dtype = bbox1.dtype
|
19 |
bbox1 = bbox1.to(torch.float32)
|
20 |
bbox2 = bbox2.to(torch.float32)
|
|
|
210 |
topk_masks = topk_targets > 0
|
211 |
return topk_targets, topk_masks
|
212 |
|
213 |
+
def filter_duplicates(self, target_matrix: Tensor, topk_mask: Tensor):
|
214 |
"""
|
215 |
Filter the maximum suitability target index of each anchor.
|
216 |
|
|
|
220 |
Returns:
|
221 |
unique_indices [batch x anchors x 1]: The index of the best targets for each anchors
|
222 |
"""
|
223 |
+
duplicates = (topk_mask.sum(1, keepdim=True) > 1).repeat([1, topk_mask.size(1), 1])
|
224 |
+
max_idx = F.one_hot(target_matrix.argmax(1), topk_mask.size(1)).permute(0, 2, 1)
|
225 |
+
topk_mask = torch.where(duplicates, max_idx, topk_mask)
|
226 |
+
unique_indices = topk_mask.argmax(dim=1)
|
227 |
+
return unique_indices[..., None], topk_mask.sum(1), topk_mask
|
228 |
|
229 |
def __call__(self, target: Tensor, predict: Tuple[Tensor]) -> Tuple[Tensor, Tensor]:
|
230 |
"""
|
|
|
251 |
topk_targets, topk_mask = self.filter_topk(target_matrix, topk=self.topk)
|
252 |
|
253 |
# delete one anchor pred assign to mutliple gts
|
254 |
+
unique_indices, valid_mask, topk_mask = self.filter_duplicates(iou_mat, topk_mask)
|
|
|
|
|
|
|
255 |
|
256 |
align_bbox = torch.gather(target_bbox, 1, unique_indices.repeat(1, 1, 4))
|
257 |
align_cls = torch.gather(target_cls, 1, unique_indices).squeeze(-1)
|
258 |
align_cls = F.one_hot(align_cls, self.class_num)
|
259 |
|
260 |
# normalize class ditribution
|
261 |
+
iou_mat *= topk_mask
|
262 |
+
target_matrix *= topk_mask
|
263 |
max_target = target_matrix.amax(dim=-1, keepdim=True)
|
264 |
max_iou = iou_mat.amax(dim=-1, keepdim=True)
|
265 |
normalize_term = (target_matrix / (max_target + 1e-9)) * max_iou
|