import os import torch import torchvision.transforms as transforms import torchvision.models as models from PIL import Image import json import gradio as gr import requests # Path to the model file and Hugging Face URL model_path = 'food_classification_model.pth' model_url = "https://huggingface.co/KabeerAmjad/food_classification_model/resolve/main/food_classification_model.pth" # Download the model file if it's not already available if not os.path.exists(model_path): print(f"Downloading the model from {model_url}...") response = requests.get(model_url) with open(model_path, 'wb') as f: f.write(response.content) print("Model downloaded successfully.") # Load the model with updated weights parameter model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) model.eval() # Set model to evaluation mode # Load the model's custom state_dict try: state_dict = torch.load(model_path, map_location=torch.device('cpu')) model.load_state_dict(state_dict) except RuntimeError as e: print("Error loading state_dict:", e) print("Ensure that the saved model architecture matches ResNet50.") # Define the image transformations preprocess = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], ), ]) # Load labels with open("config.json") as f: labels = json.load(f) # Function to predict image class def predict(image): # Convert the uploaded file to a PIL image input_image = image.convert("RGB") # Preprocess the image input_tensor = preprocess(input_image) input_batch = input_tensor.unsqueeze(0) # Add batch dimension # Check if a GPU is available and move the input and model to GPU if torch.cuda.is_available(): input_batch = input_batch.to('cuda') model.to('cuda') # Perform inference with torch.no_grad(): output = model(input_batch) # Get the predicted class with the highest score _, predicted_idx = torch.max(output, 1) predicted_class = labels[str(predicted_idx.item())] return f"Predicted class: {predicted_class}" # Set up the Gradio interface iface = gr.Interface( fn=predict, inputs=gr.Image(type="pil"), outputs="text", title="Food Classification Model", description="Upload an image of food to classify it." ) # Launch the Gradio app iface.launch()