rafale77
commited on
torch.ops.torchvision.nms (#860)
Browse filesDon't load the entire torchvision library just for nms when the function is already in the torch library.
- utils/general.py +1 -2
utils/general.py
CHANGED
@@ -17,7 +17,6 @@ import matplotlib.pyplot as plt
|
|
17 |
import numpy as np
|
18 |
import torch
|
19 |
import torch.nn as nn
|
20 |
-
import torchvision
|
21 |
import yaml
|
22 |
from scipy.cluster.vq import kmeans
|
23 |
from scipy.signal import butter, filtfilt
|
@@ -651,7 +650,7 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, merge=False,
|
|
651 |
# Batched NMS
|
652 |
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
|
653 |
boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
|
654 |
-
i =
|
655 |
if i.shape[0] > max_det: # limit detections
|
656 |
i = i[:max_det]
|
657 |
if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
|
|
|
17 |
import numpy as np
|
18 |
import torch
|
19 |
import torch.nn as nn
|
|
|
20 |
import yaml
|
21 |
from scipy.cluster.vq import kmeans
|
22 |
from scipy.signal import butter, filtfilt
|
|
|
650 |
# Batched NMS
|
651 |
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
|
652 |
boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
|
653 |
+
i = torch.ops.torchvision.nms(boxes, scores, iou_thres)
|
654 |
if i.shape[0] > max_det: # limit detections
|
655 |
i = i[:max_det]
|
656 |
if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
|