Spaces:
Sleeping
Sleeping
File size: 4,043 Bytes
7f2d78e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import gradio as gr
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import json
import os
from train_resnet50 import ResNet, Bottleneck
# Load ImageNet class labels
try:
with open("imagenet_classes.json", "r") as f:
class_labels = json.load(f)
print(f"Loaded {len(class_labels)} class labels")
except FileNotFoundError:
print("Warning: imagenet_classes.json not found, creating simplified labels")
# Fallback to a simplified version
class_labels = {str(i): f"class_{i}" for i in range(1000)}
except json.JSONDecodeError:
print("Warning: Error parsing imagenet_classes.json, using simplified labels")
class_labels = {str(i): f"class_{i}" for i in range(1000)}
except Exception as e:
print(f"Warning: Unexpected error loading class labels: {e}")
class_labels = {str(i): f"class_{i}" for i in range(1000)}
def create_model():
model = ResNet(Bottleneck, [3, 4, 6, 3])
return model
def load_model(model_path):
model = create_model()
try:
checkpoint = torch.load(model_path, map_location="cpu")
# Handle DataParallel/DDP state dict
state_dict = checkpoint["model_state_dict"]
new_state_dict = {}
for k, v in state_dict.items():
name = k.replace("module.", "") if k.startswith("module.") else k
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
model.eval()
print("Model loaded successfully!")
return model
except Exception as e:
print(f"Error loading model: {e}")
print("Loading pretrained ResNet50 as fallback...")
model = torch.hub.load("pytorch/vision:v0.10.0", "resnet50", pretrained=True)
model.eval()
return model
# Preprocessing transform
transform = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
# Global variable for model
global_model = None
def predict(image):
global global_model
# Load model only once
if global_model is None:
try:
global_model = load_model("best_model.pth")
except Exception as e:
print(f"Error loading model: {e}")
return None
# Preprocess image
if image is None:
return None
try:
image = Image.fromarray(image)
image = transform(image).unsqueeze(0)
# Make prediction
with torch.no_grad():
outputs = global_model(image)
probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
# Get top 5 predictions
top5_prob, top5_catid = torch.topk(probabilities, 5)
# Create results dictionary
results = []
for i in range(5):
class_idx = top5_catid[i].item()
# Use list indexing instead of dictionary get()
class_label = (
class_labels[class_idx]
if class_idx < len(class_labels)
else f"class_{class_idx}"
)
results.append(
{
"label": class_label,
"class_id": class_idx,
"confidence": float(top5_prob[i].item()),
}
)
return results
except Exception as e:
print(f"Error during prediction: {e}")
print(f"Class indices: {[idx.item() for idx in top5_catid]}") # Debug info
return None
# Create Gradio interface
iface = gr.Interface(
fn=predict,
inputs=gr.Image(),
outputs=gr.JSON(),
title="ResNet50 ImageNet Classifier",
description="Upload an image to get top-5 predictions from our trained ResNet50 model.",
)
# Launch the app
if __name__ == "__main__":
iface.launch(share=True)
|