GoogleNet / app.py
akhaliq's picture
akhaliq HF Staff
Update app.py
3d83121
raw
history blame
2.66 kB
import mxnet as mx
import matplotlib.pyplot as plt
import numpy as np
from collections import namedtuple
from mxnet.gluon.data.vision import transforms
from mxnet.contrib.onnx.onnx2mx.import_model import import_model
import os
import gradio as gr
from PIL import Image
import imageio
def get_image(path):
'''
Using path to image, return the RGB load image
'''
img = imageio.imread(path, pilmode='RGB')
return img
# Pre-processing function for ImageNet models using numpy
def preprocess(img):
'''
Preprocessing required on the images for inference with mxnet gluon
The function takes loaded image and returns processed tensor
'''
img = np.array(Image.fromarray(img).resize((224, 224))).astype(np.float32)
img[:, :, 0] -= 123.68
img[:, :, 1] -= 116.779
img[:, :, 2] -= 103.939
img[:,:,[0,1,2]] = img[:,:,[2,1,0]]
img = img.transpose((2, 0, 1))
img = np.expand_dims(img, axis=0)
return img
mx.test_utils.download('https://s3.amazonaws.com/model-server/inputs/kitten.jpg')
mx.test_utils.download('https://s3.amazonaws.com/onnx-model-zoo/synset.txt')
with open('synset.txt', 'r') as f:
labels = [l.rstrip() for l in f]
os.system("wget https://github.com/onnx/models/raw/main/vision/classification/inception_and_googlenet/googlenet/model/googlenet-9.onnx")
# Enter path to the ONNX model file
sym, arg_params, aux_params = import_model('googlenet-9.onnx')
Batch = namedtuple('Batch', ['data'])
def predict(path):
img = get_image(path)
img = preprocess(img)
mod.forward(Batch([img]))
# Take softmax to generate probabilities
scores = mx.ndarray.softmax(mod.get_outputs()[0]).asnumpy()
# print the top-5 inferences class
scores = np.squeeze(scores)
a = np.argsort(scores)[::-1]
results = {}
for i in a[0:5]:
results[labels[i]] = float(scores[i])
return results
# Determine and set context
if len(mx.test_utils.list_gpus())==0:
ctx = mx.cpu()
else:
ctx = mx.gpu(0)
# Load module
mod = mx.mod.Module(symbol=sym, context=ctx, data_names=['data_0'], label_names=None)
mod.bind(for_training=False, data_shapes=[('data_0', (1,3,224,224))],label_shapes=mod._label_shapes)
mod.set_params(arg_params, aux_params, allow_missing=True, allow_extra=True)
title="GoogleNet"
description="GoogLeNet is the name of a convolutional neural network for classification, which competed in the ImageNet Large Scale Visual Recognition Challenge in 2014."
examples=[['catonnx.jpg']]
gr.Interface(predict,gr.inputs.Image(type='filepath'),"label",title=title,description=description,examples=examples).launch(enable_queue=True)