henry000 commited on
Commit
a7cf768
·
1 Parent(s): 8942278

♻️ [Refactor] the code in deploy model

Browse files
Files changed (1) hide show
  1. yolo/utils/deploy_utils.py +24 -23
yolo/utils/deploy_utils.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import torch
2
  from loguru import logger
3
  from torch import Tensor
@@ -9,24 +11,24 @@ from yolo.model.yolo import create_model
9
  class FastModelLoader:
10
  def __init__(self, cfg: Config):
11
  self.cfg = cfg
12
- self.compiler = self.cfg.task.fast_inference
 
 
 
 
13
  if self.compiler not in ["onnx", "trt"]:
14
- logger.warning(f"⚠️ {self.compiler} is not supported, if it is spelled wrong? Select origin model")
15
  self.compiler = None
16
  if self.cfg.device == "mps" and self.compiler == "trt":
17
- logger.warning("🍎 TensorRT does not support MPS devices, select origin model")
18
  self.compiler = None
19
- self.weight = cfg.weight.split(".")[0] + "." + self.compiler
20
 
21
  def load_model(self):
22
  if self.compiler == "onnx":
23
- logger.info("🚀 Try to use ONNX")
24
  return self._load_onnx_model()
25
  elif self.compiler == "trt":
26
- logger.info("🚀 Try to use TensorRT")
27
  return self._load_trt_model()
28
- else:
29
- return create_model(self.cfg)
30
 
31
  def _load_onnx_model(self):
32
  from onnxruntime import InferenceSession
@@ -37,17 +39,17 @@ class FastModelLoader:
37
  return [x]
38
 
39
  InferenceSession.__call__ = onnx_forward
40
-
41
  try:
42
- ort_session = InferenceSession(self.weight, providers=["CPUExecutionProvider"])
 
43
  except Exception as e:
44
  logger.warning(f"🈳 Error loading ONNX model: {e}")
45
- ort_session = self._create_onnx_weight()
46
  # TODO: Update if GPU onnx unavailable change to cpu
47
  self.cfg.device = "cpu"
48
  return ort_session
49
 
50
- def _create_onnx_weight(self):
51
  from onnxruntime import InferenceSession
52
  from torch.onnx import export
53
 
@@ -56,34 +58,33 @@ class FastModelLoader:
56
  export(
57
  model,
58
  dummy_input,
59
- self.weight,
60
  input_names=["input"],
61
  output_names=["output"],
62
  dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
63
  )
64
- logger.info(f"📥 ONNX model saved to {self.weight} ")
65
- return InferenceSession(self.weight, providers=["CPUExecutionProvider"])
66
 
67
  def _load_trt_model(self):
68
  from torch2trt import TRTModule
69
 
70
- model_trt = TRTModule()
71
-
72
  try:
73
  model_trt = TRTModule()
74
- model_trt.load_state_dict(torch.load(self.weight))
 
75
  except FileNotFoundError:
76
- logger.warning(f"🈳 No found model weight at {self.weight}")
77
- model_trt = self._create_trt_weight()
78
  return model_trt
79
 
80
- def _create_trt_weight(self):
81
  from torch2trt import torch2trt
82
 
83
  model = create_model(self.cfg).eval()
84
  dummy_input = torch.ones((1, 3, *self.cfg.image_size))
85
  logger.info(f"♻️ Creating TensorRT model")
86
  model_trt = torch2trt(model, [dummy_input])
87
- torch.save(model_trt.state_dict(), self.weight)
88
- logger.info(f"📥 TensorRT model saved to {self.weight}")
89
  return model_trt
 
1
+ import os
2
+
3
  import torch
4
  from loguru import logger
5
  from torch import Tensor
 
11
  class FastModelLoader:
12
  def __init__(self, cfg: Config):
13
  self.cfg = cfg
14
+ self.compiler = cfg.task.fast_inference
15
+ self._validate_compiler()
16
+ self.model_path = f"{os.path.splitext(cfg.weight)[0]}.{self.compiler}"
17
+
18
+ def _validate_compiler(self):
19
  if self.compiler not in ["onnx", "trt"]:
20
+ logger.warning(f"⚠️ Compiler '{self.compiler}' is not supported. Using original model.")
21
  self.compiler = None
22
  if self.cfg.device == "mps" and self.compiler == "trt":
23
+ logger.warning("🍎 TensorRT does not support MPS devices. Using original model.")
24
  self.compiler = None
 
25
 
26
  def load_model(self):
27
  if self.compiler == "onnx":
 
28
  return self._load_onnx_model()
29
  elif self.compiler == "trt":
 
30
  return self._load_trt_model()
31
+ return create_model(self.cfg)
 
32
 
33
  def _load_onnx_model(self):
34
  from onnxruntime import InferenceSession
 
39
  return [x]
40
 
41
  InferenceSession.__call__ = onnx_forward
 
42
  try:
43
+ ort_session = InferenceSession(self.model_path)
44
+ logger.info("🚀 Using ONNX as MODEL frameworks!")
45
  except Exception as e:
46
  logger.warning(f"🈳 Error loading ONNX model: {e}")
47
+ ort_session = self._create_onnx_model()
48
  # TODO: Update if GPU onnx unavailable change to cpu
49
  self.cfg.device = "cpu"
50
  return ort_session
51
 
52
+ def _create_onnx_model(self):
53
  from onnxruntime import InferenceSession
54
  from torch.onnx import export
55
 
 
58
  export(
59
  model,
60
  dummy_input,
61
+ self.model_path,
62
  input_names=["input"],
63
  output_names=["output"],
64
  dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
65
  )
66
+ logger.info(f"📥 ONNX model saved to {self.model_path}")
67
+ return InferenceSession(self.model_path)
68
 
69
  def _load_trt_model(self):
70
  from torch2trt import TRTModule
71
 
 
 
72
  try:
73
  model_trt = TRTModule()
74
+ model_trt.load_state_dict(torch.load(self.model_path))
75
+ logger.info("🚀 Using TensorRT as MODEL frameworks!")
76
  except FileNotFoundError:
77
+ logger.warning(f"🈳 No found model weight at {self.model_path}")
78
+ model_trt = self._create_trt_model()
79
  return model_trt
80
 
81
+ def _create_trt_model(self):
82
  from torch2trt import torch2trt
83
 
84
  model = create_model(self.cfg).eval()
85
  dummy_input = torch.ones((1, 3, *self.cfg.image_size))
86
  logger.info(f"♻️ Creating TensorRT model")
87
  model_trt = torch2trt(model, [dummy_input])
88
+ torch.save(model_trt.state_dict(), self.model_path)
89
+ logger.info(f"📥 TensorRT model saved to {self.model_path}")
90
  return model_trt