Spaces:
Runtime error
Runtime error
import matplotlib.cm as cm | |
import numpy as np | |
import tensorflow as tf | |
from tensorflow import keras | |
def make_gradcam_heatmap(img_array, grad_model, pred_index=None): | |
with tf.GradientTape(persistent=True) as tape: | |
preds, base_top, swin_top = grad_model(img_array) | |
if pred_index is None: | |
pred_index = tf.argmax(preds[0]) | |
class_channel = preds[:, pred_index] | |
grads = tape.gradient(class_channel, base_top) | |
pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2)) | |
base_top = base_top[0] | |
heatmap_a = base_top @ pooled_grads[..., tf.newaxis] | |
heatmap_a = tf.squeeze(heatmap_a) | |
heatmap_a = tf.maximum(heatmap_a, 0) / tf.math.reduce_max(heatmap_a) | |
heatmap_a = heatmap_a.numpy() | |
grads = tape.gradient(class_channel, swin_top) | |
pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2)) | |
swin_top = swin_top[0] | |
heatmap_b = swin_top @ pooled_grads[..., tf.newaxis] | |
heatmap_b = tf.squeeze(heatmap_b) | |
heatmap_b = tf.maximum(heatmap_b, 0) / tf.math.reduce_max(heatmap_b) | |
heatmap_b = heatmap_b.numpy() | |
return heatmap_a, heatmap_b, preds | |
def save_and_display_gradcam( | |
img, | |
heatmap, | |
target=None, | |
pred=None, | |
cam_path="cam.jpg", | |
cmap="jet", # inferno, viridis | |
alpha=0.6, | |
plot=None, | |
image_shape=None, | |
): | |
# Rescale heatmap to a range 0-255 | |
heatmap = np.uint8(255 * heatmap) | |
# Use jet colormap to colorize heatmap | |
jet = cm.get_cmap(cmap) | |
# Use RGB values of the colormap | |
jet_colors = jet(np.arange(256))[:, :3] | |
jet_heatmap = jet_colors[heatmap] | |
# Create an image with RGB colorized heatmap | |
jet_heatmap = keras.utils.array_to_img(jet_heatmap) | |
jet_heatmap = jet_heatmap.resize((img.shape[0], img.shape[1])) | |
jet_heatmap = keras.utils.img_to_array(jet_heatmap) | |
# Superimpose the heatmap on original image | |
superimposed_img = img + jet_heatmap * alpha | |
superimposed_img = keras.utils.array_to_img(superimposed_img) | |
size_w, size_h = image_shape[:2] | |
superimposed_img = superimposed_img.resize((size_h, size_w)) | |
return superimposed_img | |