henry000 commited on
Commit
7a45a10
Β·
1 Parent(s): 46c9d1c

πŸ› [Fix] onnx deplying bug

Browse files
yolo/utils/bounding_box_utils.py CHANGED
@@ -275,10 +275,6 @@ class Vec2Box:
275
  logger.info("🧸 Found no stride of model, performed a dummy test for auto-anchor size")
276
  self.strides = self.create_auto_anchor(model, image_size)
277
 
278
- # TODO: this is a exception of onnx, remove it when onnx device if fixed
279
- if not isinstance(model, YOLO):
280
- device = torch.device("cpu")
281
-
282
  anchor_grid, scaler = generate_anchors(image_size, self.strides)
283
  self.anchor_grid, self.scaler = anchor_grid.to(device), scaler.to(device)
284
 
 
275
  logger.info("🧸 Found no stride of model, performed a dummy test for auto-anchor size")
276
  self.strides = self.create_auto_anchor(model, image_size)
277
 
 
 
 
 
278
  anchor_grid, scaler = generate_anchors(image_size, self.strides)
279
  self.anchor_grid, self.scaler = anchor_grid.to(device), scaler.to(device)
280
 
yolo/utils/deploy_utils.py CHANGED
@@ -27,38 +27,41 @@ class FastModelLoader:
27
 
28
  def load_model(self, device):
29
  if self.compiler == "onnx":
30
- return self._load_onnx_model()
31
  elif self.compiler == "trt":
32
  return self._load_trt_model().to(device)
33
  elif self.compiler == "deploy":
34
  self.cfg.model.model.auxiliary = {}
35
  return create_model(self.cfg.model, class_num=self.cfg.class_num, weight_path=self.cfg.weight).to(device)
36
 
37
- def _load_onnx_model(self):
38
  from onnxruntime import InferenceSession
39
 
40
  def onnx_forward(self: InferenceSession, x: Tensor):
41
  x = {self.get_inputs()[0].name: x.cpu().numpy()}
42
  model_outputs, layer_output = [], []
43
  for idx, predict in enumerate(self.run(None, x)):
44
- layer_output.append(torch.from_numpy(predict))
45
  if idx % 3 == 2:
46
  model_outputs.append(layer_output)
47
  layer_output = []
48
  return {"Main": model_outputs}
49
 
50
  InferenceSession.__call__ = onnx_forward
 
 
 
 
 
51
  try:
52
- ort_session = InferenceSession(self.model_path)
53
  logger.info("πŸš€ Using ONNX as MODEL frameworks!")
54
  except Exception as e:
55
  logger.warning(f"🈳 Error loading ONNX model: {e}")
56
- ort_session = self._create_onnx_model()
57
- # TODO: Update if GPU onnx unavailable change to cpu
58
- self.cfg.device = "cpu"
59
  return ort_session
60
 
61
- def _create_onnx_model(self):
62
  from onnxruntime import InferenceSession
63
  from torch.onnx import export
64
 
@@ -73,7 +76,7 @@ class FastModelLoader:
73
  dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
74
  )
75
  logger.info(f"πŸ“₯ ONNX model saved to {self.model_path}")
76
- return InferenceSession(self.model_path)
77
 
78
  def _load_trt_model(self):
79
  from torch2trt import TRTModule
 
27
 
28
  def load_model(self, device):
29
  if self.compiler == "onnx":
30
+ return self._load_onnx_model(device)
31
  elif self.compiler == "trt":
32
  return self._load_trt_model().to(device)
33
  elif self.compiler == "deploy":
34
  self.cfg.model.model.auxiliary = {}
35
  return create_model(self.cfg.model, class_num=self.cfg.class_num, weight_path=self.cfg.weight).to(device)
36
 
37
+ def _load_onnx_model(self, device):
38
  from onnxruntime import InferenceSession
39
 
40
  def onnx_forward(self: InferenceSession, x: Tensor):
41
  x = {self.get_inputs()[0].name: x.cpu().numpy()}
42
  model_outputs, layer_output = [], []
43
  for idx, predict in enumerate(self.run(None, x)):
44
+ layer_output.append(torch.from_numpy(predict).to(device))
45
  if idx % 3 == 2:
46
  model_outputs.append(layer_output)
47
  layer_output = []
48
  return {"Main": model_outputs}
49
 
50
  InferenceSession.__call__ = onnx_forward
51
+
52
+ if device == "cpu":
53
+ providers = ["CPUExecutionProvider"]
54
+ else:
55
+ providers = ["CUDAExecutionProvider"]
56
  try:
57
+ ort_session = InferenceSession(self.model_path, providers=providers)
58
  logger.info("πŸš€ Using ONNX as MODEL frameworks!")
59
  except Exception as e:
60
  logger.warning(f"🈳 Error loading ONNX model: {e}")
61
+ ort_session = self._create_onnx_model(providers)
 
 
62
  return ort_session
63
 
64
+ def _create_onnx_model(self, providers):
65
  from onnxruntime import InferenceSession
66
  from torch.onnx import export
67
 
 
76
  dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
77
  )
78
  logger.info(f"πŸ“₯ ONNX model saved to {self.model_path}")
79
+ return InferenceSession(self.model_path, providers=providers)
80
 
81
  def _load_trt_model(self):
82
  from torch2trt import TRTModule