🚑️ [Fix] #22 mentioned error, make nms run on cpu
Browse files
yolo/utils/bounding_box_utils.py
CHANGED
@@ -292,12 +292,13 @@ class BoxMatcher:
|
|
292 |
|
293 |
|
294 |
def bbox_nms(predicts: Tensor, nms_cfg: NMSConfig):
|
|
|
295 |
cls_dist, bbox = predicts.split([80, 4], dim=-1)
|
296 |
|
297 |
# filter class by confidence
|
298 |
cls_val, cls_idx = cls_dist.max(dim=-1, keepdim=True)
|
299 |
valid_mask = cls_val > nms_cfg.min_confidence
|
300 |
-
valid_cls = cls_idx[valid_mask]
|
301 |
valid_box = bbox[valid_mask.repeat(1, 1, 4)].view(-1, 4)
|
302 |
|
303 |
batch_idx, *_ = torch.where(valid_mask)
|
|
|
292 |
|
293 |
|
294 |
def bbox_nms(predicts: Tensor, nms_cfg: NMSConfig):
|
295 |
+
# TODO change function to class or set 80 to class_num instead of a number
|
296 |
cls_dist, bbox = predicts.split([80, 4], dim=-1)
|
297 |
|
298 |
# filter class by confidence
|
299 |
cls_val, cls_idx = cls_dist.max(dim=-1, keepdim=True)
|
300 |
valid_mask = cls_val > nms_cfg.min_confidence
|
301 |
+
valid_cls = cls_idx[valid_mask].float()
|
302 |
valid_box = bbox[valid_mask.repeat(1, 1, 4)].view(-1, 4)
|
303 |
|
304 |
batch_idx, *_ = torch.where(valid_mask)
|