Spaces:
Runtime error
Runtime error
Commit
·
008b175
1
Parent(s):
5db8bb0
Update app.py
Browse files
app.py
CHANGED
@@ -2,7 +2,7 @@ import gradio as gr
|
|
2 |
import os
|
3 |
import torch
|
4 |
|
5 |
-
from model import
|
6 |
from timeit import default_timer as timer
|
7 |
from typing import Tuple, Dict
|
8 |
|
@@ -12,14 +12,14 @@ with open("class_names.txt", "r") as f:
|
|
12 |
|
13 |
|
14 |
# Create model
|
15 |
-
|
16 |
-
num_classes=
|
17 |
)
|
18 |
|
19 |
# Load saved weights
|
20 |
-
|
21 |
torch.load(
|
22 |
-
f="
|
23 |
map_location=torch.device("cpu"),
|
24 |
)
|
25 |
)
|
@@ -31,13 +31,13 @@ def predict(img) -> Tuple[Dict, float]:
|
|
31 |
start_time = timer()
|
32 |
|
33 |
# Transform the target image and add a batch dimension
|
34 |
-
img =
|
35 |
|
36 |
# Put model into evaluation mode and turn on inference mode
|
37 |
-
|
38 |
with torch.inference_mode():
|
39 |
# Pass the transformed image through the model and turn the prediction logits into prediction probabilities
|
40 |
-
pred_probs = torch.softmax(
|
41 |
|
42 |
# Create a prediction label and prediction probability dictionary for each prediction class (this is the required format for Gradio's output parameter)
|
43 |
pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
|
@@ -52,7 +52,7 @@ def predict(img) -> Tuple[Dict, float]:
|
|
52 |
##GRADIO APP
|
53 |
# Create title, description and article strings
|
54 |
title = "FoodVision🍔🍟🍦"
|
55 |
-
description = "
|
56 |
article = "Created by [Rohit](https://github.com/ItsNotRohit02)."
|
57 |
|
58 |
# Create examples list from "examples/" directory
|
|
|
2 |
import os
|
3 |
import torch
|
4 |
|
5 |
+
from model import create_ViT
|
6 |
from timeit import default_timer as timer
|
7 |
from typing import Tuple, Dict
|
8 |
|
|
|
12 |
|
13 |
|
14 |
# Create model
|
15 |
+
ViT_model, ViT_transforms = create_ViT(
|
16 |
+
num_classes=126,
|
17 |
)
|
18 |
|
19 |
# Load saved weights
|
20 |
+
ViT_model.load_state_dict(
|
21 |
torch.load(
|
22 |
+
f="ViT.pth",
|
23 |
map_location=torch.device("cpu"),
|
24 |
)
|
25 |
)
|
|
|
31 |
start_time = timer()
|
32 |
|
33 |
# Transform the target image and add a batch dimension
|
34 |
+
img = ViT_transforms(img).unsqueeze(0)
|
35 |
|
36 |
# Put model into evaluation mode and turn on inference mode
|
37 |
+
ViT_model.eval()
|
38 |
with torch.inference_mode():
|
39 |
# Pass the transformed image through the model and turn the prediction logits into prediction probabilities
|
40 |
+
pred_probs = torch.softmax(ViT_model(img), dim=1)
|
41 |
|
42 |
# Create a prediction label and prediction probability dictionary for each prediction class (this is the required format for Gradio's output parameter)
|
43 |
pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
|
|
|
52 |
##GRADIO APP
|
53 |
# Create title, description and article strings
|
54 |
title = "FoodVision🍔🍟🍦"
|
55 |
+
description = "A Vision Transformer feature extractor computer vision model to classify images of food into 126 different classes."
|
56 |
article = "Created by [Rohit](https://github.com/ItsNotRohit02)."
|
57 |
|
58 |
# Create examples list from "examples/" directory
|