Abdul
commited on
🔨 fix: BoxMatcher.__call__ now returns all zero anchor matched targets and all False valid mask, if input target has zero annotations in it. (#88)
Browse files
yolo/utils/bounding_box_utils.py
CHANGED
@@ -222,12 +222,37 @@ class BoxMatcher:
|
|
222 |
return unique_indices[..., None]
|
223 |
|
224 |
def __call__(self, target: Tensor, predict: Tuple[Tensor]) -> Tuple[Tensor, Tensor]:
|
225 |
-
"""
|
226 |
-
1. For each anchor prediction, find the highest suitability targets
|
227 |
-
2.
|
228 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
229 |
"""
|
230 |
predict_cls, predict_bbox = predict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
231 |
target_cls, target_bbox = target.split([1, 4], dim=-1) # B x N x (C B) -> B x N x C, B x N x B
|
232 |
target_cls = target_cls.long().clamp(0)
|
233 |
|
@@ -261,8 +286,8 @@ class BoxMatcher:
|
|
261 |
normalize_term = (target_matrix / (max_target + 1e-9)) * max_iou
|
262 |
normalize_term = normalize_term.permute(0, 2, 1).gather(2, unique_indices)
|
263 |
align_cls = align_cls * normalize_term * valid_mask[:, :, None]
|
264 |
-
|
265 |
-
return
|
266 |
|
267 |
|
268 |
class Vec2Box:
|
|
|
222 |
return unique_indices[..., None]
|
223 |
|
224 |
def __call__(self, target: Tensor, predict: Tuple[Tensor]) -> Tuple[Tensor, Tensor]:
|
225 |
+
"""Matches each target to the most suitable anchor.
|
226 |
+
1. For each anchor prediction, find the highest suitability targets.
|
227 |
+
2. Match target to the best anchor.
|
228 |
+
3. Noramlize the class probilities of targets.
|
229 |
+
|
230 |
+
Args:
|
231 |
+
target: The ground truth class and bounding box information
|
232 |
+
as tensor of size [batch x targets x 5].
|
233 |
+
predict: Tuple of predicted class and bounding box tensors.
|
234 |
+
Class tensor is of size [batch x anchors x class]
|
235 |
+
Bounding box tensor is of size [batch x anchors x 4].
|
236 |
+
|
237 |
+
Returns:
|
238 |
+
anchor_matched_targets: Tensor of size [batch x anchors x (class + 4)].
|
239 |
+
A tensor assigning each target/gt to the best fitting anchor.
|
240 |
+
The class probabilities are normalized.
|
241 |
+
valid_mask: Bool tensor of shape [batch x anchors].
|
242 |
+
True if a anchor has a target/gt assigned to it.
|
243 |
"""
|
244 |
predict_cls, predict_bbox = predict
|
245 |
+
|
246 |
+
# return if target has no gt information.
|
247 |
+
n_targets = target.shape[1]
|
248 |
+
if n_targets == 0:
|
249 |
+
device = predict_bbox.device
|
250 |
+
align_cls = torch.zeros_like(predict_cls, device=device)
|
251 |
+
align_bbox = torch.zeros_like(predict_bbox, device=device)
|
252 |
+
valid_mask = torch.zeros(predict_cls.shape[:2], dtype=bool, device=device)
|
253 |
+
anchor_matched_targets = torch.cat([align_cls, align_bbox], dim=-1)
|
254 |
+
return anchor_matched_targets, valid_mask
|
255 |
+
|
256 |
target_cls, target_bbox = target.split([1, 4], dim=-1) # B x N x (C B) -> B x N x C, B x N x B
|
257 |
target_cls = target_cls.long().clamp(0)
|
258 |
|
|
|
286 |
normalize_term = (target_matrix / (max_target + 1e-9)) * max_iou
|
287 |
normalize_term = normalize_term.permute(0, 2, 1).gather(2, unique_indices)
|
288 |
align_cls = align_cls * normalize_term * valid_mask[:, :, None]
|
289 |
+
anchor_matched_targets = torch.cat([align_cls, align_bbox], dim=-1)
|
290 |
+
return anchor_matched_targets, valid_mask
|
291 |
|
292 |
|
293 |
class Vec2Box:
|