from pathlib import Path import torch from torch import Tensor from yolo.config.config import Config from yolo.model.yolo import create_model from yolo.utils.logger import logger class FastModelLoader: def __init__(self, cfg: Config): self.cfg = cfg self.compiler = cfg.task.fast_inference self.class_num = cfg.dataset.class_num self._validate_compiler() if cfg.weight == True: cfg.weight = Path("weights") / f"{cfg.model.name}.pt" self.model_path = f"{Path(cfg.weight).stem}.{self.compiler}" def _validate_compiler(self): if self.compiler not in ["onnx", "trt", "deploy"]: logger.warning(f":warning: Compiler '{self.compiler}' is not supported. Using original model.") self.compiler = None if self.cfg.device == "mps" and self.compiler == "trt": logger.warning(":red_apple: TensorRT does not support MPS devices. Using original model.") self.compiler = None def load_model(self, device): if self.compiler == "onnx": return self._load_onnx_model(device) elif self.compiler == "trt": return self._load_trt_model().to(device) elif self.compiler == "deploy": self.cfg.model.model.auxiliary = {} return create_model(self.cfg.model, class_num=self.class_num, weight_path=self.cfg.weight).to(device) def _load_onnx_model(self, device): from onnxruntime import InferenceSession def onnx_forward(self: InferenceSession, x: Tensor): x = {self.get_inputs()[0].name: x.cpu().numpy()} model_outputs, layer_output = [], [] for idx, predict in enumerate(self.run(None, x)): layer_output.append(torch.from_numpy(predict).to(device)) if idx % 3 == 2: model_outputs.append(layer_output) layer_output = [] if len(model_outputs) == 6: model_outputs = model_outputs[:3] return {"Main": model_outputs} InferenceSession.__call__ = onnx_forward if device == "cpu": providers = ["CPUExecutionProvider"] else: providers = ["CUDAExecutionProvider"] try: ort_session = InferenceSession(self.model_path, providers=providers) logger.info(":rocket: Using ONNX as MODEL frameworks!") except Exception as e: logger.warning(f"🈳 Error loading ONNX model: {e}") ort_session = self._create_onnx_model(providers) return ort_session def _create_onnx_model(self, providers): from onnxruntime import InferenceSession from torch.onnx import export model = create_model(self.cfg.model, class_num=self.class_num, weight_path=self.cfg.weight).eval() dummy_input = torch.ones((1, 3, *self.cfg.image_size)) export( model, dummy_input, self.model_path, input_names=["input"], output_names=["output"], dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}, ) logger.info(f":inbox_tray: ONNX model saved to {self.model_path}") return InferenceSession(self.model_path, providers=providers) def _load_trt_model(self): from torch2trt import TRTModule try: model_trt = TRTModule() model_trt.load_state_dict(torch.load(self.model_path)) logger.info(":rocket: Using TensorRT as MODEL frameworks!") except FileNotFoundError: logger.warning(f"🈳 No found model weight at {self.model_path}") model_trt = self._create_trt_model() return model_trt def _create_trt_model(self): from torch2trt import torch2trt model = create_model(self.cfg.model, class_num=self.class_num, weight_path=self.cfg.weight).eval() dummy_input = torch.ones((1, 3, *self.cfg.image_size)).cuda() logger.info(f"♻️ Creating TensorRT model") model_trt = torch2trt(model.cuda(), [dummy_input]) torch.save(model_trt.state_dict(), self.model_path) logger.info(f":inbox_tray: TensorRT model saved to {self.model_path}") return model_trt