File size: 3,237 Bytes
71413cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import os
import traceback
import gradio as gr
import torch
from torchvision.models import get_model
from torchvision.transforms import v2
from torchvision.transforms.functional import InterpolationMode


# Imagenet-1k classes
if not os.path.exists("imagenet_classes.txt"):
    os.system("wget https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt")

# Download an example image from the pytorch website
if not os.path.exists("dog.jpg"):
    torch.hub.download_url_to_file("https://github.com/pytorch/hub/raw/master/images/dog.jpg", "dog.jpg")


# Function to load the model with custom weights
def load_model(weights_path):
    model = get_model("resnet50", num_classes=1000)
    ckpt = torch.load(weights_path, map_location=torch.device("cpu"))
    model.load_state_dict(ckpt["model_state_dict"])
    model.eval()
    return model


# Function for making predictions and returning top 5 predictions with confidence
def classify_image(image):
    # Preprocess the input image
    image = transform(image).unsqueeze(0)  # Add batch dimension

    with torch.no_grad():
        output = model(image)  # Get model output
    
    # The output has unnormalized scores. To get probabilities, you can run a softmax on it.
    probabilities = torch.nn.functional.softmax(output[0], dim=0)

    # Read the categories
    with open("imagenet_classes.txt", "r") as f:
        categories = [s.strip() for s in f.readlines()]
    
    # Show top categories per image
    top5_prob, top5_catid = torch.topk(probabilities, 5)
    result = {}
    for i in range(top5_prob.size(0)):
        result[categories[top5_catid[i]]] = top5_prob[i].item()
    return result

# Define image transformation to match the model input
transform = v2.Compose([
    v2.Resize(256, interpolation=InterpolationMode.BILINEAR, antialias=True),
    v2.CenterCrop(224),
    v2.PILToTensor(),
    v2.ToDtype(torch.float, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    v2.ToPureTensor(),
])

# Path to the pre-trained model weights (should be set by the user)
model_weights_path = "best.pth"
model = load_model(model_weights_path)

# Define the Gradio interface
iface = gr.Interface(
    fn=classify_image,  # The function to run on input
    inputs=gr.Image(type="pil"),  # Image input (in PIL format)
    outputs=gr.Label(num_top_classes=5),  # Output will be the predicted top 5 classes with confidence scores
    title = "Image Recognition using ResNet-50 trained on Imagenet-1K",
    description = "<p style='text-align: center'> Gradio demo for ResNet, Deep residual networks pre-trained on ImageNet. To use it, simply upload your image, or click one of the examples to load them. </p>",
    article = "<p style='text-align: center'> \
        <a href='https://arxiv.org/abs/1512.03385' target='_blank'>Deep Residual Learning for Image Recognition</a> | \
            <a href='https://github.com/KD1994/session-9-imagenet-resnet50' target='_blank'>Github Repo</a> \
                </p>",
    examples = [
        ['dog.jpg']
    ]
)

# Add error handling to launch
try:
    iface.launch(share=True)
except Exception as e:
    print(f"Error launching interface: {str(e)}")
    print(traceback.format_exc())