henry000 commited on
Commit
f2d4184
Β·
1 Parent(s): e78c98b

πŸ› [Update] NMS, enable multiclass result

Browse files
Files changed (1) hide show
  1. yolo/utils/bounding_box_utils.py +5 -9
yolo/utils/bounding_box_utils.py CHANGED
@@ -411,15 +411,11 @@ def create_converter(model_version: str = "v9-c", *args, **kwargs) -> Union[Anc2
411
  def bbox_nms(cls_dist: Tensor, bbox: Tensor, nms_cfg: NMSConfig, confidence: Optional[Tensor] = None):
412
  cls_dist = cls_dist.sigmoid() * (1 if confidence is None else confidence)
413
 
414
- # filter class by confidence
415
- cls_val, cls_idx = cls_dist.max(dim=-1, keepdim=True)
416
- valid_mask = cls_val > nms_cfg.min_confidence
417
- valid_cls = cls_idx[valid_mask].float()
418
- valid_con = cls_val[valid_mask].float()
419
- valid_box = bbox[valid_mask.repeat(1, 1, 4)].view(-1, 4)
420
-
421
- batch_idx, *_ = torch.where(valid_mask)
422
- nms_idx = batched_nms(valid_box, valid_con, batch_idx, nms_cfg.min_iou)
423
  predicts_nms = []
424
  for idx in range(cls_dist.size(0)):
425
  instance_idx = nms_idx[idx == batch_idx[nms_idx]]
 
411
  def bbox_nms(cls_dist: Tensor, bbox: Tensor, nms_cfg: NMSConfig, confidence: Optional[Tensor] = None):
412
  cls_dist = cls_dist.sigmoid() * (1 if confidence is None else confidence)
413
 
414
+ batch_idx, valid_grid, valid_cls = torch.where(cls_dist > nms_cfg.min_confidence)
415
+ valid_con = cls_dist[batch_idx, valid_grid, valid_cls]
416
+ valid_box = bbox[batch_idx, valid_grid]
417
+
418
+ nms_idx = batched_nms(valid_box, valid_con, batch_idx + valid_cls * bbox.size(0), nms_cfg.min_iou)
 
 
 
 
419
  predicts_nms = []
420
  for idx in range(cls_dist.size(0)):
421
  instance_idx = nms_idx[idx == batch_idx[nms_idx]]