Adjust NMS time limit warning to batch size (#7156)
Browse files- utils/general.py +7 -4
utils/general.py
CHANGED
@@ -709,6 +709,7 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non
|
|
709 |
list of detections, on (n,6) tensor per image [xyxy, conf, cls]
|
710 |
"""
|
711 |
|
|
|
712 |
nc = prediction.shape[2] - 5 # number of classes
|
713 |
xc = prediction[..., 4] > conf_thres # candidates
|
714 |
|
@@ -719,13 +720,13 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non
|
|
719 |
# Settings
|
720 |
min_wh, max_wh = 2, 7680 # (pixels) minimum and maximum box width and height
|
721 |
max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
|
722 |
-
time_limit =
|
723 |
redundant = True # require redundant detections
|
724 |
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
|
725 |
merge = False # use merge-NMS
|
726 |
|
727 |
-
t = time.time()
|
728 |
-
output = [torch.zeros((0, 6), device=prediction.device)] *
|
729 |
for xi, x in enumerate(prediction): # image index, image inference
|
730 |
# Apply constraints
|
731 |
x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
|
@@ -789,7 +790,9 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non
|
|
789 |
|
790 |
output[xi] = x[i]
|
791 |
if (time.time() - t) > time_limit:
|
792 |
-
|
|
|
|
|
793 |
break # time limit exceeded
|
794 |
|
795 |
return output
|
|
|
709 |
list of detections, on (n,6) tensor per image [xyxy, conf, cls]
|
710 |
"""
|
711 |
|
712 |
+
bs = prediction.shape[0] # batch size
|
713 |
nc = prediction.shape[2] - 5 # number of classes
|
714 |
xc = prediction[..., 4] > conf_thres # candidates
|
715 |
|
|
|
720 |
# Settings
|
721 |
min_wh, max_wh = 2, 7680 # (pixels) minimum and maximum box width and height
|
722 |
max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
|
723 |
+
time_limit = 0.030 * bs # seconds to quit after
|
724 |
redundant = True # require redundant detections
|
725 |
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
|
726 |
merge = False # use merge-NMS
|
727 |
|
728 |
+
t, warn_time = time.time(), True
|
729 |
+
output = [torch.zeros((0, 6), device=prediction.device)] * bs
|
730 |
for xi, x in enumerate(prediction): # image index, image inference
|
731 |
# Apply constraints
|
732 |
x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
|
|
|
790 |
|
791 |
output[xi] = x[i]
|
792 |
if (time.time() - t) > time_limit:
|
793 |
+
if warn_time:
|
794 |
+
LOGGER.warning(f'WARNING: NMS time limit {time_limit:3f}s exceeded')
|
795 |
+
warn_time = False
|
796 |
break # time limit exceeded
|
797 |
|
798 |
return output
|