HybridModel-GradCAM / utils /viz_utils.py
innat
init
0f09377
raw
history blame
2.11 kB
import matplotlib.cm as cm
import numpy as np
import tensorflow as tf
from tensorflow import keras
def make_gradcam_heatmap(img_array, grad_model, pred_index=None):
with tf.GradientTape(persistent=True) as tape:
preds, base_top, swin_top = grad_model(img_array)
if pred_index is None:
pred_index = tf.argmax(preds[0])
class_channel = preds[:, pred_index]
grads = tape.gradient(class_channel, base_top)
pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
base_top = base_top[0]
heatmap_a = base_top @ pooled_grads[..., tf.newaxis]
heatmap_a = tf.squeeze(heatmap_a)
heatmap_a = tf.maximum(heatmap_a, 0) / tf.math.reduce_max(heatmap_a)
heatmap_a = heatmap_a.numpy()
grads = tape.gradient(class_channel, swin_top)
pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
swin_top = swin_top[0]
heatmap_b = swin_top @ pooled_grads[..., tf.newaxis]
heatmap_b = tf.squeeze(heatmap_b)
heatmap_b = tf.maximum(heatmap_b, 0) / tf.math.reduce_max(heatmap_b)
heatmap_b = heatmap_b.numpy()
return heatmap_a, heatmap_b, preds
def save_and_display_gradcam(
img,
heatmap,
target=None,
pred=None,
cam_path="cam.jpg",
cmap="jet", # inferno, viridis
alpha=0.6,
plot=None,
image_shape=None,
):
# Rescale heatmap to a range 0-255
heatmap = np.uint8(255 * heatmap)
# Use jet colormap to colorize heatmap
jet = cm.get_cmap(cmap)
# Use RGB values of the colormap
jet_colors = jet(np.arange(256))[:, :3]
jet_heatmap = jet_colors[heatmap]
# Create an image with RGB colorized heatmap
jet_heatmap = keras.utils.array_to_img(jet_heatmap)
jet_heatmap = jet_heatmap.resize((img.shape[0], img.shape[1]))
jet_heatmap = keras.utils.img_to_array(jet_heatmap)
# Superimpose the heatmap on original image
superimposed_img = img + jet_heatmap * alpha
superimposed_img = keras.utils.array_to_img(superimposed_img)
size_w, size_h = image_shape[:2]
superimposed_img = superimposed_img.resize((size_h, size_w))
return superimposed_img