|
import os |
|
|
|
import torch |
|
from loguru import logger |
|
from torch import Tensor |
|
|
|
from yolo.config.config import Config |
|
from yolo.model.yolo import create_model |
|
|
|
|
|
class FastModelLoader: |
|
def __init__(self, cfg: Config): |
|
self.cfg = cfg |
|
self.compiler = cfg.task.fast_inference |
|
self._validate_compiler() |
|
if cfg.weight == True: |
|
cfg.weight = os.path.join("weights", f"{cfg.model.name}.pt") |
|
self.model_path = f"{os.path.splitext(cfg.weight)[0]}.{self.compiler}" |
|
|
|
def _validate_compiler(self): |
|
if self.compiler not in ["onnx", "trt", "deploy"]: |
|
logger.warning(f"β οΈ Compiler '{self.compiler}' is not supported. Using original model.") |
|
self.compiler = None |
|
if self.cfg.device == "mps" and self.compiler == "trt": |
|
logger.warning("π TensorRT does not support MPS devices. Using original model.") |
|
self.compiler = None |
|
|
|
def load_model(self, device): |
|
if self.compiler == "onnx": |
|
return self._load_onnx_model() |
|
elif self.compiler == "trt": |
|
return self._load_trt_model().to(device) |
|
elif self.compiler == "deploy": |
|
self.cfg.model.model.auxiliary = {} |
|
return create_model(self.cfg.model, class_num=self.cfg.class_num, weight_path=self.cfg.weight).to(device) |
|
|
|
def _load_onnx_model(self): |
|
from onnxruntime import InferenceSession |
|
|
|
def onnx_forward(self: InferenceSession, x: Tensor): |
|
x = {self.get_inputs()[0].name: x.cpu().numpy()} |
|
model_outputs, layer_output = [], [] |
|
for idx, predict in enumerate(self.run(None, x)): |
|
layer_output.append(torch.from_numpy(predict)) |
|
if idx % 3 == 2: |
|
model_outputs.append(layer_output) |
|
layer_output = [] |
|
return {"Main": model_outputs} |
|
|
|
InferenceSession.__call__ = onnx_forward |
|
try: |
|
ort_session = InferenceSession(self.model_path) |
|
logger.info("π Using ONNX as MODEL frameworks!") |
|
except Exception as e: |
|
logger.warning(f"π³ Error loading ONNX model: {e}") |
|
ort_session = self._create_onnx_model() |
|
|
|
self.cfg.device = "cpu" |
|
return ort_session |
|
|
|
def _create_onnx_model(self): |
|
from onnxruntime import InferenceSession |
|
from torch.onnx import export |
|
|
|
model = create_model(self.cfg.model, class_num=self.cfg.class_num, weight_path=self.cfg.weight).eval() |
|
dummy_input = torch.ones((1, 3, *self.cfg.image_size)) |
|
export( |
|
model, |
|
dummy_input, |
|
self.model_path, |
|
input_names=["input"], |
|
output_names=["output"], |
|
dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}, |
|
) |
|
logger.info(f"π₯ ONNX model saved to {self.model_path}") |
|
return InferenceSession(self.model_path) |
|
|
|
def _load_trt_model(self): |
|
from torch2trt import TRTModule |
|
|
|
try: |
|
model_trt = TRTModule() |
|
model_trt.load_state_dict(torch.load(self.model_path)) |
|
logger.info("π Using TensorRT as MODEL frameworks!") |
|
except FileNotFoundError: |
|
logger.warning(f"π³ No found model weight at {self.model_path}") |
|
model_trt = self._create_trt_model() |
|
return model_trt |
|
|
|
def _create_trt_model(self): |
|
from torch2trt import torch2trt |
|
|
|
model = create_model(self.cfg.model, class_num=self.cfg.class_num, weight_path=self.cfg.weight).eval() |
|
dummy_input = torch.ones((1, 3, *self.cfg.image_size)).cuda() |
|
logger.info(f"β»οΈ Creating TensorRT model") |
|
model_trt = torch2trt(model.cuda(), [dummy_input]) |
|
torch.save(model_trt.state_dict(), self.model_path) |
|
logger.info(f"π₯ TensorRT model saved to {self.model_path}") |
|
return model_trt |
|
|