Spaces:
Runtime error
Runtime error
import gradio as gr | |
import os | |
import torch | |
from model import create_effnet_b2 | |
from timeit import default_timer as timer | |
from typing import Tuple, Dict | |
#setup class names | |
class_names = ['pizza', 'steak', 'sushi'] | |
#model and transforms preparation | |
effnetb2, effnetb2_transforms = create_effnet_b2( | |
num_classes = 3) | |
#load saved weights | |
effnetb2.load_state_dict( | |
torch.load(f = 'pretrained_effnetb2_feature_extractor.pth', | |
map_location = torch.device('cpu')) #hardcoding to load state dict onto the cpu | |
) | |
#Predict function | |
def predict(img) -> Tuple[Dict, float]: | |
#Start a timer | |
start_time = timer() | |
#transform the input image for use with effnetb2 | |
transformed_image = effnetb2_transforms(img).unsqueeze(0) | |
#put model into deval mode, make preiction | |
effnetb2.eval() | |
with torch.inference_mode(): | |
pred_logits = effnetb2(transformed_image) | |
pred_probs = torch.softmax(pred_logits, dim = 1) | |
# create a prediction label and pred prob dictionary | |
pred_labels_and_probs = {effnet_class_names[i]: float(pred_probs[0][i]) | |
for i in range(len(effnet_class_names))} | |
#calculate pred time | |
end_time = timer() | |
pred_time = end_time - start_time | |
#return pred dict and pred time | |
print(pred_probs[0]) | |
return pred_labels_and_probs, pred_time | |
# Gradio app | |
import gradio as gr | |
#Create title, description and article | |
title = 'FoodVision Mini' | |
description = 'An EfficientNetB2 feature extractor to classify food as pizza, steak, and sushi' | |
#Create example list | |
example_list = [['examples/' + example] for example in os.listdir('examples')] | |
demo = gr.Interface(fn = predict, | |
inputs = gr.Image(type='pil'), | |
outputs = [gr.Label(num_top_classes = 3, label = 'Predictions'), | |
gr.Number(label = 'Prediction time (s)')], | |
examples = example_list, | |
title = title, | |
description = description) | |
demo.launch(debug = False, | |
share = True) | |