Spaces:
Runtime error
Runtime error
File size: 5,972 Bytes
aaad6b0 a898d5e c1b499f c480201 bc0c135 c480201 c1b499f e4cf61a cc3a014 c1b499f aaad6b0 69d871d bc0c135 144a4d9 8ad0298 ec10c40 405db4f c1b499f ec10c40 c480201 6ced9ca e204026 bc0c135 c480201 f690810 a898d5e 341395d 8ad0298 a898d5e c1b499f a898d5e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import torch, torchvision
from torchvision import transforms
import numpy as np
import gradio as gr
from PIL import Image
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from resnet import ResNet18
import gradio as gr
model = ResNet18()
model.load_state_dict(torch.load("model.pth", map_location=torch.device('cpu')), strict=False)
inv_normalize = transforms.Normalize(
mean=[-0.50/0.23, -0.50/0.23, -0.50/0.23],
std=[1/0.23, 1/0.23, 1/0.23]
)
classes = ('plane', 'car', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck')
def inference(input_img, transparency = 0.5, target_layer_number = -1):
transform = transforms.ToTensor()
org_img = input_img
input_img = transform(input_img)
input_img = input_img
input_img = input_img.unsqueeze(0)
outputs = model(input_img)
softmax = torch.nn.Softmax(dim=0)
o = softmax(outputs.flatten())
confidences = {classes[i]: float(o[i]) for i in range(10)}
_, prediction = torch.max(outputs, 1)
target_layers = [model.layer2[target_layer_number]]
cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False)
grayscale_cam = cam(input_tensor=input_img, targets=None)
grayscale_cam = grayscale_cam[0, :]
img = input_img.squeeze(0)
img = inv_normalize(img)
rgb_img = np.transpose(img, (1, 2, 0))
rgb_img = rgb_img.numpy()
visualization = show_cam_on_image(org_img/255, grayscale_cam, use_rgb=True, image_weight=transparency)
return confidences, visualization
def inference_confidences(input_img, transparency = 0.5, target_layer_number = -1):
transform = transforms.ToTensor()
org_img = input_img
input_img = transform(input_img)
input_img = input_img
input_img = input_img.unsqueeze(0)
outputs = model(input_img)
softmax = torch.nn.Softmax(dim=0)
o = softmax(outputs.flatten())
confidences = {classes[i]: float(o[i]) for i in range(10)}
return confidences
def inference_visualization(input_img, transparency = 0.5, target_layer_number = -1):
transform = transforms.ToTensor()
org_img = input_img
input_img = transform(input_img)
input_img = input_img
input_img = input_img.unsqueeze(0)
outputs = model(input_img)
softmax = torch.nn.Softmax(dim=0)
o = softmax(outputs.flatten())
confidences = {classes[i]: float(o[i]) for i in range(10)}
_, prediction = torch.max(outputs, 1)
target_layers = [model.layer2[target_layer_number]]
cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False)
grayscale_cam = cam(input_tensor=input_img, targets=None)
grayscale_cam = grayscale_cam[0, :]
img = input_img.squeeze(0)
img = inv_normalize(img)
rgb_img = np.transpose(img, (1, 2, 0))
rgb_img = rgb_img.numpy()
visualization = show_cam_on_image(org_img/255, grayscale_cam, use_rgb=True, image_weight=transparency)
return visualization
# Callback function for the Gradio interface
# def gradio_callback(view_gradcam, num_gradcam_images, layer_name, opacity,
# view_misclassified, num_misclassified_images,
# input_img,submit):
def gradio_callback(view_grad_cam, num_gradcam_images, view_misclassified, num_misclassified_images,
input_img, transparency = 0.5, target_layer_number = -1):
confidence = inference_confidences(input_img, transparency = 0.5, target_layer_number = -1)
if view_grad_cam == "Yes":
visualization = inference_visualization(input_img, transparency = 0.5, target_layer_number = -1)
return confidence, visualization
else:
return confidence
title = "CIFAR10 trained on ResNet18 Model with GradCAM"
description = "Gradio interface to infer on ResNet18 model, and get GradCAM results"
examples = [["Yes",5,"Yes",5,"cat.jpg", 0.5, -1], ["Yes",5,"Yes",5,"dog.jpg", 0.5, -1]]
demo = gr.Interface(
# inference,
# inputs = [gr.Image(shape=(32, 32), label="Input Image"), gr.Slider(0, 1, value = 0.5, label="Opacity of GradCAM"), gr.Slider(-2, -1, value = -2, step=1, label="Which Layer?")],
# outputs = [gr.Label(num_top_classes=3), gr.Image(shape=(32, 32), label="Output").style(width=128, height=128)],
# title = title,
# description = description,
# examples = examples,
title = title,
escription = description,
# examples = examples,
fn=gradio_callback, # We'll add the function later after defining all functions, # We'll add the function later after defining all functions
inputs=[
# gr.Radio(["Yes", "No"], label="View GradCAM images?"),
# gr.Number(label="Number of GradCAM images to view", default=5, max=10),
# gr.Slider(-2, -1, value = -2, step=1, label="Which Layer?"),
# gr.Slider(minimum=0.1, maximum=1.0, step=0.1, default=0.5, label="Opacity"),
# gr.Radio(["Yes", "No"], label="View misclassified images?"),
# gr.Number(label="Number of misclassified images to view", default=5, min=1, max=10),
# gr.Image(shape=(32, 32), label="Input Image")
# gr.Radio(["Yes", "No"], label="View GradCAM images?"),
gr.Radio(["Yes", "No"], label="GradCAM images", info="View GradCAM images?"),
gr.Number(label="Number of GradCAM images to view", default=5, max=10),
gr.Radio(["Yes", "No"], label="View misclassified images?"),
gr.Number(label="Number of misclassified images to view", default=5, min=1, max=10),
gr.Image(shape=(32, 32), label="Input Image"),
gr.Slider(0, 1, value = 0.5, label="Opacity of GradCAM"),
gr.Slider(-2, -1, value = -2, step=1, label="Which Layer?")
],
outputs = [gr.Label(num_top_classes=3), gr.Image(shape=(32, 32), label="Output").style(width=128, height=128)],
examples = examples,
# live=True
)
# Set the callback function to the Gradio interface
# demo.fn = gradio_callback
demo.launch() |