LuxExistentia's picture
Upload 3 files
2b7bb94 verified
raw
history blame
2.2 kB
import torch
import timm
import numpy as np
import onnx
import onnxruntime
from PIL import Image
def export_onnx(model, weight_path, export_path, input_size:list, device="cpu"):
"""
Save model with weights as onnx file
"""
torch.set_default_device(device)
weights = torch.load(f=weight_path)
model.load_state_dict(weights)
model.eval()
example_input = torch.empty(input_size)
# 모델 변환
torch.onnx.export(model,
example_input,
export_path,
export_params=True,
do_constant_folding=True,
input_names = ['input'],
output_names = ['output'],
dynamic_axes={'input' : {0 : 'batch_size'},
'output' : {0 : 'batch_size'}})
print("[info] The model has succesfull exported.")
print(f"[info] File Path : {export_path}")
class Onnx_deploy_model():
def __init__(self, model_path, img_size):
onnx_model = onnx.load(model_path)
onnx.checker.check_model(onnx_model)
self.ort_session = onnxruntime.InferenceSession(model_path)
self.transform = build_transform(img_size)
def run(self, x, return_prob=True):
"""
input : Image(PIL or Numpy)
output : prob or logits
"""
# img = Image.open(x).convert("RGB")
x = self.transform(x).unsqueeze(dim=0)
ort_inputs = {self.ort_session.get_inputs()[0].name: to_numpy(x)}
ort_outputs = self.ort_session.run(None, ort_inputs)
if return_prob:
ort_outputs = softmax(ort_outputs)
return ort_outputs.squeeze()
def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
def softmax(x):
max_num = np.max(x)
exp_a = np.exp(x - max_num) # to prevent OverFlow
sum_exp_a = np.sum(exp_a)
y = exp_a / sum_exp_a
return y
def build_transform(input_size,interpolation="bicubic"):
return timm.data.create_transform(input_size=input_size, interpolation=interpolation, is_training=False)