nelbarman053's picture
XAI implementation added
18b304b
raw
history blame
3.45 kB
# -*- coding: utf-8 -*-
"""xai_app.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1XRB_m0JRoi0KugiHw5WSPJzV0udfU69O
"""
import cv2
import numpy as np
import gradio as gr
import matplotlib as plt
from fastai.vision.all import *
from torchvision import transforms
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
import pathlib
temp = pathlib.PosixPath
pathlib.PosixPath = pathlib.WindowsPath
model = load_learner("models/recognizer_model.pkl")
# Transforming to pytorch model
pytorch_model = model.eval()
labels = ['Ayre', 'Catla', 'Chital', 'Ilish', 'Kachki', 'Kajoli', 'Koi', 'Magur', 'Mola Dhela', 'Mrigal', 'Pabda', 'Pangash', 'Poa', 'Puti', 'Rui', 'Shing', 'Silver Carp', 'Taki', 'Telapia', 'Tengra']
def xai_visualization(image, image_tensor, targeted_category, model, target_layers):
cam = GradCAM(model = model, target_layers = target_layers)
targets = [ClassifierOutputTarget(targeted_category)]
grayscale_cam = cam(input_tensor = image_tensor, targets = targets)
mask = grayscale_cam[0, :]
plt.figure(figsize=(5,5))
plt.axis('off')
plt.imshow(image)
plt.imshow(mask*255, cmap="plasma", alpha=0.7)
plt.savefig("xai/xai_visualization.png", dpi=150)
def preprocess_image(image_path):
# Resizing an image
image = cv2.resize(
image_path,
dsize=(224, 224),
interpolation=cv2.INTER_CUBIC)
# Converting image to tensor
img_tensor = transforms.ToTensor()(image)
# Converting image to batch
img_tensor = img_tensor.reshape(1,3,224,224)
return image, img_tensor
def target_layers_finding(model):
# Available layers
layers = list(model.named_modules())
# For Resnet-50
target_layers = [layers[len(layers)-20][1]]
return target_layers
def classify_image(image_path):
# Model Prediction
label, _, probs = model.predict(image_path)
# Predicted Category
targeted_category = np.argmax(probs)
# Preprocessed image and image tensor
image, img_tensor = preprocess_image(image_path)
# Target layer
target_layer = target_layers_finding(pytorch_model)
xai_visualization(image, img_tensor, targeted_category, pytorch_model, target_layer)
# print(f"Category with most probability: {np.argmax(probs)}")
xai_image = "xai/xai_visualization.png"
return xai_image, dict(zip(labels, map(float, probs)))
# classify_image('test images/unknown_01.jpg')
inputs = gr.Image(
label = "Input Image"
)
outputs = [
gr.Image(
label = "GradCAM visualization",
show_label = True
),
gr.Label(
num_top_classes=5,
label="Predicted Category"
)
]
examples = [
'test images/unknown_01.jpg',
'test images/unknown_02.png',
'test images/unknown_03.jpg',
'test images/unknown_04.jpg',
'test images/unknown_05.jpg',
'test images/unknown_06.jpg',
'test images/unknown_07.jpg',
'test images/unknown_08.jpg',
'test images/unknown_09.jpg',
'test images/unknown_10.jpg',
'test images/unknown_11.jpg',
'test images/unknown_12.png',
'test images/unknown_13.jpg',
'test images/unknown_14.png',
'test images/unknown_15.png',
'test images/unknown_16.png',
'test images/unknown_17.jpg'
]
interface = gr.Interface(
fn = classify_image,
inputs = inputs,
outputs = outputs,
examples = examples
)
interface.launch()