File size: 3,589 Bytes
122654c
5cf4eea
75a5b88
 
d97b868
75a5b88
954ac21
122654c
 
 
 
 
 
 
 
 
 
 
 
 
75a5b88
 
 
 
 
 
 
 
 
d9d7936
75a5b88
 
 
 
 
 
 
 
d97b868
75a5b88
 
 
 
d97b868
9f320da
75a5b88
d9d7936
 
 
 
 
 
75a5b88
 
954ac21
d9d7936
3e8dce3
 
d9d7936
 
3e8dce3
d9d7936
 
 
 
3e8dce3
d9d7936
 
 
 
 
3e8dce3
d9d7936
 
 
 
 
 
3e8dce3
d9d7936
 
 
e6daf1b
 
 
 
 
 
 
 
 
9f320da
d9d7936
 
 
 
3e8dce3
f63495a
954ac21
 
 
6e9fd21
954ac21
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os
import torch
import torchvision.transforms as transforms
import torchvision.models as models
from PIL import Image
import json
import gradio as gr
import requests

# Path to the model file and Hugging Face URL
model_path = 'food_classification_model.pth'
model_url = "https://huggingface.co/KabeerAmjad/food_classification_model/resolve/main/food_classification_model.pth"

# Download the model file if it's not already available
if not os.path.exists(model_path):
    print(f"Downloading the model from {model_url}...")
    response = requests.get(model_url)
    with open(model_path, 'wb') as f:
        f.write(response.content)
    print("Model downloaded successfully.")

# Load the model with updated weights parameter
model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
model.eval()  # Set model to evaluation mode

# Load the model's custom state_dict
try:
    state_dict = torch.load(model_path, map_location=torch.device('cpu'))
    model.load_state_dict(state_dict)
    print("Model loaded successfully.")
except RuntimeError as e:
    print("Error loading state_dict:", e)
    print("Ensure that the saved model architecture matches ResNet50.")

# Define the image transformations
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    ),
])

# Load labels
try:
    with open("config.json") as f:
        labels = json.load(f)
    print("Labels loaded successfully.")
except Exception as e:
    print("Error loading labels:", e)

# Function to predict image class
def predict(image):
    try:
        print("Starting prediction...")
        
        # Convert the uploaded file to a PIL image
        input_image = image.convert("RGB")
        print(f"Image converted to RGB: {input_image.size}")
        
        # Preprocess the image
        input_tensor = preprocess(input_image)
        input_batch = input_tensor.unsqueeze(0)  # Add batch dimension
        print(f"Input tensor shape after unsqueeze: {input_batch.shape}")
        
        # Check if a GPU is available and move the input and model to GPU
        if torch.cuda.is_available():
            input_batch = input_batch.to('cuda')
            model.to('cuda')
            print("Using GPU for inference.")
        else:
            print("GPU not available, using CPU.")
        
        # Perform inference
        with torch.no_grad():
            output = model(input_batch)
        print(f"Inference output shape: {output.shape}")
        
        # Get the predicted class with the highest score
        _, predicted_idx = torch.max(output, 1)
        predicted_idx = predicted_idx.item()
        print(f"Predicted class index: {predicted_idx}")
        
        # Check if the predicted index exists in labels
        if str(predicted_idx) in labels:
            predicted_class = labels[str(predicted_idx)]
        else:
            predicted_class = f"Unknown class index: {predicted_idx}. Please check the label mapping."
            print(predicted_class)

        return f"Predicted class: {predicted_class}"
    
    except Exception as e:
        print(f"Error during prediction: {e}")
        return f"An error occurred during prediction: {e}"

# Set up the Gradio interface
iface = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil"),
    outputs="text",
    title="Food Classification Model",
    description="Upload an image of food to classify it."
)

# Launch the Gradio app
iface.launch()