Inception_v2 / app.py
akhaliq's picture
akhaliq HF Staff
Update app.py
a739b51
raw
history blame
1.75 kB
import onnx
import numpy as np
import onnxruntime as ort
from PIL import Image
import cv2
import os
import gradio as gr
os.system("wget https://s3.amazonaws.com/onnx-model-zoo/synset.txt")
with open('synset.txt', 'r') as f:
labels = [l.rstrip() for l in f]
os.system("wget https://github.com/AK391/models/raw/main/vision/classification/inception_and_googlenet/inception_v2/model/inception-v2-9.onnx")
os.system("wget https://s3.amazonaws.com/model-server/inputs/kitten.jpg")
model_path = 'inception-v2-9.onnx'
model = onnx.load(model_path)
session = ort.InferenceSession(model.SerializeToString())
def get_image(path):
with Image.open(path) as img:
img = np.array(img.convert('RGB'))
return img
def preprocess(img):
img = img / 255.
img = cv2.resize(img, (256, 256))
h, w = img.shape[0], img.shape[1]
y0 = (h - 224) // 2
x0 = (w - 224) // 2
img = img[y0 : y0+224, x0 : x0+224, :]
img = (img - [0.485, 0.456, 0.406]) / [0.229, 0.224, 0.225]
img = np.transpose(img, axes=[2, 0, 1])
img = img.astype(np.float32)
img = np.expand_dims(img, axis=0)
return img
def predict(path):
img = get_image(path)
img = preprocess(img)
ort_inputs = {session.get_inputs()[0].name: img}
preds = session.run(None, ort_inputs)[0]
preds = np.squeeze(preds)
a = np.argsort(preds)[::-1]
results = {}
results[labels[a[0]]] = float(preds[a[0]]*0.1)
return results
title="Inception v2"
description="Inception v2 is a deep convolutional networks for classification."
examples=[['kitten.jpg']]
gr.Interface(predict,gr.inputs.Image(type='filepath'),"label",title=title,description=description,examples=examples).launch(enable_queue=True,debug=True)