import gradio as gr import torch import torchvision.transforms as transforms from PIL import Image import traceback import sys from network import create_model # Import our model architecture # Load model from local checkpoint def load_model(): try: model = create_model(num_classes=1000) checkpoint = torch.load('model/model_best.pth', map_location='cpu') # Handle DataParallel state dict state_dict = checkpoint['model_state_dict'] # Remove 'module.' prefix if it exists new_state_dict = {} for k, v in state_dict.items(): name = k.replace('module.', '') # Remove 'module.' prefix new_state_dict[name] = v # Load the modified state dict model.load_state_dict(new_state_dict) model.eval() return model except Exception as e: print(f"Error loading model: {str(e)}") print(traceback.format_exc()) raise e # Load ImageNet class labels def load_labels(): try: with open('model/classes.txt', 'r') as f: labels = [line.strip() for line in f.readlines()] return labels except: return [f"Class_{i}" for i in range(100)] # Fallback to generic labels # Preprocessing transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # Global variables model = load_model() labels = load_labels() # Inference function def predict(image): try: # Preprocess image img = Image.fromarray(image) img = transform(img).unsqueeze(0) # Inference with torch.no_grad(): output = model(img) probabilities = torch.nn.functional.softmax(output[0], dim=0) # Get top 5 predictions top5_prob, top5_catid = torch.topk(probabilities, 5) return {labels[idx]: float(prob) for prob, idx in zip(top5_prob, top5_catid)} except Exception as e: print(f"Error during prediction: {str(e)}") print(traceback.format_exc()) return {"error": str(e)} # Create Gradio interface with error handling iface = gr.Interface( fn=predict, inputs=gr.Image(), outputs=gr.Label(num_top_classes=5), title="ResNet Image Classification", description="Upload an image to classify it using ResNet trained on ImageNet subset", allow_flagging="never" ) # Add error handling to launch try: iface.launch(share=True) except Exception as e: print(f"Error launching interface: {str(e)}") print(traceback.format_exc())