nikhiljais's picture
Initial commit
7f2d78e
import gradio as gr
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import json
import os
from train_resnet50 import ResNet, Bottleneck
# Load ImageNet class labels
try:
with open("imagenet_classes.json", "r") as f:
class_labels = json.load(f)
print(f"Loaded {len(class_labels)} class labels")
except FileNotFoundError:
print("Warning: imagenet_classes.json not found, creating simplified labels")
# Fallback to a simplified version
class_labels = {str(i): f"class_{i}" for i in range(1000)}
except json.JSONDecodeError:
print("Warning: Error parsing imagenet_classes.json, using simplified labels")
class_labels = {str(i): f"class_{i}" for i in range(1000)}
except Exception as e:
print(f"Warning: Unexpected error loading class labels: {e}")
class_labels = {str(i): f"class_{i}" for i in range(1000)}
def create_model():
model = ResNet(Bottleneck, [3, 4, 6, 3])
return model
def load_model(model_path):
model = create_model()
try:
checkpoint = torch.load(model_path, map_location="cpu")
# Handle DataParallel/DDP state dict
state_dict = checkpoint["model_state_dict"]
new_state_dict = {}
for k, v in state_dict.items():
name = k.replace("module.", "") if k.startswith("module.") else k
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
model.eval()
print("Model loaded successfully!")
return model
except Exception as e:
print(f"Error loading model: {e}")
print("Loading pretrained ResNet50 as fallback...")
model = torch.hub.load("pytorch/vision:v0.10.0", "resnet50", pretrained=True)
model.eval()
return model
# Preprocessing transform
transform = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
# Global variable for model
global_model = None
def predict(image):
global global_model
# Load model only once
if global_model is None:
try:
global_model = load_model("best_model.pth")
except Exception as e:
print(f"Error loading model: {e}")
return None
# Preprocess image
if image is None:
return None
try:
image = Image.fromarray(image)
image = transform(image).unsqueeze(0)
# Make prediction
with torch.no_grad():
outputs = global_model(image)
probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
# Get top 5 predictions
top5_prob, top5_catid = torch.topk(probabilities, 5)
# Create results dictionary
results = []
for i in range(5):
class_idx = top5_catid[i].item()
# Use list indexing instead of dictionary get()
class_label = (
class_labels[class_idx]
if class_idx < len(class_labels)
else f"class_{class_idx}"
)
results.append(
{
"label": class_label,
"class_id": class_idx,
"confidence": float(top5_prob[i].item()),
}
)
return results
except Exception as e:
print(f"Error during prediction: {e}")
print(f"Class indices: {[idx.item() for idx in top5_catid]}") # Debug info
return None
# Create Gradio interface
iface = gr.Interface(
fn=predict,
inputs=gr.Image(),
outputs=gr.JSON(),
title="ResNet50 ImageNet Classifier",
description="Upload an image to get top-5 predictions from our trained ResNet50 model.",
)
# Launch the app
if __name__ == "__main__":
iface.launch(share=True)