File size: 2,097 Bytes
da26c6d
 
 
9e38cc3
da26c6d
008b175
da26c6d
 
 
 
 
 
 
 
 
7cbc13f
da26c6d
 
7cbc13f
da26c6d
87f7375
da26c6d
 
 
 
8458aec
da26c6d
 
 
7cbc13f
 
 
 
 
 
9d8d906
da26c6d
9d8d906
7cbc13f
9d8d906
 
 
da26c6d
8458aec
 
da26c6d
9d8d906
da26c6d
 
 
 
 
 
 
5db8bb0
82d6ea6
da26c6d
 
 
 
 
 
 
 
 
 
87f7375
da26c6d
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import gradio as gr
import os
import torch
from torchvision import datasets, transforms

from model import create_ViT
from timeit import default_timer as timer
from typing import Tuple, Dict

# Setup class names
with open("class_names.txt", "r") as f:
    class_names = [food_name.strip() for food_name in  f.readlines()]


# Create model
model = create_ViT()

# Load saved weights
model.load_state_dict(
    torch.load(
        f="ViT.pth",
        map_location=torch.device("cpu"),
    )
)


def predict(img) -> Tuple[Dict, float]:
    start_time = timer()

    preprocess = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    image = preprocess(img).unsqueeze(0)  # Add batch dimension

    # Make predictions
    model.eval()
    with torch.no_grad():
        outputs = model(image).logits
        predicted_probs = torch.softmax(outputs, dim=1)

    # Create a prediction label and prediction probability dictionary for each prediction class
    pred_labels_and_probs = {class_names[i]: float(predicted_probs[0][i]) for i in range(len(class_names))}

    # Calculate the prediction time
    pred_time = round(timer() - start_time, 5)

    return pred_labels_and_probs, pred_time


##GRADIO APP
# Create title, description and article strings
title = "FoodVision🍔🍟🍦"
description = "A Vision Transformer feature extractor computer vision model to classify images of food into 121 different classes."
article = "Created by [Rohit](https://github.com/ItsNotRohit02)."

# Create examples list from "examples/" directory
example_list = [["examples/" + example] for example in os.listdir("examples")]

# Create Gradio interface
demo = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil"),
    outputs=[
        gr.Label(num_top_classes=5, label="Predictions"),
        gr.Number(label="Prediction time (s)"),
    ],
    examples=example_list,
    title=title,
    description=description,
    article=article,
)

# Launch the app!
demo.launch()