Abdul commited on
Commit
e53ff09
·
unverified ·
1 Parent(s): dea5a8a

🔨 fix: BoxMatcher.__call__ now returns all zero anchor matched targets and all False valid mask, if input target has zero annotations in it. (#88)

Browse files
Files changed (1) hide show
  1. yolo/utils/bounding_box_utils.py +31 -6
yolo/utils/bounding_box_utils.py CHANGED
@@ -222,12 +222,37 @@ class BoxMatcher:
222
  return unique_indices[..., None]
223
 
224
  def __call__(self, target: Tensor, predict: Tuple[Tensor]) -> Tuple[Tensor, Tensor]:
225
- """
226
- 1. For each anchor prediction, find the highest suitability targets
227
- 2. Select the targets
228
- 2. Noramlize the class probilities of targets
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
  """
230
  predict_cls, predict_bbox = predict
 
 
 
 
 
 
 
 
 
 
 
231
  target_cls, target_bbox = target.split([1, 4], dim=-1) # B x N x (C B) -> B x N x C, B x N x B
232
  target_cls = target_cls.long().clamp(0)
233
 
@@ -261,8 +286,8 @@ class BoxMatcher:
261
  normalize_term = (target_matrix / (max_target + 1e-9)) * max_iou
262
  normalize_term = normalize_term.permute(0, 2, 1).gather(2, unique_indices)
263
  align_cls = align_cls * normalize_term * valid_mask[:, :, None]
264
-
265
- return torch.cat([align_cls, align_bbox], dim=-1), valid_mask.bool()
266
 
267
 
268
  class Vec2Box:
 
222
  return unique_indices[..., None]
223
 
224
  def __call__(self, target: Tensor, predict: Tuple[Tensor]) -> Tuple[Tensor, Tensor]:
225
+ """Matches each target to the most suitable anchor.
226
+ 1. For each anchor prediction, find the highest suitability targets.
227
+ 2. Match target to the best anchor.
228
+ 3. Noramlize the class probilities of targets.
229
+
230
+ Args:
231
+ target: The ground truth class and bounding box information
232
+ as tensor of size [batch x targets x 5].
233
+ predict: Tuple of predicted class and bounding box tensors.
234
+ Class tensor is of size [batch x anchors x class]
235
+ Bounding box tensor is of size [batch x anchors x 4].
236
+
237
+ Returns:
238
+ anchor_matched_targets: Tensor of size [batch x anchors x (class + 4)].
239
+ A tensor assigning each target/gt to the best fitting anchor.
240
+ The class probabilities are normalized.
241
+ valid_mask: Bool tensor of shape [batch x anchors].
242
+ True if a anchor has a target/gt assigned to it.
243
  """
244
  predict_cls, predict_bbox = predict
245
+
246
+ # return if target has no gt information.
247
+ n_targets = target.shape[1]
248
+ if n_targets == 0:
249
+ device = predict_bbox.device
250
+ align_cls = torch.zeros_like(predict_cls, device=device)
251
+ align_bbox = torch.zeros_like(predict_bbox, device=device)
252
+ valid_mask = torch.zeros(predict_cls.shape[:2], dtype=bool, device=device)
253
+ anchor_matched_targets = torch.cat([align_cls, align_bbox], dim=-1)
254
+ return anchor_matched_targets, valid_mask
255
+
256
  target_cls, target_bbox = target.split([1, 4], dim=-1) # B x N x (C B) -> B x N x C, B x N x B
257
  target_cls = target_cls.long().clamp(0)
258
 
 
286
  normalize_term = (target_matrix / (max_target + 1e-9)) * max_iou
287
  normalize_term = normalize_term.permute(0, 2, 1).gather(2, unique_indices)
288
  align_cls = align_cls * normalize_term * valid_mask[:, :, None]
289
+ anchor_matched_targets = torch.cat([align_cls, align_bbox], dim=-1)
290
+ return anchor_matched_targets, valid_mask
291
 
292
 
293
  class Vec2Box: