henry000 commited on
Commit
c653e7e
·
1 Parent(s): 6eb7d56

⚗️ [Add] Fast Inference {ONNX, TensorRT}, unstable!

Browse files
.gitignore CHANGED
@@ -118,6 +118,8 @@ runs
118
  # Datasets and model checkpoints
119
  *.pth
120
  *.pt
 
 
121
 
122
  # Image files
123
  *.png
 
118
  # Datasets and model checkpoints
119
  *.pth
120
  *.pt
121
+ *.trt
122
+ *.onnx
123
 
124
  # Image files
125
  *.png
yolo/config/config.py CHANGED
@@ -115,6 +115,7 @@ class InferenceConfig:
115
  task: str
116
  source: Union[str, int]
117
  nms: NMSConfig
 
118
 
119
 
120
  @dataclass
 
115
  task: str
116
  source: Union[str, int]
117
  nms: NMSConfig
118
+ fast_inference: Optional[None]
119
 
120
 
121
  @dataclass
yolo/config/task/inference.yaml CHANGED
@@ -1,5 +1,6 @@
1
  task: inference
2
  source: demo/images/inference/image.png
 
3
  data:
4
  batch_size: 16
5
  shuffle: False
 
1
  task: inference
2
  source: demo/images/inference/image.png
3
+ fast_inference: # onnx, trt or Empty
4
  data:
5
  batch_size: 16
6
  shuffle: False
yolo/lazy.py CHANGED
@@ -11,6 +11,7 @@ from yolo.config.config import Config
11
  from yolo.model.yolo import create_model
12
  from yolo.tools.data_loader import create_dataloader
13
  from yolo.tools.solver import ModelTester, ModelTrainer
 
14
  from yolo.utils.logging_utils import custom_logger, validate_log_directory
15
 
16
 
@@ -20,7 +21,11 @@ def main(cfg: Config):
20
  save_path = validate_log_directory(cfg, cfg.name)
21
  dataloader = create_dataloader(cfg)
22
  device = torch.device(cfg.device)
23
- model = create_model(cfg).to(device)
 
 
 
 
24
 
25
  if cfg.task.task == "train":
26
  trainer = ModelTrainer(cfg, model, save_path, device)
 
11
  from yolo.model.yolo import create_model
12
  from yolo.tools.data_loader import create_dataloader
13
  from yolo.tools.solver import ModelTester, ModelTrainer
14
+ from yolo.utils.deploy_utils import FastModelLoader
15
  from yolo.utils.logging_utils import custom_logger, validate_log_directory
16
 
17
 
 
21
  save_path = validate_log_directory(cfg, cfg.name)
22
  dataloader = create_dataloader(cfg)
23
  device = torch.device(cfg.device)
24
+ if cfg.task.fast_inference:
25
+ model = FastModelLoader(cfg).load_model()
26
+ device = torch.device(cfg.device)
27
+ else:
28
+ model = create_model(cfg).to(device)
29
 
30
  if cfg.task.task == "train":
31
  trainer = ModelTrainer(cfg, model, save_path, device)
yolo/utils/deploy_utils.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from loguru import logger
3
+ from torch import Tensor
4
+
5
+ from yolo.config.config import Config
6
+ from yolo.model.yolo import create_model
7
+
8
+
9
+ class FastModelLoader:
10
+ def __init__(self, cfg: Config):
11
+ self.cfg = cfg
12
+ self.compiler = self.cfg.task.fast_inference
13
+ if self.compiler not in ["onnx", "trt"]:
14
+ logger.warning(f"⚠️ {self.compiler} is not supported, if it is spelled wrong? Select origin model")
15
+ self.compiler = None
16
+ if self.cfg.device == "mps" and self.compiler == "trt":
17
+ logger.warning("🍎 TensorRT does not support MPS devices, select origin model")
18
+ self.compiler = None
19
+ self.weight = cfg.weight.split(".")[0] + "." + self.compiler
20
+
21
+ def load_model(self):
22
+ if self.compiler == "onnx":
23
+ logger.info("🚀 Try to use ONNX")
24
+ return self._load_onnx_model()
25
+ elif self.compiler == "trt":
26
+ logger.info("🚀 Try to use TensorRT")
27
+ return self._load_trt_model()
28
+ else:
29
+ return create_model(self.cfg)
30
+
31
+ def _load_onnx_model(self):
32
+ from onnxruntime import InferenceSession
33
+
34
+ def onnx_forward(self: InferenceSession, x: Tensor):
35
+ x = {self.get_inputs()[0].name: x.cpu().numpy()}
36
+ x = [torch.from_numpy(y) for y in self.run(None, x)]
37
+ return [x]
38
+
39
+ InferenceSession.__call__ = onnx_forward
40
+
41
+ try:
42
+ ort_session = InferenceSession(self.weight, providers=["CPUExecutionProvider"])
43
+ except Exception as e:
44
+ logger.warning(f"🈳 Error loading ONNX model: {e}")
45
+ ort_session = self._create_onnx_weight()
46
+ # TODO: Update if GPU onnx unavailable change to cpu
47
+ self.cfg.device = "cpu"
48
+ return ort_session
49
+
50
+ def _create_onnx_weight(self):
51
+ from onnxruntime import InferenceSession
52
+ from torch.onnx import export
53
+
54
+ model = create_model(self.cfg).eval().cuda()
55
+ dummy_input = torch.ones((1, 3, *self.cfg.image_size)).cuda()
56
+ export(
57
+ model,
58
+ dummy_input,
59
+ self.weight,
60
+ input_names=["input"],
61
+ output_names=["output"],
62
+ dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
63
+ )
64
+ logger.info(f"📥 ONNX model saved to {self.weight} ")
65
+ return InferenceSession(self.weight, providers=["CPUExecutionProvider"])
66
+
67
+ def _load_trt_model(self):
68
+ from torch2trt import TRTModule
69
+
70
+ model_trt = TRTModule()
71
+
72
+ try:
73
+ model_trt = TRTModule()
74
+ model_trt.load_state_dict(torch.load(self.weight))
75
+ except FileNotFoundError:
76
+ logger.warning(f"🈳 No found model weight at {self.weight}")
77
+ model_trt = self._create_trt_weight()
78
+ return model_trt
79
+
80
+ def _create_trt_weight(self):
81
+ from torch2trt import torch2trt
82
+
83
+ model = create_model(self.cfg).eval().cuda()
84
+ dummy_input = torch.ones((1, 3, *self.cfg.image_size)).cuda()
85
+ logger.info(f"♻️ Creating TensorRT model")
86
+ model_trt = torch2trt(model, [dummy_input])
87
+ torch.save(model_trt.state_dict(), self.weight)
88
+ logger.info(f"📥 TensorRT model saved to {self.weight}")
89
+ return model_trt