glenn-jocher commited on
Commit
948bcdd
·
1 Parent(s): 22d6088

--classes bug fix #17

Browse files
Files changed (1) hide show
  1. utils/utils.py +2 -1
utils/utils.py CHANGED
@@ -460,6 +460,7 @@ def build_targets(p, targets, model):
460
 
461
  return tcls, tbox, indices, anch
462
 
 
463
  def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, fast=False, classes=None, agnostic=False):
464
  """
465
  Performs Non-Maximum Suppression on inference results
@@ -508,7 +509,7 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, fast=False, c
508
 
509
  # Filter by class
510
  if classes:
511
- x = x[(j.view(-1, 1) == torch.tensor(classes, device=j.device)).any(1)]
512
 
513
  # Apply finite constraint
514
  # if not torch.isfinite(x).all():
 
460
 
461
  return tcls, tbox, indices, anch
462
 
463
+
464
  def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, fast=False, classes=None, agnostic=False):
465
  """
466
  Performs Non-Maximum Suppression on inference results
 
509
 
510
  # Filter by class
511
  if classes:
512
+ x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
513
 
514
  # Apply finite constraint
515
  # if not torch.isfinite(x).all():