# import the necessary packages from utilities import config from utilities import model from tensorflow import keras from tensorflow.keras import layers import tensorflow as tf import matplotlib.pyplot as plt import math import gradio as gr # load the models from disk conv_stem = keras.models.load_model( config.IMAGENETTE_STEM_PATH, compile=False ) conv_trunk = keras.models.load_model( config.IMAGENETTE_TRUNK_PATH, compile=False ) conv_attn = keras.models.load_model( config.IMAGENETTE_ATTN_PATH, compile=False ) def plot_attention(image): # resize the image to a 224, 224 dim image = tf.image.convert_image_dtype(image, tf.float32) image = tf.image.resize(image, (224, 224)) image = image[tf.newaxis, ...] # pass through the stem test_x = conv_stem(image) # pass through the trunk test_x = conv_trunk(test_x) # pass through the attention pooling block _, test_viz_weights = conv_attn(test_x) test_viz_weights = test_viz_weights[tf.newaxis, ...] # reshape the vizualization weights num_patches = tf.shape(test_viz_weights)[-1] height = width = int(math.sqrt(num_patches)) test_viz_weights = layers.Reshape((height, width))(test_viz_weights) index = 0 selected_image = image[index] selected_weight = test_viz_weights[index] fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(10, 5)) ax[0].imshow(selected_image) ax[0].set_title(f"Original") ax[0].axis("off") img = ax[1].imshow(selected_image) ax[1].imshow(selected_weight, cmap='inferno', alpha=0.6, extent=img.get_extent()) ax[1].set_title(f"Attended") ax[1].axis("off") plt.axis("off") return plt iface = gr.Interface( fn=plot_attention, inputs=[gr.inputs.Image(label="Input Image")], outputs="image").launch()