Spaces:
Runtime error
Runtime error
# import the necessary packages | |
from utilities import config | |
from tensorflow.keras import layers | |
from tensorflow import keras | |
import tensorflow as tf | |
import matplotlib.pyplot as plt | |
import math | |
import gradio as gr | |
# load the models from disk | |
conv_stem = keras.models.load_model( | |
config.IMAGENETTE_STEM_PATH, | |
compile=False | |
) | |
conv_trunk = keras.models.load_model( | |
config.IMAGENETTE_TRUNK_PATH, | |
compile=False | |
) | |
conv_attn = keras.models.load_model( | |
config.IMAGENETTE_ATTN_PATH, | |
compile=False | |
) | |
def plot_attention(image): | |
# resize the image to a 224, 224 dim | |
image = tf.image.convert_image_dtype(image, tf.float32) | |
image = tf.image.resize(image, (224, 224)) | |
image = image[tf.newaxis, ...] | |
# pass through the stem | |
test_x = conv_stem(image) | |
# pass through the trunk | |
test_x = conv_trunk(test_x) | |
# pass through the attention pooling block | |
_, test_viz_weights = conv_attn(test_x) | |
test_viz_weights = test_viz_weights[tf.newaxis, ...] | |
# reshape the vizualization weights | |
num_patches = tf.shape(test_viz_weights)[-1] | |
height = width = int(math.sqrt(num_patches)) | |
test_viz_weights = layers.Reshape((height, width))(test_viz_weights) | |
index = 0 | |
selected_image = image[index] | |
selected_weight = test_viz_weights[index] | |
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(10, 5)) | |
ax[0].imshow(selected_image) | |
ax[0].set_title(f"Original") | |
ax[0].axis("off") | |
img = ax[1].imshow(selected_image) | |
ax[1].imshow(selected_weight, cmap='inferno', alpha=0.6, extent=img.get_extent()) | |
ax[1].set_title(f"Attended") | |
ax[1].axis("off") | |
plt.axis("off") | |
return plt | |
iface = gr.Interface( | |
fn=plot_attention, | |
inputs=[gr.inputs.Image(label="Input Image")], | |
outputs="image").launch() |