ShuffleNet-v2 / app.py
akhaliq's picture
akhaliq HF Staff
Update app.py
608aad9
import onnx
import numpy as np
import onnxruntime as ort
from PIL import Image
import cv2
import os
import gradio as gr
import mxnet
from torchvision import transforms
os.system("wget https://s3.amazonaws.com/onnx-model-zoo/synset.txt")
with open('synset.txt', 'r') as f:
labels = [l.rstrip() for l in f]
os.system("wget https://github.com/AK391/models/raw/main/vision/classification/shufflenet/model/shufflenet-v2-10.onnx")
os.system("wget https://s3.amazonaws.com/model-server/inputs/kitten.jpg")
model_path = 'shufflenet-v2-10.onnx'
model = onnx.load(model_path)
session = ort.InferenceSession(model.SerializeToString())
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def predict(img):
input_tensor = preprocess(img)
img = input_tensor.unsqueeze(0)
ort_inputs = {session.get_inputs()[0].name: img.cpu().detach().numpy()}
preds = session.run(None, ort_inputs)[0]
preds = np.squeeze(preds)
a = np.argsort(preds)
results = {}
for i in a[0:5]:
results[labels[a[i]]] = float(preds[a[i]])
return results
title="ShuffleNet-v2"
description="ShuffleNet is a deep convolutional network for image classification. ShuffleNetV2 is an improved architecture that is the state-of-the-art in terms of speed and accuracy tradeoff used for image classification."
examples=[['kitten.jpg']]
gr.Interface(predict,gr.inputs.Image(type='pil'),"label",title=title,description=description,examples=examples).launch(enable_queue=True,debug=True)