gupta1912 commited on
Commit
bd47fa2
·
1 Parent(s): 6ecbdd9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +144 -0
app.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, torchvision
2
+ from torchvision import transforms
3
+ from torchvision import datasets
4
+ import numpy as np
5
+ import gradio as gr
6
+ from PIL import Image
7
+ from pytorch_grad_cam import GradCAM
8
+ from pytorch_grad_cam.utils.image import preprocess_image, show_cam_on_image
9
+ import itertools
10
+ import matplotlib.pyplot as plt
11
+ from utils import LitCIFAR10
12
+
13
+ model = LitCIFAR10.load_from_checkpoint("model.ckpt")
14
+ model.eval()
15
+
16
+ classes = ('plane', 'car', 'bird', 'cat',
17
+ 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
18
+
19
+ means = [0.4914, 0.4822, 0.4465]
20
+ stds = [0.2470, 0.2435, 0.2616]
21
+
22
+ cifar_testset = datasets.CIFAR10(root='.', train=False, download=True)
23
+
24
+ transform=transforms.Compose([
25
+ transforms.ToTensor(),
26
+ transforms.Normalize(means, stds)
27
+ ])
28
+
29
+ class ClassifierOutputTarget:
30
+ def __init__(self, category):
31
+ self.category = category
32
+
33
+ def __call__(self, model_output):
34
+ if len(model_output.shape) == 1:
35
+ return model_output[self.category]
36
+ return model_output[:, self.category]
37
+
38
+ def inference(wants_gradcam, n_gradcam, target_layer_number, transparency, wants_misclassified, n_misclassified, input_img = None, n_top_classes=10):
39
+
40
+ if wants_gradcam:
41
+
42
+ outputs_inference_gc = []
43
+ count_gradcam = 1
44
+
45
+ for data, target in cifar_testset:
46
+
47
+ input_tensor = preprocess_image(data,
48
+ mean=means,
49
+ std=stds)
50
+ target_layers = [model.model.layer3[target_layer_number]]
51
+ targets = [ClassifierOutputTarget(target)]
52
+ cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False)
53
+ grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
54
+ grayscale_cam = grayscale_cam[0, :]
55
+ rgb_img = np.float32(data) / 255
56
+ visualization = np.array(show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True, image_weight=transparency))
57
+ outputs_inference_gc.append(visualization)
58
+
59
+ count_gradcam += 1
60
+ if count_gradcam > n_gradcam:
61
+ break
62
+ else:
63
+ outputs_inference_gc = None
64
+
65
+ if wants_misclassified:
66
+ outputs_inference_mis = []
67
+ count_mis = 1
68
+
69
+ for data_, target in cifar_testset:
70
+
71
+ data = transform(data_)
72
+ data = data.unsqueeze(0)
73
+ output = model(data)
74
+ pred = output.argmax(dim=1, keepdim=True)
75
+
76
+ if pred.item()!=target:
77
+
78
+ count_mis += 1
79
+
80
+
81
+ fig = plt.figure()
82
+ fig.add_subplot(111)
83
+
84
+ plt.imshow(data_)
85
+ plt.title(f'Target: {classes[target]}\nPred: {classes[pred.item()]}')
86
+ plt.axis('off')
87
+
88
+ fig.canvas.draw()
89
+
90
+ fig_img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
91
+ fig_img = fig_img.reshape(fig.canvas.get_width_height()[::-1] + (3,))
92
+
93
+ plt.close(fig)
94
+
95
+ outputs_inference_mis.append(fig_img)
96
+
97
+ if count_mis > n_misclassified:
98
+ break
99
+
100
+ else:
101
+ outputs_inference_mis = None
102
+
103
+ if input_img is not None:
104
+ data = transform(input_img)
105
+ data = data.unsqueeze(0)
106
+ output = model(data)
107
+ softmax = torch.nn.Softmax(dim=0)
108
+ o = softmax(output.flatten())
109
+ confidences = {classes[i]: float(o[i]) for i in range(10)}
110
+ _, prediction = torch.max(output, 1)
111
+
112
+ confidences = {k: v for k, v in sorted(confidences.items(), key=lambda item: item[1], reverse=True)}
113
+ confidences = dict(itertools.islice(confidences.items(), n_top_classes))
114
+ else:
115
+ confidences = None
116
+
117
+
118
+ return outputs_inference_gc, outputs_inference_mis, confidences
119
+
120
+
121
+ title = "CIFAR10 trained on Custom ResNet Model with GradCAM"
122
+ description = "A simple Gradio interface to infer on ResNet model, and get GradCAM results"
123
+ examples = [[None, None, None, None, None, None, 'examples/test_'+str(i)+'.jpg', None] for i in range(10)]
124
+
125
+ demo = gr.Interface(inference,
126
+ inputs = [gr.Checkbox(False, label='Do you want to see GradCAM outputs?'),
127
+ gr.Slider(0, 10, value = 0, step=1, label="How many?"),
128
+ gr.Slider(-2, -1, value = -2, step=1, label="Which target layer?"),
129
+ gr.Slider(0, 1, value = 0, label="Opacity of GradCAM"),
130
+ gr.Checkbox(False, label='Do you want to see misclassified images?'),
131
+ gr.Slider(0, 10, value = 0, step=1, label="How many?"),
132
+ gr.Image(shape=(32, 32), label="Input image"),
133
+ gr.Slider(0, 10, value = 0, step=1, label="How many top classes you want to see?")
134
+ ],
135
+ outputs = [
136
+ gr.Gallery(label="GradCAM Outputs", show_label=True, elem_id="gallery").style(columns=[2], rows=[2], object_fit="contain", height="auto"),
137
+ gr.Gallery(label="Misclassified Images", show_label=True, elem_id="gallery").style(columns=[2], rows=[2], object_fit="contain", height="auto"),
138
+ gr.Label(num_top_classes=None)
139
+ ],
140
+ title = title,
141
+ description = description,
142
+ examples = examples
143
+ )
144
+ demo.launch()