Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import torchvision.transforms as transforms | |
from PIL import Image | |
import traceback | |
import sys | |
from network import create_model # Import our model architecture | |
# Load model from local checkpoint | |
def load_model(): | |
try: | |
model = create_model(num_classes=1000) | |
checkpoint = torch.load('model/model_best.pth', map_location='cpu') | |
# Handle DataParallel state dict | |
state_dict = checkpoint['model_state_dict'] | |
# Remove 'module.' prefix if it exists | |
new_state_dict = {} | |
for k, v in state_dict.items(): | |
name = k.replace('module.', '') # Remove 'module.' prefix | |
new_state_dict[name] = v | |
# Load the modified state dict | |
model.load_state_dict(new_state_dict) | |
model.eval() | |
return model | |
except Exception as e: | |
print(f"Error loading model: {str(e)}") | |
print(traceback.format_exc()) | |
raise e | |
# Load ImageNet class labels | |
def load_labels(): | |
try: | |
with open('model/classes.txt', 'r') as f: | |
labels = [line.strip() for line in f.readlines()] | |
return labels | |
except: | |
return [f"Class_{i}" for i in range(100)] # Fallback to generic labels | |
# Preprocessing | |
transform = transforms.Compose([ | |
transforms.Resize(256), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]) | |
]) | |
# Global variables | |
model = load_model() | |
labels = load_labels() | |
# Inference function | |
def predict(image): | |
try: | |
# Preprocess image | |
img = Image.fromarray(image) | |
img = transform(img).unsqueeze(0) | |
# Inference | |
with torch.no_grad(): | |
output = model(img) | |
probabilities = torch.nn.functional.softmax(output[0], dim=0) | |
# Get top 5 predictions | |
top5_prob, top5_catid = torch.topk(probabilities, 5) | |
return {labels[idx]: float(prob) for prob, idx in zip(top5_prob, top5_catid)} | |
except Exception as e: | |
print(f"Error during prediction: {str(e)}") | |
print(traceback.format_exc()) | |
return {"error": str(e)} | |
# Create Gradio interface with error handling | |
iface = gr.Interface( | |
fn=predict, | |
inputs=gr.Image(), | |
outputs=gr.Label(num_top_classes=5), | |
title="ResNet Image Classification", | |
description="Upload an image to classify it using ResNet trained on ImageNet subset", | |
allow_flagging="never" | |
) | |
# Add error handling to launch | |
try: | |
iface.launch(share=True) | |
except Exception as e: | |
print(f"Error launching interface: {str(e)}") | |
print(traceback.format_exc()) |