|
import os |
|
import torch |
|
import torchvision.transforms as transforms |
|
import torchvision.models as models |
|
from PIL import Image |
|
import json |
|
import gradio as gr |
|
import requests |
|
|
|
|
|
model_path = 'food_classification_model.pth' |
|
model_url = "https://huggingface.co/KabeerAmjad/food_classification_model/resolve/main/food_classification_model.pth" |
|
|
|
|
|
if not os.path.exists(model_path): |
|
print(f"Downloading the model from {model_url}...") |
|
response = requests.get(model_url) |
|
with open(model_path, 'wb') as f: |
|
f.write(response.content) |
|
print("Model downloaded successfully.") |
|
|
|
|
|
model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) |
|
model.eval() |
|
|
|
|
|
try: |
|
state_dict = torch.load(model_path, map_location=torch.device('cpu')) |
|
model.load_state_dict(state_dict) |
|
print("Model loaded successfully.") |
|
except RuntimeError as e: |
|
print("Error loading state_dict:", e) |
|
print("Ensure that the saved model architecture matches ResNet50.") |
|
|
|
|
|
preprocess = transforms.Compose([ |
|
transforms.Resize(256), |
|
transforms.CenterCrop(224), |
|
transforms.ToTensor(), |
|
transforms.Normalize( |
|
mean=[0.485, 0.456, 0.406], |
|
std=[0.229, 0.224, 0.225], |
|
), |
|
]) |
|
|
|
|
|
try: |
|
with open("config.json") as f: |
|
labels = json.load(f) |
|
print("Labels loaded successfully.") |
|
except Exception as e: |
|
print("Error loading labels:", e) |
|
|
|
|
|
def predict(image): |
|
try: |
|
print("Starting prediction...") |
|
|
|
|
|
input_image = image.convert("RGB") |
|
print(f"Image converted to RGB: {input_image.size}") |
|
|
|
|
|
input_tensor = preprocess(input_image) |
|
input_batch = input_tensor.unsqueeze(0) |
|
print(f"Input tensor shape after unsqueeze: {input_batch.shape}") |
|
|
|
|
|
if torch.cuda.is_available(): |
|
input_batch = input_batch.to('cuda') |
|
model.to('cuda') |
|
print("Using GPU for inference.") |
|
else: |
|
print("GPU not available, using CPU.") |
|
|
|
|
|
with torch.no_grad(): |
|
output = model(input_batch) |
|
print(f"Inference output shape: {output.shape}") |
|
|
|
|
|
_, predicted_idx = torch.max(output, 1) |
|
predicted_idx = predicted_idx.item() |
|
print(f"Predicted class index: {predicted_idx}") |
|
|
|
|
|
if str(predicted_idx) in labels: |
|
predicted_class = labels[str(predicted_idx)] |
|
else: |
|
predicted_class = f"Unknown class index: {predicted_idx}. Please check the label mapping." |
|
print(predicted_class) |
|
|
|
return f"Predicted class: {predicted_class}" |
|
|
|
except Exception as e: |
|
print(f"Error during prediction: {e}") |
|
return f"An error occurred during prediction: {e}" |
|
|
|
|
|
iface = gr.Interface( |
|
fn=predict, |
|
inputs=gr.Image(type="pil"), |
|
outputs="text", |
|
title="Food Classification Model", |
|
description="Upload an image of food to classify it." |
|
) |
|
|
|
|
|
iface.launch() |
|
|