henry000 commited on
Commit
72c155f
·
1 Parent(s): 97211aa

🚑️ [Fix] loss function mask bugs, use bool

Browse files
Files changed (1) hide show
  1. yolo/utils/bounding_box_utils.py +1 -1
yolo/utils/bounding_box_utils.py CHANGED
@@ -293,7 +293,7 @@ class BoxMatcher:
293
  normalize_term = normalize_term.permute(0, 2, 1).gather(2, unique_indices)
294
  align_cls = align_cls * normalize_term * valid_mask[:, :, None]
295
  anchor_matched_targets = torch.cat([align_cls, align_bbox], dim=-1)
296
- return anchor_matched_targets, valid_mask
297
 
298
 
299
  class Vec2Box:
 
293
  normalize_term = normalize_term.permute(0, 2, 1).gather(2, unique_indices)
294
  align_cls = align_cls * normalize_term * valid_mask[:, :, None]
295
  anchor_matched_targets = torch.cat([align_cls, align_bbox], dim=-1)
296
+ return anchor_matched_targets, valid_mask.bool()
297
 
298
 
299
  class Vec2Box: