import os import traceback import gradio as gr import torch from torchvision.models import get_model from torchvision.transforms import v2 from torchvision.transforms.functional import InterpolationMode # Imagenet-1k classes if not os.path.exists("imagenet_classes.txt"): os.system("wget https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt") # Download an example image from the pytorch website if not os.path.exists("dog.jpg"): torch.hub.download_url_to_file("https://github.com/pytorch/hub/raw/master/images/dog.jpg", "dog.jpg") # Function to load the model with custom weights def load_model(weights_path): model = get_model("resnet50", num_classes=1000) ckpt = torch.load(weights_path, map_location=torch.device("cpu")) model.load_state_dict(ckpt["model_state_dict"]) model.eval() return model # Function for making predictions and returning top 5 predictions with confidence def classify_image(image): # Preprocess the input image image = transform(image).unsqueeze(0) # Add batch dimension with torch.no_grad(): output = model(image) # Get model output # The output has unnormalized scores. To get probabilities, you can run a softmax on it. probabilities = torch.nn.functional.softmax(output[0], dim=0) # Read the categories with open("imagenet_classes.txt", "r") as f: categories = [s.strip() for s in f.readlines()] # Show top categories per image top5_prob, top5_catid = torch.topk(probabilities, 5) result = {} for i in range(top5_prob.size(0)): result[categories[top5_catid[i]]] = top5_prob[i].item() return result # Define image transformation to match the model input transform = v2.Compose([ v2.Resize(256, interpolation=InterpolationMode.BILINEAR, antialias=True), v2.CenterCrop(224), v2.PILToTensor(), v2.ToDtype(torch.float, scale=True), v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), v2.ToPureTensor(), ]) # Path to the pre-trained model weights (should be set by the user) model_weights_path = "best.pth" model = load_model(model_weights_path) # Define the Gradio interface iface = gr.Interface( fn=classify_image, # The function to run on input inputs=gr.Image(type="pil"), # Image input (in PIL format) outputs=gr.Label(num_top_classes=5), # Output will be the predicted top 5 classes with confidence scores title = "Image Recognition using ResNet-50 trained on Imagenet-1K", description = "<p style='text-align: center'> Gradio demo for ResNet, Deep residual networks pre-trained on ImageNet. To use it, simply upload your image, or click one of the examples to load them. </p>", article = "<p style='text-align: center'> \ <a href='https://arxiv.org/abs/1512.03385' target='_blank'>Deep Residual Learning for Image Recognition</a> | \ <a href='https://github.com/KD1994/session-9-imagenet-resnet50' target='_blank'>Github Repo</a> \ </p>", examples = [ ['dog.jpg'] ] ) # Add error handling to launch try: iface.launch(share=True) except Exception as e: print(f"Error launching interface: {str(e)}") print(traceback.format_exc())