π [Update] NMS, enable multiclass result
Browse files
yolo/utils/bounding_box_utils.py
CHANGED
@@ -411,15 +411,11 @@ def create_converter(model_version: str = "v9-c", *args, **kwargs) -> Union[Anc2
|
|
411 |
def bbox_nms(cls_dist: Tensor, bbox: Tensor, nms_cfg: NMSConfig, confidence: Optional[Tensor] = None):
|
412 |
cls_dist = cls_dist.sigmoid() * (1 if confidence is None else confidence)
|
413 |
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
valid_box = bbox[valid_mask.repeat(1, 1, 4)].view(-1, 4)
|
420 |
-
|
421 |
-
batch_idx, *_ = torch.where(valid_mask)
|
422 |
-
nms_idx = batched_nms(valid_box, valid_con, batch_idx, nms_cfg.min_iou)
|
423 |
predicts_nms = []
|
424 |
for idx in range(cls_dist.size(0)):
|
425 |
instance_idx = nms_idx[idx == batch_idx[nms_idx]]
|
|
|
411 |
def bbox_nms(cls_dist: Tensor, bbox: Tensor, nms_cfg: NMSConfig, confidence: Optional[Tensor] = None):
|
412 |
cls_dist = cls_dist.sigmoid() * (1 if confidence is None else confidence)
|
413 |
|
414 |
+
batch_idx, valid_grid, valid_cls = torch.where(cls_dist > nms_cfg.min_confidence)
|
415 |
+
valid_con = cls_dist[batch_idx, valid_grid, valid_cls]
|
416 |
+
valid_box = bbox[batch_idx, valid_grid]
|
417 |
+
|
418 |
+
nms_idx = batched_nms(valid_box, valid_con, batch_idx + valid_cls * bbox.size(0), nms_cfg.min_iou)
|
|
|
|
|
|
|
|
|
419 |
predicts_nms = []
|
420 |
for idx in range(cls_dist.size(0)):
|
421 |
instance_idx = nms_idx[idx == batch_idx[nms_idx]]
|