from typing import List, Tuple, Dict import torch import os from timeit import default_timer as timer import PIL import gradio as gr from model import create_effnetb2_model class_names = ['pizza', 'steak', 'sushi'] examples = [os.path.join('examples', img) for img in os.listdir('examples')] model, preprocess = create_effnetb2_model(num_classes=3, seed=42) model.load_state_dict(torch.load('effnetb2_20_percent.pth', map_location=torch.device('cpu'))) def predict(img: PIL.Image) -> Tuple[Dict, float]: start_time = timer() img = preprocess(img).unsqueeze(dim=0) model.eval() with torch.inference_mode(): probs = model(img).softmax(dim=-1).squeeze().tolist() preds = {class_name: prob for class_name, prob in zip(class_names, probs)} pred_time = round(timer() - start_time, 8) return preds, pred_time demo = gr.Interface(fn=predict, inputs=gr.Image(type='pil'), outputs=[gr.Label(num_top_classes=3, label='Prediction probabilities'), gr.Number(label='Prediction time (s)')], examples=examples, title='FoodVision Mini 🍕🥩🍣') demo.launch()