import os import traceback import gradio as gr import torch from torchvision.models import get_model from torchvision.transforms import transforms from torchvision.transforms.functional import InterpolationMode # Function to load the model with custom weights def load_model(weights_path): model = get_model("resnet50", num_classes=1000) ckpt = torch.load(weights_path, map_location=torch.device("cpu")) model.load_state_dict(ckpt["model_state_dict"]) model.eval() return model # Function for making predictions and returning top 5 predictions with confidence def classify_image(image): # Preprocess the input image image = transform(image).unsqueeze(0) # Add batch dimension with torch.no_grad(): output = model(image) # Get model output # The output has unnormalized scores. To get probabilities, you can run a softmax on it. probabilities = torch.nn.functional.softmax(output[0], dim=0) # Read the categories with open("imagenet_classes.txt", "r") as f: categories = [s.strip() for s in f.readlines()] # Show top categories per image top5_prob, top5_catid = torch.topk(probabilities, 5) result = {} for i in range(top5_prob.size(0)): result[categories[top5_catid[i]]] = top5_prob[i].item() return result # Define image transformation to match the model input transform = transforms.Compose([ transforms.Resize(256, interpolation=InterpolationMode.BILINEAR, antialias=True), transforms.CenterCrop(224), transforms.PILToTensor(), transforms.ConvertImageDtype(torch.float), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # Path to the pre-trained model weights (should be set by the user) model_weights_path = "best.pth" model = load_model(model_weights_path) # Define the Gradio interface iface = gr.Interface( fn=classify_image, # The function to run on input inputs=gr.Image(type="pil"), # Image input (in PIL format) outputs=gr.Label(num_top_classes=5), # Output will be the predicted top 5 classes with confidence scores title = "Image Recognition using ResNet-50 trained on Imagenet-1K", live = True, description = "

Gradio demo for ResNet, Deep residual networks pre-trained on ImageNet. To use it, simply upload your image, or click one of the examples to load them.

", article = "

\ Deep Residual Learning for Image Recognition | \ Github Repo \

", examples = [ ['examples/dog.jpg'], ['examples/great-white-shark.jpg'], ['examples/american-goldfinch.jpg'], ['examples/hognose-snake.jpg'] ] ) # Add error handling to launch try: iface.launch(share=True) except Exception as e: print(f"Error launching interface: {str(e)}") print(traceback.format_exc())