nragrawal's picture
Update username in app.py
33566fb
raw
history blame
1.97 kB
import gradio as gr
from transformers import AutoModelForImageClassification
import torch
import torchvision.transforms as transforms
from PIL import Image
import traceback
import sys
# Load model from Hub instead of local file
def load_model():
try:
model = AutoModelForImageClassification.from_pretrained(
"nragrawal/resnet-imagenet",
trust_remote_code=True
)
model.eval()
return model
except Exception as e:
print(f"Error loading model: {str(e)}")
print(traceback.format_exc())
raise e
# Preprocessing
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# Inference function
def predict(image):
try:
model = load_model()
# Preprocess image
img = Image.fromarray(image)
img = transform(img).unsqueeze(0)
# Inference
with torch.no_grad():
output = model(img)
probabilities = torch.nn.functional.softmax(output[0], dim=0)
# Get top 5 predictions
top5_prob, top5_catid = torch.topk(probabilities, 5)
return {f"Class {i}": float(prob) for i, prob in zip(top5_catid, top5_prob)}
except Exception as e:
print(f"Error during prediction: {str(e)}")
print(traceback.format_exc())
return {"error": str(e)}
# Create Gradio interface with error handling
iface = gr.Interface(
fn=predict,
inputs=gr.Image(),
outputs=gr.Label(num_top_classes=5),
title="ResNet Image Classification",
description="Upload an image to classify it using ResNet",
allow_flagging="never"
)
# Add error handling to launch
try:
iface.launch()
except Exception as e:
print(f"Error launching interface: {str(e)}")
print(traceback.format_exc())