File size: 3,909 Bytes
a7cf768
 
c653e7e
 
 
 
 
 
 
 
 
89c6a27
c653e7e
89c6a27
a7cf768
 
 
 
 
89c6a27
a7cf768
c653e7e
 
a7cf768
c653e7e
 
 
 
 
 
 
89c6a27
 
 
 
 
c653e7e
 
 
 
 
 
819890a
 
 
 
 
 
 
c653e7e
 
 
a7cf768
 
c653e7e
 
a7cf768
c653e7e
 
 
 
a7cf768
c653e7e
 
 
0a3c9de
8942278
c653e7e
 
 
a7cf768
c653e7e
 
 
 
a7cf768
 
c653e7e
 
 
 
 
 
a7cf768
 
c653e7e
a7cf768
 
c653e7e
 
a7cf768
c653e7e
 
0a3c9de
8942278
c653e7e
 
a7cf768
 
c653e7e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import os

import torch
from loguru import logger
from torch import Tensor

from yolo.config.config import Config
from yolo.model.yolo import create_model


class FastModelLoader:
    def __init__(self, cfg: Config, device):
        self.cfg = cfg
        self.device = device
        self.compiler = cfg.task.fast_inference
        self._validate_compiler()
        self.model_path = f"{os.path.splitext(cfg.weight)[0]}.{self.compiler}"

    def _validate_compiler(self):
        if self.compiler not in ["onnx", "trt", "deploy"]:
            logger.warning(f"⚠️ Compiler '{self.compiler}' is not supported. Using original model.")
            self.compiler = None
        if self.cfg.device == "mps" and self.compiler == "trt":
            logger.warning("🍎 TensorRT does not support MPS devices. Using original model.")
            self.compiler = None

    def load_model(self):
        if self.compiler == "onnx":
            return self._load_onnx_model()
        elif self.compiler == "trt":
            return self._load_trt_model()
        elif self.compiler == "deploy":
            self.cfg.model.model.auxiliary = {}
        return create_model(
            self.cfg.model, class_num=self.cfg.class_num, weight_path=self.cfg.weight, device=self.device
        )

    def _load_onnx_model(self):
        from onnxruntime import InferenceSession

        def onnx_forward(self: InferenceSession, x: Tensor):
            x = {self.get_inputs()[0].name: x.cpu().numpy()}
            model_outputs, layer_output = [], []
            for idx, predict in enumerate(self.run(None, x)):
                layer_output.append(torch.from_numpy(predict))
                if idx % 3 == 2:
                    model_outputs.append(layer_output)
                    layer_output = []
            return {"Main": model_outputs}

        InferenceSession.__call__ = onnx_forward
        try:
            ort_session = InferenceSession(self.model_path)
            logger.info("πŸš€ Using ONNX as MODEL frameworks!")
        except Exception as e:
            logger.warning(f"🈳 Error loading ONNX model: {e}")
            ort_session = self._create_onnx_model()
        # TODO: Update if GPU onnx unavailable change to cpu
        self.cfg.device = "cpu"
        return ort_session

    def _create_onnx_model(self):
        from onnxruntime import InferenceSession
        from torch.onnx import export

        model = create_model(self.cfg.model, class_num=self.cfg.class_num, weight_path=self.cfg.weight).eval()
        dummy_input = torch.ones((1, 3, *self.cfg.image_size))
        export(
            model,
            dummy_input,
            self.model_path,
            input_names=["input"],
            output_names=["output"],
            dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
        )
        logger.info(f"πŸ“₯ ONNX model saved to {self.model_path}")
        return InferenceSession(self.model_path)

    def _load_trt_model(self):
        from torch2trt import TRTModule

        try:
            model_trt = TRTModule()
            model_trt.load_state_dict(torch.load(self.model_path))
            logger.info("πŸš€ Using TensorRT as MODEL frameworks!")
        except FileNotFoundError:
            logger.warning(f"🈳 No found model weight at {self.model_path}")
            model_trt = self._create_trt_model()
        return model_trt

    def _create_trt_model(self):
        from torch2trt import torch2trt

        model = create_model(self.cfg.model, class_num=self.cfg.class_num, weight_path=self.cfg.weight).eval()
        dummy_input = torch.ones((1, 3, *self.cfg.image_size))
        logger.info(f"♻️ Creating TensorRT model")
        model_trt = torch2trt(model, [dummy_input])
        torch.save(model_trt.state_dict(), self.model_path)
        logger.info(f"πŸ“₯ TensorRT model saved to {self.model_path}")
        return model_trt