Spaces:
Sleeping
Sleeping
File size: 5,384 Bytes
bd47fa2 3e1ae80 bd47fa2 3e1ae80 bd47fa2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import torch, torchvision
from torchvision import transforms
from torchvision import datasets
import numpy as np
import gradio as gr
from PIL import Image
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import preprocess_image, show_cam_on_image
import itertools
import matplotlib.pyplot as plt
from utils import LitCIFAR10
model = LitCIFAR10.load_from_checkpoint("model/model.ckpt")
model.eval()
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
means = [0.4914, 0.4822, 0.4465]
stds = [0.2470, 0.2435, 0.2616]
cifar_testset = datasets.CIFAR10(root='.', train=False, download=True)
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(means, stds)
])
class ClassifierOutputTarget:
def __init__(self, category):
self.category = category
def __call__(self, model_output):
if len(model_output.shape) == 1:
return model_output[self.category]
return model_output[:, self.category]
def inference(wants_gradcam, n_gradcam, target_layer_number, transparency, wants_misclassified, n_misclassified, input_img = None, n_top_classes=10):
if wants_gradcam:
outputs_inference_gc = []
count_gradcam = 1
for data, target in cifar_testset:
input_tensor = preprocess_image(data,
mean=means,
std=stds)
target_layers = [model.model.layer3[target_layer_number]]
targets = [ClassifierOutputTarget(target)]
cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False)
grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
grayscale_cam = grayscale_cam[0, :]
rgb_img = np.float32(data) / 255
visualization = np.array(show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True, image_weight=transparency))
outputs_inference_gc.append(visualization)
count_gradcam += 1
if count_gradcam > n_gradcam:
break
else:
outputs_inference_gc = None
if wants_misclassified:
outputs_inference_mis = []
count_mis = 1
for data_, target in cifar_testset:
data = transform(data_)
data = data.unsqueeze(0)
output = model(data)
pred = output.argmax(dim=1, keepdim=True)
if pred.item()!=target:
count_mis += 1
fig = plt.figure()
fig.add_subplot(111)
plt.imshow(data_)
plt.title(f'Target: {classes[target]}\nPred: {classes[pred.item()]}')
plt.axis('off')
fig.canvas.draw()
fig_img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
fig_img = fig_img.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close(fig)
outputs_inference_mis.append(fig_img)
if count_mis > n_misclassified:
break
else:
outputs_inference_mis = None
if input_img is not None:
data = transform(input_img)
data = data.unsqueeze(0)
output = model(data)
softmax = torch.nn.Softmax(dim=0)
o = softmax(output.flatten())
confidences = {classes[i]: float(o[i]) for i in range(10)}
_, prediction = torch.max(output, 1)
confidences = {k: v for k, v in sorted(confidences.items(), key=lambda item: item[1], reverse=True)}
confidences = dict(itertools.islice(confidences.items(), n_top_classes))
else:
confidences = None
return outputs_inference_gc, outputs_inference_mis, confidences
title = "CIFAR10 trained on Custom ResNet Model with GradCAM"
description = "A simple Gradio interface to infer on ResNet model, and get GradCAM results"
examples = [[None, None, None, None, None, None, 'Images/test_'+str(i)+'.jpg', None] for i in range(10)]
demo = gr.Interface(inference,
inputs = [gr.Checkbox(False, label='Do you want to see GradCAM outputs?'),
gr.Slider(0, 10, value = 0, step=1, label="How many?"),
gr.Slider(-2, -1, value = -2, step=1, label="Which target layer?"),
gr.Slider(0, 1, value = 0, label="Opacity of GradCAM"),
gr.Checkbox(False, label='Do you want to see misclassified images?'),
gr.Slider(0, 10, value = 0, step=1, label="How many?"),
gr.Image(shape=(32, 32), label="Input image"),
gr.Slider(0, 10, value = 0, step=1, label="How many top classes you want to see?")
],
outputs = [
gr.Gallery(label="GradCAM Outputs", show_label=True, elem_id="gallery").style(columns=[2], rows=[2], object_fit="contain", height="auto"),
gr.Gallery(label="Misclassified Images", show_label=True, elem_id="gallery").style(columns=[2], rows=[2], object_fit="contain", height="auto"),
gr.Label(num_top_classes=None)
],
title = title,
description = description,
examples = examples
)
demo.launch() |