akhaliq HF Staff commited on
Commit
4e227ea
·
1 Parent(s): 045f300

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -0
app.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import onnx
2
+ import numpy as np
3
+ import onnxruntime as ort
4
+ from PIL import Image
5
+ import cv2
6
+ import os
7
+ import gradio as gr
8
+
9
+ os.system("wget https://s3.amazonaws.com/onnx-model-zoo/synset.txt")
10
+
11
+
12
+ with open('synset.txt', 'r') as f:
13
+ labels = [l.rstrip() for l in f]
14
+
15
+ os.system("wget https://github.com/onnx/models/raw/main/vision/classification/mnist/model/mnist-8.onnx")
16
+
17
+ os.system("wget https://s3.amazonaws.com/model-server/inputs/kitten.jpg")
18
+
19
+
20
+
21
+ model_path = 'shufflenet-v2-10.onnx'
22
+ model = onnx.load(model_path)
23
+ session = ort.InferenceSession(model.SerializeToString())
24
+
25
+ def get_image(path):
26
+ with Image.open(path) as img:
27
+ img = np.array(img.convert('RGB'))
28
+ return img
29
+
30
+ def preprocess(img):
31
+ img = img / 255.
32
+ img = cv2.resize(img, (256, 256))
33
+ h, w = img.shape[0], img.shape[1]
34
+ y0 = (h - 224) // 2
35
+ x0 = (w - 224) // 2
36
+ img = img[y0 : y0+224, x0 : x0+224, :]
37
+ img = (img - [0.485, 0.456, 0.406]) / [0.229, 0.224, 0.225]
38
+ img = np.transpose(img, axes=[2, 0, 1])
39
+ img = img.astype(np.float32)
40
+ img = np.expand_dims(img, axis=0)
41
+ return img
42
+
43
+ def predict(path):
44
+ img = get_image(path)
45
+ img = preprocess(img)
46
+ ort_inputs = {session.get_inputs()[0].name: img}
47
+ preds = session.run(None, ort_inputs)[0]
48
+ preds = np.squeeze(preds)
49
+ a = np.argsort(preds)
50
+ results = {}
51
+ for i in a[0:5]:
52
+ results[labels[a[i]]] = float(preds[a[i]])
53
+ return results
54
+
55
+
56
+ title="ShuffleNet-v2"
57
+ description="ShuffleNet is a deep convolutional network for image classification. ShuffleNetV2 is an improved architecture that is the state-of-the-art in terms of speed and accuracy tradeoff used for image classification."
58
+
59
+ examples=[['kitten.jpg']]
60
+ gr.Interface(predict,gr.inputs.Image(type='filepath'),"label",title=title,description=description,examples=examples).launch(enable_queue=True,debug=True)