henry000 commited on
Commit
c0e778f
·
1 Parent(s): 7dd7b62

🚑️ [Fix] #22 mentioned error, make nms run on cpu

Browse files
Files changed (1) hide show
  1. yolo/utils/bounding_box_utils.py +2 -1
yolo/utils/bounding_box_utils.py CHANGED
@@ -292,12 +292,13 @@ class BoxMatcher:
292
 
293
 
294
  def bbox_nms(predicts: Tensor, nms_cfg: NMSConfig):
 
295
  cls_dist, bbox = predicts.split([80, 4], dim=-1)
296
 
297
  # filter class by confidence
298
  cls_val, cls_idx = cls_dist.max(dim=-1, keepdim=True)
299
  valid_mask = cls_val > nms_cfg.min_confidence
300
- valid_cls = cls_idx[valid_mask]
301
  valid_box = bbox[valid_mask.repeat(1, 1, 4)].view(-1, 4)
302
 
303
  batch_idx, *_ = torch.where(valid_mask)
 
292
 
293
 
294
  def bbox_nms(predicts: Tensor, nms_cfg: NMSConfig):
295
+ # TODO change function to class or set 80 to class_num instead of a number
296
  cls_dist, bbox = predicts.split([80, 4], dim=-1)
297
 
298
  # filter class by confidence
299
  cls_val, cls_idx = cls_dist.max(dim=-1, keepdim=True)
300
  valid_mask = cls_val > nms_cfg.min_confidence
301
+ valid_cls = cls_idx[valid_mask].float()
302
  valid_box = bbox[valid_mask.repeat(1, 1, 4)].view(-1, 4)
303
 
304
  batch_idx, *_ = torch.where(valid_mask)