DenseNet-121 / app.py
akhaliq's picture
akhaliq HF Staff
Update app.py
fd8271e
raw
history blame
1.64 kB
import mxnet as mx
import matplotlib.pyplot as plt
import numpy as np
from collections import namedtuple
from mxnet.gluon.data.vision import transforms
import os
import gradio as gr
from PIL import Image
import imageio
import onnxruntime as ort
from torchvision import transforms
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
mx.test_utils.download('https://s3.amazonaws.com/model-server/inputs/kitten.jpg')
mx.test_utils.download('https://s3.amazonaws.com/onnx-model-zoo/synset.txt')
with open('synset.txt', 'r') as f:
labels = [l.rstrip() for l in f]
os.system("wget https://github.com/AK391/models/raw/main/vision/classification/densenet-121/model/densenet-9.onnx")
ort_session = ort.InferenceSession("densenet-9.onnx")
def predict(pil):
input_tensor = preprocess(pil)
img_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model
img_batch_np = img_batch.cpu().detach().numpy()
outputs = ort_session.run(
None,
{"data_0": img_batch_np.astype(np.float32)},
)
a = np.argsort(-outputs[0].flatten())
results = {}
for i in a[0:5]:
results[labels[i]]=float(outputs[0][0][i])
return results
title="DenseNet-121"
description="DenseNet-121 is a convolutional neural network for classification."
examples=[['apple.jpg']]
gr.Interface(predict,gr.inputs.Image(type='pil'),"label",title=title,description=description,examples=examples).launch(enable_queue=True,debug=True)