LuxExistentia commited on
Commit
2b7bb94
1 Parent(s): aeaf5b8

Upload 3 files

Browse files

First Version of Gender Classification demo

app.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from custom_torch_module.deploy_utils import Onnx_deploy_model
2
+ import gradio as gr
3
+ import time
4
+ from PIL import Image
5
+
6
+ model_path = "deploying model/" + "vit_xsmall_patch16_clip_224(trainble_0.15) (Acc 98.44%, Loss 0.168152).onnx"
7
+ input_size = [1, 3, 224, 224]
8
+ img_size = input_size[-1]
9
+
10
+ title = "Gender Vision mini"
11
+ description = "An ViT(xsmall_clip) based model(fine tuned with Custom dataset : around 800 train images & 200 test iamges) Accuracy : around 98.4% with the custom test dataset. Optimized with ONNX(around 1.7 times faster than PyTorch version on cpu)"
12
+ article = "Through bunch of fine tuning and experiments. !REMEMBER! This model can be wrong."
13
+
14
+ def predict(img):
15
+ start_time = time.time()
16
+ output = onnx_model.run(img, return_prob=True)
17
+ end_time = time.time()
18
+ elapsed_time = end_time - start_time
19
+
20
+ pred_label_and_probs = {"Men" : output[0],"Women" : output[1]}
21
+
22
+ return pred_label_and_probs, elapsed_time
23
+
24
+ onnx_model = Onnx_deploy_model(model_path=model_path, img_size=img_size)
25
+
26
+ # Create the Gradio demo
27
+ demo = gr.Interface(fn=predict,
28
+ inputs=gr.Image(type="pil"),
29
+ outputs=[gr.Label(num_top_classes=2, label="Predictions"),
30
+ gr.Number(label="Prediction time (s)")],
31
+ title=title,
32
+ description=description,
33
+ article=article)
34
+
35
+ # Launch the demo
36
+ demo.launch()
custom_torch_module/deploy_utils.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import timm
3
+ import numpy as np
4
+ import onnx
5
+ import onnxruntime
6
+ from PIL import Image
7
+
8
+ def export_onnx(model, weight_path, export_path, input_size:list, device="cpu"):
9
+ """
10
+ Save model with weights as onnx file
11
+ """
12
+ torch.set_default_device(device)
13
+
14
+ weights = torch.load(f=weight_path)
15
+ model.load_state_dict(weights)
16
+ model.eval()
17
+
18
+ example_input = torch.empty(input_size)
19
+
20
+ # 모델 변환
21
+ torch.onnx.export(model,
22
+ example_input,
23
+ export_path,
24
+ export_params=True,
25
+ do_constant_folding=True,
26
+ input_names = ['input'],
27
+ output_names = ['output'],
28
+ dynamic_axes={'input' : {0 : 'batch_size'},
29
+ 'output' : {0 : 'batch_size'}})
30
+ print("[info] The model has succesfull exported.")
31
+ print(f"[info] File Path : {export_path}")
32
+
33
+ class Onnx_deploy_model():
34
+ def __init__(self, model_path, img_size):
35
+ onnx_model = onnx.load(model_path)
36
+ onnx.checker.check_model(onnx_model)
37
+
38
+ self.ort_session = onnxruntime.InferenceSession(model_path)
39
+ self.transform = build_transform(img_size)
40
+
41
+ def run(self, x, return_prob=True):
42
+ """
43
+ input : Image(PIL or Numpy)
44
+ output : prob or logits
45
+ """
46
+ # img = Image.open(x).convert("RGB")
47
+ x = self.transform(x).unsqueeze(dim=0)
48
+ ort_inputs = {self.ort_session.get_inputs()[0].name: to_numpy(x)}
49
+ ort_outputs = self.ort_session.run(None, ort_inputs)
50
+
51
+ if return_prob:
52
+ ort_outputs = softmax(ort_outputs)
53
+
54
+ return ort_outputs.squeeze()
55
+
56
+ def to_numpy(tensor):
57
+ return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
58
+
59
+ def softmax(x):
60
+ max_num = np.max(x)
61
+ exp_a = np.exp(x - max_num) # to prevent OverFlow
62
+ sum_exp_a = np.sum(exp_a)
63
+ y = exp_a / sum_exp_a
64
+
65
+ return y
66
+
67
+ def build_transform(input_size,interpolation="bicubic"):
68
+ return timm.data.create_transform(input_size=input_size, interpolation=interpolation, is_training=False)
69
+
70
+
71
+
deploying model/vit_xsmall_patch16_clip_224(trainble_0.15) (Acc 98.44%, Loss 0.168152).onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c82878b43b74b203e07fe8506c5d7f977ce5559a51893187d1f4efe79f837675
3
+ size 32699557