File size: 1,693 Bytes
1c25c76
8243d5f
1c25c76
 
 
8243d5f
 
 
af56d97
559c82e
af56d97
1c25c76
8e15e86
8243d5f
 
 
 
20e3a92
8243d5f
1c25c76
8243d5f
1c25c76
 
20e3a92
1c25c76
 
 
608aad9
8243d5f
559c82e
 
 
 
 
 
 
282c6b9
 
1c25c76
608aad9
559c82e
 
 
1c25c76
 
 
8243d5f
1c25c76
 
8243d5f
8e15e86
8243d5f
1c25c76
 
8243d5f
1c25c76
608aad9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import onnx
import numpy as np
import onnxruntime as ort
from PIL import Image
import cv2
import os
import gradio as gr

import mxnet
from torchvision import transforms

os.system("wget https://s3.amazonaws.com/onnx-model-zoo/synset.txt")


with open('synset.txt', 'r') as f:
    labels = [l.rstrip() for l in f]
    
os.system("wget https://github.com/AK391/models/raw/main/vision/classification/shufflenet/model/shufflenet-v2-10.onnx")

os.system("wget https://s3.amazonaws.com/model-server/inputs/kitten.jpg")



model_path = 'shufflenet-v2-10.onnx'
model = onnx.load(model_path)
session = ort.InferenceSession(model.SerializeToString())


    
preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
        

    

def predict(img):
    input_tensor = preprocess(img)
    img = input_tensor.unsqueeze(0)
    ort_inputs = {session.get_inputs()[0].name: img.cpu().detach().numpy()}
    preds = session.run(None, ort_inputs)[0]
    preds = np.squeeze(preds)
    a = np.argsort(preds)
    results = {}
    for i in a[0:5]:    
        results[labels[a[i]]] = float(preds[a[i]])
    return results
       

title="ShuffleNet-v2"
description="ShuffleNet is a deep convolutional network for image classification. ShuffleNetV2 is an improved architecture that is the state-of-the-art in terms of speed and accuracy tradeoff used for image classification."

examples=[['kitten.jpg']]
gr.Interface(predict,gr.inputs.Image(type='pil'),"label",title=title,description=description,examples=examples).launch(enable_queue=True,debug=True)