π [Fix] onnx deplying bug
Browse files
yolo/utils/bounding_box_utils.py
CHANGED
@@ -275,10 +275,6 @@ class Vec2Box:
|
|
275 |
logger.info("π§Έ Found no stride of model, performed a dummy test for auto-anchor size")
|
276 |
self.strides = self.create_auto_anchor(model, image_size)
|
277 |
|
278 |
-
# TODO: this is a exception of onnx, remove it when onnx device if fixed
|
279 |
-
if not isinstance(model, YOLO):
|
280 |
-
device = torch.device("cpu")
|
281 |
-
|
282 |
anchor_grid, scaler = generate_anchors(image_size, self.strides)
|
283 |
self.anchor_grid, self.scaler = anchor_grid.to(device), scaler.to(device)
|
284 |
|
|
|
275 |
logger.info("π§Έ Found no stride of model, performed a dummy test for auto-anchor size")
|
276 |
self.strides = self.create_auto_anchor(model, image_size)
|
277 |
|
|
|
|
|
|
|
|
|
278 |
anchor_grid, scaler = generate_anchors(image_size, self.strides)
|
279 |
self.anchor_grid, self.scaler = anchor_grid.to(device), scaler.to(device)
|
280 |
|
yolo/utils/deploy_utils.py
CHANGED
@@ -27,38 +27,41 @@ class FastModelLoader:
|
|
27 |
|
28 |
def load_model(self, device):
|
29 |
if self.compiler == "onnx":
|
30 |
-
return self._load_onnx_model()
|
31 |
elif self.compiler == "trt":
|
32 |
return self._load_trt_model().to(device)
|
33 |
elif self.compiler == "deploy":
|
34 |
self.cfg.model.model.auxiliary = {}
|
35 |
return create_model(self.cfg.model, class_num=self.cfg.class_num, weight_path=self.cfg.weight).to(device)
|
36 |
|
37 |
-
def _load_onnx_model(self):
|
38 |
from onnxruntime import InferenceSession
|
39 |
|
40 |
def onnx_forward(self: InferenceSession, x: Tensor):
|
41 |
x = {self.get_inputs()[0].name: x.cpu().numpy()}
|
42 |
model_outputs, layer_output = [], []
|
43 |
for idx, predict in enumerate(self.run(None, x)):
|
44 |
-
layer_output.append(torch.from_numpy(predict))
|
45 |
if idx % 3 == 2:
|
46 |
model_outputs.append(layer_output)
|
47 |
layer_output = []
|
48 |
return {"Main": model_outputs}
|
49 |
|
50 |
InferenceSession.__call__ = onnx_forward
|
|
|
|
|
|
|
|
|
|
|
51 |
try:
|
52 |
-
ort_session = InferenceSession(self.model_path)
|
53 |
logger.info("π Using ONNX as MODEL frameworks!")
|
54 |
except Exception as e:
|
55 |
logger.warning(f"π³ Error loading ONNX model: {e}")
|
56 |
-
ort_session = self._create_onnx_model()
|
57 |
-
# TODO: Update if GPU onnx unavailable change to cpu
|
58 |
-
self.cfg.device = "cpu"
|
59 |
return ort_session
|
60 |
|
61 |
-
def _create_onnx_model(self):
|
62 |
from onnxruntime import InferenceSession
|
63 |
from torch.onnx import export
|
64 |
|
@@ -73,7 +76,7 @@ class FastModelLoader:
|
|
73 |
dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
|
74 |
)
|
75 |
logger.info(f"π₯ ONNX model saved to {self.model_path}")
|
76 |
-
return InferenceSession(self.model_path)
|
77 |
|
78 |
def _load_trt_model(self):
|
79 |
from torch2trt import TRTModule
|
|
|
27 |
|
28 |
def load_model(self, device):
|
29 |
if self.compiler == "onnx":
|
30 |
+
return self._load_onnx_model(device)
|
31 |
elif self.compiler == "trt":
|
32 |
return self._load_trt_model().to(device)
|
33 |
elif self.compiler == "deploy":
|
34 |
self.cfg.model.model.auxiliary = {}
|
35 |
return create_model(self.cfg.model, class_num=self.cfg.class_num, weight_path=self.cfg.weight).to(device)
|
36 |
|
37 |
+
def _load_onnx_model(self, device):
|
38 |
from onnxruntime import InferenceSession
|
39 |
|
40 |
def onnx_forward(self: InferenceSession, x: Tensor):
|
41 |
x = {self.get_inputs()[0].name: x.cpu().numpy()}
|
42 |
model_outputs, layer_output = [], []
|
43 |
for idx, predict in enumerate(self.run(None, x)):
|
44 |
+
layer_output.append(torch.from_numpy(predict).to(device))
|
45 |
if idx % 3 == 2:
|
46 |
model_outputs.append(layer_output)
|
47 |
layer_output = []
|
48 |
return {"Main": model_outputs}
|
49 |
|
50 |
InferenceSession.__call__ = onnx_forward
|
51 |
+
|
52 |
+
if device == "cpu":
|
53 |
+
providers = ["CPUExecutionProvider"]
|
54 |
+
else:
|
55 |
+
providers = ["CUDAExecutionProvider"]
|
56 |
try:
|
57 |
+
ort_session = InferenceSession(self.model_path, providers=providers)
|
58 |
logger.info("π Using ONNX as MODEL frameworks!")
|
59 |
except Exception as e:
|
60 |
logger.warning(f"π³ Error loading ONNX model: {e}")
|
61 |
+
ort_session = self._create_onnx_model(providers)
|
|
|
|
|
62 |
return ort_session
|
63 |
|
64 |
+
def _create_onnx_model(self, providers):
|
65 |
from onnxruntime import InferenceSession
|
66 |
from torch.onnx import export
|
67 |
|
|
|
76 |
dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
|
77 |
)
|
78 |
logger.info(f"π₯ ONNX model saved to {self.model_path}")
|
79 |
+
return InferenceSession(self.model_path, providers=providers)
|
80 |
|
81 |
def _load_trt_model(self):
|
82 |
from torch2trt import TRTModule
|