Spaces:
Running
Running
import os | |
from io import BytesIO | |
from pathlib import Path | |
from random import shuffle | |
import cv2 | |
import gradio as gr | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import torch | |
from mini_resnet import CustomResNet | |
from PIL import Image | |
from pytorch_grad_cam import GradCAM | |
from pytorch_grad_cam.utils.image import show_cam_on_image | |
from torchvision import transforms as T | |
mean = (0.49139968, 0.48215841, 0.44653091) | |
std = (0.24703223, 0.24348513, 0.26158784) | |
transforms = T.Compose([T.ToTensor(), T.Normalize(mean=mean, std=std)]) | |
classes = ("plane", "car", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck") | |
softmax = torch.nn.Softmax(dim=0) | |
model = CustomResNet() | |
model.load_state_dict(torch.load("model_weights/weights.pt", map_location=torch.device("cpu"))) | |
model.eval() | |
misclf_path = "images/miss_classified" | |
mis_classified_imgs = list(Path(misclf_path).glob("*")) | |
def get_traget_layer(block: str, layer: int): | |
layer_num = 0 if layer == 0 else -1 | |
if block == "block1": | |
return model.layer1[layer_num] | |
if block == "block2": | |
return model.layer2[layer_num] | |
if block == "block3": | |
return model.layer3[layer_num] | |
default_cam = GradCAM(model=model, target_layers=[get_traget_layer("block3", -1)]) | |
def make_image(p: Path | str, pred: str, label: str): | |
im = cv2.imread(str(p)) | |
im = cv2.resize(im, (64, 64)) | |
plt.imshow(im) | |
plt.title(f"{pred} / {label}") | |
plt.axis("off") | |
buffer = BytesIO() | |
plt.savefig(buffer, format="png") | |
buffer.seek(0) | |
img_array = np.frombuffer(buffer.getvalue(), dtype=np.uint8) | |
buffer.close() | |
# Decode the image array using OpenCV | |
im = cv2.imdecode(img_array, cv2.IMREAD_COLOR) | |
return im | |
def predict_img(img: np.ndarray, top_k: int = 10): | |
preds = model(img) | |
preds = softmax(preds.flatten()) | |
preds = {classes[i]: float(preds[i]) for i in range(10)} | |
preds = { | |
k: v for k, v in sorted(preds.items(), key=lambda item: item[1], reverse=True)[:top_k] | |
} | |
return preds | |
def display_cam(cam: GradCAM, org_img: np.ndarray, img: torch.Tensor, transparency: float): | |
grayscale_cam = cam(input_tensor=img, targets=None) | |
grayscale_cam = grayscale_cam[0, :] | |
visualization = show_cam_on_image( | |
org_img / 255, grayscale_cam, use_rgb=True, image_weight=transparency | |
) | |
return visualization | |
def inference( | |
org_img: np.ndarray, | |
top_k: int, | |
show_cam: str, | |
num_cam_imgs: int, | |
cam_block: str, | |
target_layer_num: int, | |
transparency: float, | |
show_misclf: str, | |
num_misclf: int, | |
): | |
input_img = transforms(org_img) | |
input_img = input_img.unsqueeze(0) | |
preds = predict_img(input_img, top_k) | |
org_img = display_cam(default_cam, org_img, input_img, transparency) | |
shuffle(mis_classified_imgs) | |
cam_outputs = [] | |
if show_cam: | |
img_list = [] | |
target_layers = [get_traget_layer(cam_block, target_layer_num)] | |
cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False) | |
for p in mis_classified_imgs[:num_cam_imgs]: | |
im = cv2.imread(str(p)) | |
inp_im = transforms(im) | |
inp_im = inp_im.unsqueeze(0) | |
grayscale_cam = cam(input_tensor=inp_im, targets=None) | |
grayscale_cam = grayscale_cam[0, :] | |
visualization = show_cam_on_image( | |
im / 255, grayscale_cam, use_rgb=True, image_weight=transparency | |
) | |
cam_outputs.append(visualization) | |
del cam, img_list | |
misclf_images_output = [] | |
if show_misclf: | |
img_list = [] | |
gt = [] | |
for p in mis_classified_imgs[:num_misclf]: | |
img_list.append(transforms(Image.open(p).convert("RGB"))) | |
gt.append(p.name.split("_")[0]) | |
misclf_out = softmax(model(torch.stack(img_list))).argmax(dim=1).tolist() | |
del img_list | |
for imp, pred, label in zip(mis_classified_imgs[:num_misclf], misclf_out, gt): | |
pred = classes[pred] | |
misclf_images_output.append(make_image(imp, pred, label)) | |
return org_img, preds, cam_outputs, misclf_images_output | |
title = "CIFAR10 trained on Custom Model inspired by ResNet with GradCAM" | |
description = "A simple Gradio interface to infer on ResNet model, and get GradCAM results" | |
# examples = [["cat.jpg", 0.5, -1], ["dog.jpg", 0.5, -1]] | |
demo = gr.Interface( | |
inference, | |
inputs=[ | |
gr.Image(shape=(32, 32), label="Input Image"), | |
gr.Slider(1, 10, value=3, step=1, label="Top K predictions"), | |
gr.Checkbox(label="Show Grad Cam"), | |
gr.Slider(1, 20, value=5, step=1, label="Number of images"), | |
gr.Radio(label="Which Block?", choices=["block1", "block2", "block3"]), | |
gr.Slider(0, 1, value=1, step=1, label="Which Layer?"), | |
gr.Slider(0, 1, value=0.5, label="Opacity of GradCAM"), | |
gr.Checkbox(label="Show Misclassified Images"), | |
gr.Slider(1, 20, value=5, step=5, label="Number of Misclassification Images"), | |
], | |
outputs=[ | |
gr.Image(shape=(32, 32), label="Output", width=128, height=128), | |
"label", | |
gr.Gallery(label="GradCAM Output"), | |
gr.Gallery( | |
label="Misclassified Images Pred/G.T.", | |
columns=[2], | |
rows=[2], | |
object_fit="contain", | |
height="auto", | |
), | |
], | |
title=title, | |
description=description, | |
# examples=examples, | |
) | |
demo.launch() |