rafale77 commited on
Commit
9776e70
·
unverified ·
1 Parent(s): 0c01afc

torch.ops.torchvision.nms (#860)

Browse files

Don't load the entire torchvision library just for nms when the function is already in the torch library.

Files changed (1) hide show
  1. utils/general.py +1 -2
utils/general.py CHANGED
@@ -17,7 +17,6 @@ import matplotlib.pyplot as plt
17
  import numpy as np
18
  import torch
19
  import torch.nn as nn
20
- import torchvision
21
  import yaml
22
  from scipy.cluster.vq import kmeans
23
  from scipy.signal import butter, filtfilt
@@ -651,7 +650,7 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, merge=False,
651
  # Batched NMS
652
  c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
653
  boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
654
- i = torchvision.ops.boxes.nms(boxes, scores, iou_thres)
655
  if i.shape[0] > max_det: # limit detections
656
  i = i[:max_det]
657
  if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
 
17
  import numpy as np
18
  import torch
19
  import torch.nn as nn
 
20
  import yaml
21
  from scipy.cluster.vq import kmeans
22
  from scipy.signal import butter, filtfilt
 
650
  # Batched NMS
651
  c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
652
  boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
653
+ i = torch.ops.torchvision.nms(boxes, scores, iou_thres)
654
  if i.shape[0] > max_det: # limit detections
655
  i = i[:max_det]
656
  if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)