File size: 1,672 Bytes
8243d5f
 
 
 
 
 
 
 
8e15e86
 
 
 
8243d5f
 
 
 
 
 
1f4e988
8243d5f
1f4e988
8243d5f
 
 
d76175b
 
 
 
 
 
 
 
 
8e15e86
 
d76175b
8e15e86
 
3f8df5b
8243d5f
 
8e15e86
8243d5f
8e15e86
8243d5f
8e15e86
 
8243d5f
 
8e15e86
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import mxnet as mx
import matplotlib.pyplot as plt
import numpy as np
from collections import namedtuple
from mxnet.gluon.data.vision import transforms
import os
import gradio as gr

from PIL import Image
import imageio
import onnxruntime as ort

mx.test_utils.download('https://s3.amazonaws.com/model-server/inputs/kitten.jpg')

mx.test_utils.download('https://s3.amazonaws.com/onnx-model-zoo/synset.txt')
with open('synset.txt', 'r') as f:
    labels = [l.rstrip() for l in f]
    
os.system("wget https://github.com/AK391/models/raw/main/vision/classification/shufflenet/model/shufflenet-v2-10.onnx")

ort_session = ort.InferenceSession("shufflenet-v2-10.onnx")

    
def predict(path):
    input_image = Image.open(path)
    preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    input_tensor = preprocess(input_image)
    input_batch = input_tensor.unsqueeze(0)
    outputs = ort_session.run(
        None,
        {"input": input_batch.astype(np.float32)},
    )

    a = np.argsort(outputs[0].flatten())
    results = {}
    for i in a[0:5]:
        results[labels[i]]=float(outputs[0][0][i])
    return results
       

title="GoogleNet"
description="GoogLeNet is the name of a convolutional neural network for classification, which competed in the ImageNet Large Scale Visual Recognition Challenge in 2014."

examples=[['catonnx.jpg']]
gr.Interface(predict,gr.inputs.Image(type='filepath'),"label",title=title,description=description,examples=examples).launch(enable_queue=True,debug=True)