File size: 1,961 Bytes
1c25c76
8243d5f
1c25c76
 
 
8243d5f
 
 
1c25c76
8e15e86
8243d5f
 
 
 
20e3a92
8243d5f
1c25c76
8243d5f
1c25c76
 
20e3a92
1c25c76
 
 
 
 
 
 
8243d5f
1c25c76
282c6b9
 
 
 
 
 
 
 
 
 
 
 
 
1c25c76
282c6b9
1c25c76
8243d5f
1c25c76
 
 
 
 
 
8243d5f
1c25c76
 
8243d5f
8e15e86
8243d5f
1c25c76
 
8243d5f
1c25c76
8e15e86
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import onnx
import numpy as np
import onnxruntime as ort
from PIL import Image
import cv2
import os
import gradio as gr

os.system("wget https://s3.amazonaws.com/onnx-model-zoo/synset.txt")


with open('synset.txt', 'r') as f:
    labels = [l.rstrip() for l in f]
    
os.system("wget https://github.com/AK391/models/raw/main/vision/classification/shufflenet/model/shufflenet-v2-10.onnx")

os.system("wget https://s3.amazonaws.com/model-server/inputs/kitten.jpg")



model_path = 'shufflenet-v2-10.onnx'
model = onnx.load(model_path)
session = ort.InferenceSession(model.SerializeToString())

def get_image(path):
    with Image.open(path) as img:
        img = np.array(img.convert('RGB'))
    return img
    
def preprocess(img):
    '''
    Preprocessing required on the images for inference with mxnet gluon
    The function takes path to an image and returns processed tensor
    '''
    transform_fn = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    img = transform_fn(img)
    img = img.expand_dims(axis=0) # batchify

    return img
    

def predict(path):
    img = get_image(path)
    img = preprocess(img)
    ort_inputs = {session.get_inputs()[0].name: img}
    preds = session.run(None, ort_inputs)[0]
    preds = np.squeeze(preds)
    a = np.argsort(preds)
    results = {}
    for i in a[0:5]:    
        results[labels[a[i]]] = float(preds[a[i]])
    return results
       

title="ShuffleNet-v2"
description="ShuffleNet is a deep convolutional network for image classification. ShuffleNetV2 is an improved architecture that is the state-of-the-art in terms of speed and accuracy tradeoff used for image classification."

examples=[['kitten.jpg']]
gr.Interface(predict,gr.inputs.Image(type='filepath'),"label",title=title,description=description,examples=examples).launch(enable_queue=True,debug=True)