Spaces:
Sleeping
Sleeping
import os | |
import traceback | |
import gradio as gr | |
import torch | |
from torchvision.models import get_model | |
from torchvision.transforms import transforms | |
from torchvision.transforms.functional import InterpolationMode | |
# Function to load the model with custom weights | |
def load_model(weights_path): | |
model = get_model("resnet50", num_classes=1000) | |
ckpt = torch.load(weights_path, map_location=torch.device("cpu")) | |
model.load_state_dict(ckpt["model_state_dict"]) | |
model.eval() | |
return model | |
# Function for making predictions and returning top 5 predictions with confidence | |
def classify_image(image): | |
# Preprocess the input image | |
image = transform(image).unsqueeze(0) # Add batch dimension | |
with torch.no_grad(): | |
output = model(image) # Get model output | |
# The output has unnormalized scores. To get probabilities, you can run a softmax on it. | |
probabilities = torch.nn.functional.softmax(output[0], dim=0) | |
# Read the categories | |
with open("imagenet_classes.txt", "r") as f: | |
categories = [s.strip() for s in f.readlines()] | |
# Show top categories per image | |
top5_prob, top5_catid = torch.topk(probabilities, 5) | |
result = {} | |
for i in range(top5_prob.size(0)): | |
result[categories[top5_catid[i]]] = top5_prob[i].item() | |
return result | |
# Define image transformation to match the model input | |
transform = transforms.Compose([ | |
transforms.Resize(256, interpolation=InterpolationMode.BILINEAR, antialias=True), | |
transforms.CenterCrop(224), | |
transforms.PILToTensor(), | |
transforms.ConvertImageDtype(torch.float), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
]) | |
# Path to the pre-trained model weights (should be set by the user) | |
model_weights_path = "best.pth" | |
model = load_model(model_weights_path) | |
# Define the Gradio interface | |
iface = gr.Interface( | |
fn=classify_image, # The function to run on input | |
inputs=gr.Image(type="pil"), # Image input (in PIL format) | |
outputs=gr.Label(num_top_classes=5), # Output will be the predicted top 5 classes with confidence scores | |
title = "Image Recognition using ResNet-50 trained on Imagenet-1K", | |
live = True, | |
description = "<p style='text-align: center'> Gradio demo for ResNet, Deep residual networks pre-trained on ImageNet. To use it, simply upload your image, or click one of the examples to load them. </p>", | |
article = "<p style='text-align: center'> \ | |
<a href='https://arxiv.org/abs/1512.03385' target='_blank'>Deep Residual Learning for Image Recognition</a> | \ | |
<a href='https://github.com/KD1994/session-9-imagenet-resnet50' target='_blank'>Github Repo</a> \ | |
</p>", | |
examples = [ | |
['examples/dog.jpg'], | |
['examples/great-white-shark.jpg'], | |
['examples/american-goldfinch.jpg'], | |
['examples/hognose-snake.jpg'] | |
] | |
) | |
# Add error handling to launch | |
try: | |
iface.launch(share=True) | |
except Exception as e: | |
print(f"Error launching interface: {str(e)}") | |
print(traceback.format_exc()) |