File size: 3,978 Bytes
a7cf768
 
c653e7e
 
 
 
 
 
 
 
 
7d7e199
c653e7e
a7cf768
 
f5518c0
 
a7cf768
 
 
89c6a27
a7cf768
c653e7e
 
a7cf768
c653e7e
 
f5518c0
c653e7e
 
 
f5518c0
89c6a27
 
f5518c0
c653e7e
 
 
 
 
 
819890a
 
 
 
 
 
 
c653e7e
 
 
a7cf768
 
c653e7e
 
a7cf768
c653e7e
 
 
 
a7cf768
c653e7e
 
 
0a3c9de
8942278
c653e7e
 
 
a7cf768
c653e7e
 
 
 
a7cf768
 
c653e7e
 
 
 
 
 
a7cf768
 
c653e7e
a7cf768
 
c653e7e
 
a7cf768
c653e7e
 
0a3c9de
200b5c1
c653e7e
200b5c1
a7cf768
 
c653e7e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import os

import torch
from loguru import logger
from torch import Tensor

from yolo.config.config import Config
from yolo.model.yolo import create_model


class FastModelLoader:
    def __init__(self, cfg: Config):
        self.cfg = cfg
        self.compiler = cfg.task.fast_inference
        self._validate_compiler()
        if cfg.weight == True:
            cfg.weight = os.path.join("weights", f"{cfg.model.name}.pt")
        self.model_path = f"{os.path.splitext(cfg.weight)[0]}.{self.compiler}"

    def _validate_compiler(self):
        if self.compiler not in ["onnx", "trt", "deploy"]:
            logger.warning(f"⚠️ Compiler '{self.compiler}' is not supported. Using original model.")
            self.compiler = None
        if self.cfg.device == "mps" and self.compiler == "trt":
            logger.warning("🍎 TensorRT does not support MPS devices. Using original model.")
            self.compiler = None

    def load_model(self, device):
        if self.compiler == "onnx":
            return self._load_onnx_model()
        elif self.compiler == "trt":
            return self._load_trt_model().to(device)
        elif self.compiler == "deploy":
            self.cfg.model.model.auxiliary = {}
        return create_model(self.cfg.model, class_num=self.cfg.class_num, weight_path=self.cfg.weight).to(device)

    def _load_onnx_model(self):
        from onnxruntime import InferenceSession

        def onnx_forward(self: InferenceSession, x: Tensor):
            x = {self.get_inputs()[0].name: x.cpu().numpy()}
            model_outputs, layer_output = [], []
            for idx, predict in enumerate(self.run(None, x)):
                layer_output.append(torch.from_numpy(predict))
                if idx % 3 == 2:
                    model_outputs.append(layer_output)
                    layer_output = []
            return {"Main": model_outputs}

        InferenceSession.__call__ = onnx_forward
        try:
            ort_session = InferenceSession(self.model_path)
            logger.info("πŸš€ Using ONNX as MODEL frameworks!")
        except Exception as e:
            logger.warning(f"🈳 Error loading ONNX model: {e}")
            ort_session = self._create_onnx_model()
        # TODO: Update if GPU onnx unavailable change to cpu
        self.cfg.device = "cpu"
        return ort_session

    def _create_onnx_model(self):
        from onnxruntime import InferenceSession
        from torch.onnx import export

        model = create_model(self.cfg.model, class_num=self.cfg.class_num, weight_path=self.cfg.weight).eval()
        dummy_input = torch.ones((1, 3, *self.cfg.image_size))
        export(
            model,
            dummy_input,
            self.model_path,
            input_names=["input"],
            output_names=["output"],
            dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
        )
        logger.info(f"πŸ“₯ ONNX model saved to {self.model_path}")
        return InferenceSession(self.model_path)

    def _load_trt_model(self):
        from torch2trt import TRTModule

        try:
            model_trt = TRTModule()
            model_trt.load_state_dict(torch.load(self.model_path))
            logger.info("πŸš€ Using TensorRT as MODEL frameworks!")
        except FileNotFoundError:
            logger.warning(f"🈳 No found model weight at {self.model_path}")
            model_trt = self._create_trt_model()
        return model_trt

    def _create_trt_model(self):
        from torch2trt import torch2trt

        model = create_model(self.cfg.model, class_num=self.cfg.class_num, weight_path=self.cfg.weight).eval()
        dummy_input = torch.ones((1, 3, *self.cfg.image_size)).cuda()
        logger.info(f"♻️ Creating TensorRT model")
        model_trt = torch2trt(model.cuda(), [dummy_input])
        torch.save(model_trt.state_dict(), self.model_path)
        logger.info(f"πŸ“₯ TensorRT model saved to {self.model_path}")
        return model_trt