ItsNotRohit commited on
Commit
008b175
·
1 Parent(s): 5db8bb0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -9
app.py CHANGED
@@ -2,7 +2,7 @@ import gradio as gr
2
  import os
3
  import torch
4
 
5
- from model import create_effnetv2
6
  from timeit import default_timer as timer
7
  from typing import Tuple, Dict
8
 
@@ -12,14 +12,14 @@ with open("class_names.txt", "r") as f:
12
 
13
 
14
  # Create model
15
- effnetv2, effnetv2_transforms = create_effnetv2(
16
- num_classes=101,
17
  )
18
 
19
  # Load saved weights
20
- effnetv2.load_state_dict(
21
  torch.load(
22
- f="effnet_v2.pth",
23
  map_location=torch.device("cpu"),
24
  )
25
  )
@@ -31,13 +31,13 @@ def predict(img) -> Tuple[Dict, float]:
31
  start_time = timer()
32
 
33
  # Transform the target image and add a batch dimension
34
- img = effnetv2_transforms(img).unsqueeze(0)
35
 
36
  # Put model into evaluation mode and turn on inference mode
37
- effnetv2.eval()
38
  with torch.inference_mode():
39
  # Pass the transformed image through the model and turn the prediction logits into prediction probabilities
40
- pred_probs = torch.softmax(effnetv2(img), dim=1)
41
 
42
  # Create a prediction label and prediction probability dictionary for each prediction class (this is the required format for Gradio's output parameter)
43
  pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
@@ -52,7 +52,7 @@ def predict(img) -> Tuple[Dict, float]:
52
  ##GRADIO APP
53
  # Create title, description and article strings
54
  title = "FoodVision🍔🍟🍦"
55
- description = "An EfficientNetV2 feature extractor computer vision model to classify images of food into 101 different classes."
56
  article = "Created by [Rohit](https://github.com/ItsNotRohit02)."
57
 
58
  # Create examples list from "examples/" directory
 
2
  import os
3
  import torch
4
 
5
+ from model import create_ViT
6
  from timeit import default_timer as timer
7
  from typing import Tuple, Dict
8
 
 
12
 
13
 
14
  # Create model
15
+ ViT_model, ViT_transforms = create_ViT(
16
+ num_classes=126,
17
  )
18
 
19
  # Load saved weights
20
+ ViT_model.load_state_dict(
21
  torch.load(
22
+ f="ViT.pth",
23
  map_location=torch.device("cpu"),
24
  )
25
  )
 
31
  start_time = timer()
32
 
33
  # Transform the target image and add a batch dimension
34
+ img = ViT_transforms(img).unsqueeze(0)
35
 
36
  # Put model into evaluation mode and turn on inference mode
37
+ ViT_model.eval()
38
  with torch.inference_mode():
39
  # Pass the transformed image through the model and turn the prediction logits into prediction probabilities
40
+ pred_probs = torch.softmax(ViT_model(img), dim=1)
41
 
42
  # Create a prediction label and prediction probability dictionary for each prediction class (this is the required format for Gradio's output parameter)
43
  pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
 
52
  ##GRADIO APP
53
  # Create title, description and article strings
54
  title = "FoodVision🍔🍟🍦"
55
+ description = "A Vision Transformer feature extractor computer vision model to classify images of food into 126 different classes."
56
  article = "Created by [Rohit](https://github.com/ItsNotRohit02)."
57
 
58
  # Create examples list from "examples/" directory