Spaces:
Runtime error
Runtime error
File size: 2,097 Bytes
da26c6d 9e38cc3 da26c6d 008b175 da26c6d 7cbc13f da26c6d 7cbc13f da26c6d 87f7375 da26c6d 8458aec da26c6d 7cbc13f 9d8d906 da26c6d 9d8d906 7cbc13f 9d8d906 da26c6d 8458aec da26c6d 9d8d906 da26c6d 5db8bb0 82d6ea6 da26c6d 87f7375 da26c6d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
import gradio as gr
import os
import torch
from torchvision import datasets, transforms
from model import create_ViT
from timeit import default_timer as timer
from typing import Tuple, Dict
# Setup class names
with open("class_names.txt", "r") as f:
class_names = [food_name.strip() for food_name in f.readlines()]
# Create model
model = create_ViT()
# Load saved weights
model.load_state_dict(
torch.load(
f="ViT.pth",
map_location=torch.device("cpu"),
)
)
def predict(img) -> Tuple[Dict, float]:
start_time = timer()
preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
image = preprocess(img).unsqueeze(0) # Add batch dimension
# Make predictions
model.eval()
with torch.no_grad():
outputs = model(image).logits
predicted_probs = torch.softmax(outputs, dim=1)
# Create a prediction label and prediction probability dictionary for each prediction class
pred_labels_and_probs = {class_names[i]: float(predicted_probs[0][i]) for i in range(len(class_names))}
# Calculate the prediction time
pred_time = round(timer() - start_time, 5)
return pred_labels_and_probs, pred_time
##GRADIO APP
# Create title, description and article strings
title = "FoodVision🍔🍟🍦"
description = "A Vision Transformer feature extractor computer vision model to classify images of food into 121 different classes."
article = "Created by [Rohit](https://github.com/ItsNotRohit02)."
# Create examples list from "examples/" directory
example_list = [["examples/" + example] for example in os.listdir("examples")]
# Create Gradio interface
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=[
gr.Label(num_top_classes=5, label="Predictions"),
gr.Number(label="Prediction time (s)"),
],
examples=example_list,
title=title,
description=description,
article=article,
)
# Launch the app!
demo.launch()
|