File size: 4,283 Bytes
fa09d11 a7cf768 c653e7e 0174b5b c653e7e 7d7e199 c653e7e a7cf768 3eb85fd a7cf768 f5518c0 fa09d11 a7cf768 89c6a27 802cb12 c653e7e 802cb12 c653e7e f5518c0 c653e7e 7a45a10 c653e7e f5518c0 89c6a27 3eb85fd c653e7e 7a45a10 c653e7e 819890a 7a45a10 819890a 100c13d 819890a c653e7e 7a45a10 c653e7e 7a45a10 802cb12 c653e7e 7a45a10 c653e7e 7a45a10 c653e7e 3eb85fd 8942278 c653e7e a7cf768 c653e7e 802cb12 7a45a10 c653e7e a7cf768 802cb12 c653e7e a7cf768 c653e7e a7cf768 c653e7e 3eb85fd 200b5c1 c653e7e 200b5c1 a7cf768 802cb12 c653e7e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
from pathlib import Path
import torch
from torch import Tensor
from yolo.config.config import Config
from yolo.model.yolo import create_model
from yolo.utils.logger import logger
class FastModelLoader:
def __init__(self, cfg: Config):
self.cfg = cfg
self.compiler = cfg.task.fast_inference
self.class_num = cfg.dataset.class_num
self._validate_compiler()
if cfg.weight == True:
cfg.weight = Path("weights") / f"{cfg.model.name}.pt"
self.model_path = f"{Path(cfg.weight).stem}.{self.compiler}"
def _validate_compiler(self):
if self.compiler not in ["onnx", "trt", "deploy"]:
logger.warning(f":warning: Compiler '{self.compiler}' is not supported. Using original model.")
self.compiler = None
if self.cfg.device == "mps" and self.compiler == "trt":
logger.warning(":red_apple: TensorRT does not support MPS devices. Using original model.")
self.compiler = None
def load_model(self, device):
if self.compiler == "onnx":
return self._load_onnx_model(device)
elif self.compiler == "trt":
return self._load_trt_model().to(device)
elif self.compiler == "deploy":
self.cfg.model.model.auxiliary = {}
return create_model(self.cfg.model, class_num=self.class_num, weight_path=self.cfg.weight).to(device)
def _load_onnx_model(self, device):
from onnxruntime import InferenceSession
def onnx_forward(self: InferenceSession, x: Tensor):
x = {self.get_inputs()[0].name: x.cpu().numpy()}
model_outputs, layer_output = [], []
for idx, predict in enumerate(self.run(None, x)):
layer_output.append(torch.from_numpy(predict).to(device))
if idx % 3 == 2:
model_outputs.append(layer_output)
layer_output = []
if len(model_outputs) == 6:
model_outputs = model_outputs[:3]
return {"Main": model_outputs}
InferenceSession.__call__ = onnx_forward
if device == "cpu":
providers = ["CPUExecutionProvider"]
else:
providers = ["CUDAExecutionProvider"]
try:
ort_session = InferenceSession(self.model_path, providers=providers)
logger.info(":rocket: Using ONNX as MODEL frameworks!")
except Exception as e:
logger.warning(f"🈳 Error loading ONNX model: {e}")
ort_session = self._create_onnx_model(providers)
return ort_session
def _create_onnx_model(self, providers):
from onnxruntime import InferenceSession
from torch.onnx import export
model = create_model(self.cfg.model, class_num=self.class_num, weight_path=self.cfg.weight).eval()
dummy_input = torch.ones((1, 3, *self.cfg.image_size))
export(
model,
dummy_input,
self.model_path,
input_names=["input"],
output_names=["output"],
dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
)
logger.info(f":inbox_tray: ONNX model saved to {self.model_path}")
return InferenceSession(self.model_path, providers=providers)
def _load_trt_model(self):
from torch2trt import TRTModule
try:
model_trt = TRTModule()
model_trt.load_state_dict(torch.load(self.model_path))
logger.info(":rocket: Using TensorRT as MODEL frameworks!")
except FileNotFoundError:
logger.warning(f"🈳 No found model weight at {self.model_path}")
model_trt = self._create_trt_model()
return model_trt
def _create_trt_model(self):
from torch2trt import torch2trt
model = create_model(self.cfg.model, class_num=self.class_num, weight_path=self.cfg.weight).eval()
dummy_input = torch.ones((1, 3, *self.cfg.image_size)).cuda()
logger.info(f"♻️ Creating TensorRT model")
model_trt = torch2trt(model.cuda(), [dummy_input])
torch.save(model_trt.state_dict(), self.model_path)
logger.info(f":inbox_tray: TensorRT model saved to {self.model_path}")
return model_trt
|