Add `tensorrt>=7.0.0` checks (#6193)
Browse files* Add `tensorrt>=7.0.0` checks
* Update export.py
* Update common.py
* Update export.py
- export.py +6 -6
- models/common.py +1 -1
export.py
CHANGED
@@ -61,8 +61,8 @@ from models.experimental import attempt_load
|
|
61 |
from models.yolo import Detect
|
62 |
from utils.activations import SiLU
|
63 |
from utils.datasets import LoadImages
|
64 |
-
from utils.general import (LOGGER, check_dataset, check_img_size, check_requirements,
|
65 |
-
url2file)
|
66 |
from utils.torch_utils import select_device
|
67 |
|
68 |
|
@@ -174,14 +174,14 @@ def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=F
|
|
174 |
check_requirements(('tensorrt',))
|
175 |
import tensorrt as trt
|
176 |
|
177 |
-
|
178 |
-
if opset == 12: # TensorRT 7 handling https://github.com/ultralytics/yolov5/issues/6012
|
179 |
grid = model.model[-1].anchor_grid
|
180 |
model.model[-1].anchor_grid = [a[..., :1, :1, :] for a in grid]
|
181 |
-
export_onnx(model, im, file,
|
182 |
model.model[-1].anchor_grid = grid
|
183 |
else: # TensorRT >= 8
|
184 |
-
|
|
|
185 |
onnx = file.with_suffix('.onnx')
|
186 |
assert onnx.exists(), f'failed to export ONNX file: {onnx}'
|
187 |
|
|
|
61 |
from models.yolo import Detect
|
62 |
from utils.activations import SiLU
|
63 |
from utils.datasets import LoadImages
|
64 |
+
from utils.general import (LOGGER, check_dataset, check_img_size, check_requirements, check_version, colorstr,
|
65 |
+
file_size, print_args, url2file)
|
66 |
from utils.torch_utils import select_device
|
67 |
|
68 |
|
|
|
174 |
check_requirements(('tensorrt',))
|
175 |
import tensorrt as trt
|
176 |
|
177 |
+
if trt.__version__[0] == 7: # TensorRT 7 handling https://github.com/ultralytics/yolov5/issues/6012
|
|
|
178 |
grid = model.model[-1].anchor_grid
|
179 |
model.model[-1].anchor_grid = [a[..., :1, :1, :] for a in grid]
|
180 |
+
export_onnx(model, im, file, 12, train, False, simplify) # opset 12
|
181 |
model.model[-1].anchor_grid = grid
|
182 |
else: # TensorRT >= 8
|
183 |
+
check_version(trt.__version__, '7.0.0', hard=True) # require tensorrt>=8.0.0
|
184 |
+
export_onnx(model, im, file, 13, train, False, simplify) # opset 13
|
185 |
onnx = file.with_suffix('.onnx')
|
186 |
assert onnx.exists(), f'failed to export ONNX file: {onnx}'
|
187 |
|
models/common.py
CHANGED
@@ -337,7 +337,7 @@ class DetectMultiBackend(nn.Module):
|
|
337 |
elif engine: # TensorRT
|
338 |
LOGGER.info(f'Loading {w} for TensorRT inference...')
|
339 |
import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download
|
340 |
-
check_version(trt.__version__, '
|
341 |
Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
|
342 |
logger = trt.Logger(trt.Logger.INFO)
|
343 |
with open(w, 'rb') as f, trt.Runtime(logger) as runtime:
|
|
|
337 |
elif engine: # TensorRT
|
338 |
LOGGER.info(f'Loading {w} for TensorRT inference...')
|
339 |
import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download
|
340 |
+
check_version(trt.__version__, '7.0.0', hard=True) # require tensorrt>=7.0.0
|
341 |
Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
|
342 |
logger = trt.Logger(trt.Logger.INFO)
|
343 |
with open(w, 'rb') as f, trt.Runtime(logger) as runtime:
|