File size: 2,498 Bytes
122654c 5cf4eea 75a5b88 d97b868 75a5b88 954ac21 122654c 75a5b88 d97b868 75a5b88 d97b868 9f320da 75a5b88 954ac21 75a5b88 954ac21 75a5b88 d97b868 75a5b88 d97b868 75a5b88 9f320da 75a5b88 5cf4eea 75a5b88 9f320da 954ac21 f63495a 954ac21 6e9fd21 954ac21 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
import os
import torch
import torchvision.transforms as transforms
import torchvision.models as models
from PIL import Image
import json
import gradio as gr
import requests
# Path to the model file and Hugging Face URL
model_path = 'food_classification_model.pth'
model_url = "https://huggingface.co/KabeerAmjad/food_classification_model/resolve/main/food_classification_model.pth"
# Download the model file if it's not already available
if not os.path.exists(model_path):
print(f"Downloading the model from {model_url}...")
response = requests.get(model_url)
with open(model_path, 'wb') as f:
f.write(response.content)
print("Model downloaded successfully.")
# Load the model with updated weights parameter
model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
model.eval() # Set model to evaluation mode
# Load the model's custom state_dict
try:
state_dict = torch.load(model_path, map_location=torch.device('cpu'))
model.load_state_dict(state_dict)
except RuntimeError as e:
print("Error loading state_dict:", e)
print("Ensure that the saved model architecture matches ResNet50.")
# Define the image transformations
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
),
])
# Load labels
with open("config.json") as f:
labels = json.load(f)
# Function to predict image class
def predict(image):
# Convert the uploaded file to a PIL image
input_image = image.convert("RGB")
# Preprocess the image
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0) # Add batch dimension
# Check if a GPU is available and move the input and model to GPU
if torch.cuda.is_available():
input_batch = input_batch.to('cuda')
model.to('cuda')
# Perform inference
with torch.no_grad():
output = model(input_batch)
# Get the predicted class with the highest score
_, predicted_idx = torch.max(output, 1)
predicted_class = labels[str(predicted_idx.item())]
return f"Predicted class: {predicted_class}"
# Set up the Gradio interface
iface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs="text",
title="Food Classification Model",
description="Upload an image of food to classify it."
)
# Launch the Gradio app
iface.launch()
|