TensorRT 7 export fix (#6235)
Browse files
export.py
CHANGED
@@ -174,7 +174,7 @@ def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=F
|
|
174 |
check_requirements(('tensorrt',))
|
175 |
import tensorrt as trt
|
176 |
|
177 |
-
if trt.__version__[0] == 7: # TensorRT 7 handling https://github.com/ultralytics/yolov5/issues/6012
|
178 |
grid = model.model[-1].anchor_grid
|
179 |
model.model[-1].anchor_grid = [a[..., :1, :1, :] for a in grid]
|
180 |
export_onnx(model, im, file, 12, train, False, simplify) # opset 12
|
|
|
174 |
check_requirements(('tensorrt',))
|
175 |
import tensorrt as trt
|
176 |
|
177 |
+
if trt.__version__[0] == '7': # TensorRT 7 handling https://github.com/ultralytics/yolov5/issues/6012
|
178 |
grid = model.model[-1].anchor_grid
|
179 |
model.model[-1].anchor_grid = [a[..., :1, :1, :] for a in grid]
|
180 |
export_onnx(model, im, file, 12, train, False, simplify) # opset 12
|