KabeerAmjad commited on
Commit
f63495a
1 Parent(s): db29817

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -3
app.py CHANGED
@@ -3,13 +3,25 @@ import torch
3
  from torch import nn
4
  from torchvision import models, transforms
5
  from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  # Load the ResNet50 model
8
  model = models.resnet50(pretrained=False) # Don't load pre-trained weights here
9
  model.fc = nn.Linear(model.fc.in_features, 11) # Adjust the output layer to match your number of classes
10
 
11
- # Load the saved model weights (food_classification_model.pth)
12
- model.load_state_dict(torch.load('food_classification_model.pth')) # Load from the local file
13
  model.eval() # Set the model to evaluation mode
14
 
15
  # Define the same preprocessing used during training
@@ -31,7 +43,14 @@ def classify_image(img):
31
 
32
  # Get the label with the highest probability
33
  top_label = probs.argmax().item() # Get the index of the highest probability
34
- return top_label
 
 
 
 
 
 
 
35
 
36
  # Create the Gradio interface
37
  iface = gr.Interface(
@@ -44,3 +63,4 @@ iface = gr.Interface(
44
 
45
  # Launch the app
46
  iface.launch()
 
 
3
  from torch import nn
4
  from torchvision import models, transforms
5
  from PIL import Image
6
+ import os
7
+
8
+ # Define the model path
9
+ model_path = "food_classification_model.pth"
10
+ huggingface_model_url = "https://huggingface.co/KabeerAmjad/food_classification_model/resolve/main/food_classification_model.pth"
11
+
12
+ # Download the model from Hugging Face if it doesn't exist locally
13
+ if not os.path.exists(model_path):
14
+ import requests
15
+ response = requests.get(huggingface_model_url)
16
+ with open(model_path, "wb") as f:
17
+ f.write(response.content)
18
 
19
  # Load the ResNet50 model
20
  model = models.resnet50(pretrained=False) # Don't load pre-trained weights here
21
  model.fc = nn.Linear(model.fc.in_features, 11) # Adjust the output layer to match your number of classes
22
 
23
+ # Load the saved model weights
24
+ model.load_state_dict(torch.load(model_path))
25
  model.eval() # Set the model to evaluation mode
26
 
27
  # Define the same preprocessing used during training
 
43
 
44
  # Get the label with the highest probability
45
  top_label = probs.argmax().item() # Get the index of the highest probability
46
+
47
+ # Map label index to the actual class name
48
+ label_mapping = {
49
+ 0: "apple_pie", 1: "cheesecake", 2: "chicken_curry", 3: "french_fries",
50
+ 4: "fried_rice", 5: "hamburger", 6: "hot_dog", 7: "ice_cream",
51
+ 8: "omelette", 9: "pizza", 10: "sushi"
52
+ }
53
+ return label_mapping[top_label]
54
 
55
  # Create the Gradio interface
56
  iface = gr.Interface(
 
63
 
64
  # Launch the app
65
  iface.launch()
66
+