File size: 4,055 Bytes
ebde07d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import warnings

warnings.filterwarnings("ignore")
import torch
import cv2
import numpy as np
from torchvision import transforms
from torchvision.models import resnet18, ResNet18_Weights
import urllib.request
from pytorch_grad_cam import GradCAMPlusPlus
from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
import gradio as gr

IMG_SIZE = 224
CLASSES = ResNet18_Weights.IMAGENET1K_V1.meta["categories"]
TOP_NUM_CLASSES = 3

url = (
    "https://upload.wikimedia.org/wikipedia/commons/3/38/Adorable-animal-cat-20787.jpg"
)
path_input = "./cat.jpg"
urllib.request.urlretrieve(url, filename=path_input)


url = "https://upload.wikimedia.org/wikipedia/commons/4/43/Cute_dog.jpg"
path_input = "./dog.jpg"
urllib.request.urlretrieve(url, filename=path_input)

device = "cpu"
if torch.cuda.is_available():
    device = "cuda"

model = resnet18(pretrained=True)

data_transforms = transforms.Compose(
    [
        transforms.Resize(IMG_SIZE),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)


def grad_campp(img, cls_ids):
    img_rz = cv2.resize(np.array(img), (IMG_SIZE, IMG_SIZE))
    img = np.float32(img_rz) / 255
    input_tensor = preprocess_image(img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]).to(
        device
    )
    # mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]

    # Set target layers
    target_layers = [model.layer4[-1]]

    # Set target classes
    # targets = [ClassifierOutputTarget(cls_id) for cls_id in cls_ids]

    # GradCAM++
    gradcampp = GradCAMPlusPlus(model=model, target_layers=target_layers)

    lst_gradcam = []
    for i in range(TOP_NUM_CLASSES):
        targets = [ClassifierOutputTarget(cls_ids[i])]
        grayscale_gradcampp = gradcampp(
            input_tensor=input_tensor,
            targets=targets,
            eigen_smooth=False,
            aug_smooth=False,
        )
        grayscale_gradcampp = grayscale_gradcampp[0, :]
        gradcampp_image = show_cam_on_image(img, grayscale_gradcampp, use_rgb=True)
        lst_gradcam.append(gradcampp_image)

    return img_rz, lst_gradcam


def do_inference(img):
    img_t = data_transforms(img)
    batch_t = torch.unsqueeze(img_t, 0)
    model.eval()
    # We don't need gradients for test, so wrap in
    # no_grad to save memory
    with torch.no_grad():
        batch_t = batch_t.to(device)
        # forward propagation
        output = model(batch_t)
        # get prediction
        probs = torch.nn.functional.softmax(output, dim=1)
        cls_ids = (
            torch.argsort(probs, dim=1, descending=True).cpu().numpy()[0].astype(int)
        )[:TOP_NUM_CLASSES]
        probs = probs.cpu().numpy()[0]
        probs = probs[cls_ids]
        labels = np.array(CLASSES)[cls_ids]
    img_rz, lst_gradcam = grad_campp(img, cls_ids)
    return (
        {labels[i]: round(float(probs[i]), 2) for i in range(len(labels))},
        img_rz,
        lst_gradcam[0],
        lst_gradcam[1],
        lst_gradcam[2],
    )


im = gr.inputs.Image(
    shape=None, image_mode="RGB", invert_colors=False, source="upload", type="pil"
)

title = "Explainable AI - PyTorch"
description = "Playground: GradCam Inferernce of Object Classification using ResNet18 model. Libraries: PyTorch, Gradio, Grad-Cam"
examples = [["./cat.jpg"], ["./dog.jpg"]]
article = "<p style='text-align: center'><a href='https://github.com/mawady' target='_blank'>By Dr. Mohamed Elawady</a></p>"
iface = gr.Interface(
    do_inference,
    im,
    outputs=[
        gr.outputs.Label(num_top_classes=TOP_NUM_CLASSES),
        gr.outputs.Image(label="Output image", type="pil"),
        gr.outputs.Image(label="Output image", type="pil"),
        gr.outputs.Image(label="Output image", type="pil"),
        gr.outputs.Image(label="Output image", type="pil"),
    ],
    live=False,
    interpretation=None,
    title=title,
    description=description,
    examples=examples,
)

# iface.test_launch()

iface.launch()