import gradio as gr from transformers import AutoModelForImageClassification import torch import torchvision.transforms as transforms from PIL import Image import traceback import sys # Load model from Hub instead of local file def load_model(): try: model = AutoModelForImageClassification.from_pretrained( "nragrawal/resnet-imagenet1k", trust_remote_code=True ) model.eval() return model except Exception as e: print(f"Error loading model: {str(e)}") print(traceback.format_exc()) raise e # Preprocessing transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # Inference function def predict(image): try: model = load_model() # Preprocess image img = Image.fromarray(image) img = transform(img).unsqueeze(0) # Inference with torch.no_grad(): output = model(img) probabilities = torch.nn.functional.softmax(output[0], dim=0) # Get top 5 predictions top5_prob, top5_catid = torch.topk(probabilities, 5) return {f"Class {i}": float(prob) for i, prob in zip(top5_catid, top5_prob)} except Exception as e: print(f"Error during prediction: {str(e)}") print(traceback.format_exc()) return {"error": str(e)} # Create Gradio interface with error handling iface = gr.Interface( fn=predict, inputs=gr.Image(), outputs=gr.Label(num_top_classes=5), title="ResNet Image Classification", description="Upload an image to classify it using ResNet", allow_flagging="never" ) # Add error handling to launch try: iface.launch() except Exception as e: print(f"Error launching interface: {str(e)}") print(traceback.format_exc())