🚑️ [Fix] when creating onnx, trt force use cuda
Browse files
yolo/utils/deploy_utils.py
CHANGED
@@ -51,8 +51,8 @@ class FastModelLoader:
|
|
51 |
from onnxruntime import InferenceSession
|
52 |
from torch.onnx import export
|
53 |
|
54 |
-
model = create_model(self.cfg).eval()
|
55 |
-
dummy_input = torch.ones((1, 3, *self.cfg.image_size))
|
56 |
export(
|
57 |
model,
|
58 |
dummy_input,
|
@@ -80,8 +80,8 @@ class FastModelLoader:
|
|
80 |
def _create_trt_weight(self):
|
81 |
from torch2trt import torch2trt
|
82 |
|
83 |
-
model = create_model(self.cfg).eval()
|
84 |
-
dummy_input = torch.ones((1, 3, *self.cfg.image_size))
|
85 |
logger.info(f"♻️ Creating TensorRT model")
|
86 |
model_trt = torch2trt(model, [dummy_input])
|
87 |
torch.save(model_trt.state_dict(), self.weight)
|
|
|
51 |
from onnxruntime import InferenceSession
|
52 |
from torch.onnx import export
|
53 |
|
54 |
+
model = create_model(self.cfg).eval()
|
55 |
+
dummy_input = torch.ones((1, 3, *self.cfg.image_size))
|
56 |
export(
|
57 |
model,
|
58 |
dummy_input,
|
|
|
80 |
def _create_trt_weight(self):
|
81 |
from torch2trt import torch2trt
|
82 |
|
83 |
+
model = create_model(self.cfg).eval()
|
84 |
+
dummy_input = torch.ones((1, 3, *self.cfg.image_size))
|
85 |
logger.info(f"♻️ Creating TensorRT model")
|
86 |
model_trt = torch2trt(model, [dummy_input])
|
87 |
torch.save(model_trt.state_dict(), self.weight)
|