File size: 5,972 Bytes
aaad6b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a898d5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c1b499f
c480201
 
 
bc0c135
 
c480201
c1b499f
e4cf61a
cc3a014
 
 
 
c1b499f
 
aaad6b0
69d871d
bc0c135
144a4d9
8ad0298
ec10c40
 
 
 
 
 
405db4f
 
 
c1b499f
ec10c40
c480201
 
 
 
 
 
 
6ced9ca
e204026
bc0c135
 
 
c480201
 
 
 
f690810
a898d5e
 
341395d
8ad0298
a898d5e
 
 
 
c1b499f
a898d5e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import torch, torchvision
from torchvision import transforms
import numpy as np
import gradio as gr
from PIL import Image
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from resnet import ResNet18
import gradio as gr

model = ResNet18()
model.load_state_dict(torch.load("model.pth", map_location=torch.device('cpu')), strict=False)

inv_normalize = transforms.Normalize(
    mean=[-0.50/0.23, -0.50/0.23, -0.50/0.23],
    std=[1/0.23, 1/0.23, 1/0.23]
)
classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

def inference(input_img, transparency = 0.5, target_layer_number = -1):
    transform = transforms.ToTensor()
    org_img = input_img
    input_img = transform(input_img)
    input_img = input_img
    input_img = input_img.unsqueeze(0)
    outputs = model(input_img)
    softmax = torch.nn.Softmax(dim=0)
    o = softmax(outputs.flatten())
    confidences = {classes[i]: float(o[i]) for i in range(10)}
    _, prediction = torch.max(outputs, 1)
    target_layers = [model.layer2[target_layer_number]]
    cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False)
    grayscale_cam = cam(input_tensor=input_img, targets=None)
    grayscale_cam = grayscale_cam[0, :]
    img = input_img.squeeze(0)
    img = inv_normalize(img)
    rgb_img = np.transpose(img, (1, 2, 0))
    rgb_img = rgb_img.numpy()
    visualization = show_cam_on_image(org_img/255, grayscale_cam, use_rgb=True, image_weight=transparency)
    return confidences, visualization

def inference_confidences(input_img, transparency = 0.5, target_layer_number = -1):
    transform = transforms.ToTensor()
    org_img = input_img
    input_img = transform(input_img)
    input_img = input_img
    input_img = input_img.unsqueeze(0)
    outputs = model(input_img)
    softmax = torch.nn.Softmax(dim=0)
    o = softmax(outputs.flatten())
    confidences = {classes[i]: float(o[i]) for i in range(10)}
    return confidences

def inference_visualization(input_img, transparency = 0.5, target_layer_number = -1):
    transform = transforms.ToTensor()
    org_img = input_img
    input_img = transform(input_img)
    input_img = input_img
    input_img = input_img.unsqueeze(0)
    outputs = model(input_img)
    softmax = torch.nn.Softmax(dim=0)
    o = softmax(outputs.flatten())
    confidences = {classes[i]: float(o[i]) for i in range(10)}
    _, prediction = torch.max(outputs, 1)
    target_layers = [model.layer2[target_layer_number]]
    cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False)
    grayscale_cam = cam(input_tensor=input_img, targets=None)
    grayscale_cam = grayscale_cam[0, :]
    img = input_img.squeeze(0)
    img = inv_normalize(img)
    rgb_img = np.transpose(img, (1, 2, 0))
    rgb_img = rgb_img.numpy()
    visualization = show_cam_on_image(org_img/255, grayscale_cam, use_rgb=True, image_weight=transparency)
    return visualization

# Callback function for the Gradio interface
# def gradio_callback(view_gradcam, num_gradcam_images, layer_name, opacity,
#                     view_misclassified, num_misclassified_images,
#                         input_img,submit):
def gradio_callback(view_grad_cam, num_gradcam_images, view_misclassified, num_misclassified_images,
                    input_img, transparency = 0.5, target_layer_number = -1):
    
    confidence = inference_confidences(input_img, transparency = 0.5, target_layer_number = -1)
    if view_grad_cam == "Yes":
        visualization = inference_visualization(input_img, transparency = 0.5, target_layer_number = -1)
        return confidence, visualization
    else:
        return confidence


title = "CIFAR10 trained on ResNet18 Model with GradCAM"
description = "Gradio interface to infer on ResNet18 model, and get GradCAM results"
examples = [["Yes",5,"Yes",5,"cat.jpg", 0.5, -1], ["Yes",5,"Yes",5,"dog.jpg", 0.5, -1]]

demo = gr.Interface(
    # inference, 
    # inputs = [gr.Image(shape=(32, 32), label="Input Image"), gr.Slider(0, 1, value = 0.5, label="Opacity of GradCAM"), gr.Slider(-2, -1, value = -2, step=1, label="Which Layer?")], 
    # outputs = [gr.Label(num_top_classes=3), gr.Image(shape=(32, 32), label="Output").style(width=128, height=128)],
    # title = title,
    # description = description,
    # examples = examples,
    title = title,
    escription = description,
    # examples = examples,
    fn=gradio_callback,  # We'll add the function later after defining all functions,  # We'll add the function later after defining all functions
    inputs=[
        # gr.Radio(["Yes", "No"], label="View GradCAM images?"),
        # gr.Number(label="Number of GradCAM images to view", default=5, max=10),
        # gr.Slider(-2, -1, value = -2, step=1, label="Which Layer?"),
        # gr.Slider(minimum=0.1, maximum=1.0, step=0.1, default=0.5, label="Opacity"),
        # gr.Radio(["Yes", "No"], label="View misclassified images?"),
        # gr.Number(label="Number of misclassified images to view", default=5, min=1, max=10),
        # gr.Image(shape=(32, 32), label="Input Image")
        # gr.Radio(["Yes", "No"], label="View GradCAM images?"),
        gr.Radio(["Yes", "No"], label="GradCAM images", info="View GradCAM images?"),
        gr.Number(label="Number of GradCAM images to view", default=5, max=10),
        gr.Radio(["Yes", "No"], label="View misclassified images?"),
        gr.Number(label="Number of misclassified images to view", default=5, min=1, max=10),
        gr.Image(shape=(32, 32), label="Input Image"), 
        gr.Slider(0, 1, value = 0.5, label="Opacity of GradCAM"), 
        gr.Slider(-2, -1, value = -2, step=1, label="Which Layer?")
                  
    ],
    outputs = [gr.Label(num_top_classes=3), gr.Image(shape=(32, 32), label="Output").style(width=128, height=128)],
    examples = examples,
    # live=True
)



# Set the callback function to the Gradio interface
# demo.fn = gradio_callback
demo.launch()