nragrawal's picture
Update app to load from checkpoint
d957efc
import gradio as gr
import torch
import torchvision.transforms as transforms
from PIL import Image
import traceback
import sys
from network import create_model # Import our model architecture
# Load model from local checkpoint
def load_model():
try:
model = create_model(num_classes=1000)
checkpoint = torch.load('model/model_best.pth', map_location='cpu')
# Handle DataParallel state dict
state_dict = checkpoint['model_state_dict']
# Remove 'module.' prefix if it exists
new_state_dict = {}
for k, v in state_dict.items():
name = k.replace('module.', '') # Remove 'module.' prefix
new_state_dict[name] = v
# Load the modified state dict
model.load_state_dict(new_state_dict)
model.eval()
return model
except Exception as e:
print(f"Error loading model: {str(e)}")
print(traceback.format_exc())
raise e
# Load ImageNet class labels
def load_labels():
try:
with open('model/classes.txt', 'r') as f:
labels = [line.strip() for line in f.readlines()]
return labels
except:
return [f"Class_{i}" for i in range(100)] # Fallback to generic labels
# Preprocessing
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# Global variables
model = load_model()
labels = load_labels()
# Inference function
def predict(image):
try:
# Preprocess image
img = Image.fromarray(image)
img = transform(img).unsqueeze(0)
# Inference
with torch.no_grad():
output = model(img)
probabilities = torch.nn.functional.softmax(output[0], dim=0)
# Get top 5 predictions
top5_prob, top5_catid = torch.topk(probabilities, 5)
return {labels[idx]: float(prob) for prob, idx in zip(top5_prob, top5_catid)}
except Exception as e:
print(f"Error during prediction: {str(e)}")
print(traceback.format_exc())
return {"error": str(e)}
# Create Gradio interface with error handling
iface = gr.Interface(
fn=predict,
inputs=gr.Image(),
outputs=gr.Label(num_top_classes=5),
title="ResNet Image Classification",
description="Upload an image to classify it using ResNet trained on ImageNet subset",
allow_flagging="never"
)
# Add error handling to launch
try:
iface.launch(share=True)
except Exception as e:
print(f"Error launching interface: {str(e)}")
print(traceback.format_exc())