File size: 2,203 Bytes
2b7bb94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torch
import timm
import numpy as np
import onnx
import onnxruntime
from PIL import Image

def export_onnx(model, weight_path, export_path, input_size:list, device="cpu"):
    """
    Save model with weights as onnx file
    """
    torch.set_default_device(device)
    
    weights = torch.load(f=weight_path)
    model.load_state_dict(weights)
    model.eval()
    
    example_input = torch.empty(input_size)
    
    # 모델 변환
    torch.onnx.export(model,
                      example_input,
                      export_path,
                      export_params=True,
                      do_constant_folding=True,
                      input_names = ['input'],
                      output_names = ['output'],
                      dynamic_axes={'input' : {0 : 'batch_size'},
                                    'output' : {0 : 'batch_size'}})
    print("[info] The model has succesfull exported.")
    print(f"[info] File Path : {export_path}")

class Onnx_deploy_model():
    def __init__(self, model_path, img_size):
        onnx_model = onnx.load(model_path)
        onnx.checker.check_model(onnx_model)
        
        self.ort_session = onnxruntime.InferenceSession(model_path)
        self.transform = build_transform(img_size)
        
    def run(self, x, return_prob=True):
        """
        input : Image(PIL or Numpy)
        output : prob or logits
        """
        # img = Image.open(x).convert("RGB")
        x = self.transform(x).unsqueeze(dim=0)
        ort_inputs = {self.ort_session.get_inputs()[0].name: to_numpy(x)}
        ort_outputs = self.ort_session.run(None, ort_inputs)
        
        if return_prob:
            ort_outputs = softmax(ort_outputs)
        
        return ort_outputs.squeeze()

def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

def softmax(x):
    max_num = np.max(x)   
    exp_a = np.exp(x - max_num) # to prevent OverFlow
    sum_exp_a = np.sum(exp_a)
    y = exp_a / sum_exp_a

    return y

def build_transform(input_size,interpolation="bicubic"):
    return timm.data.create_transform(input_size=input_size, interpolation=interpolation, is_training=False)