patch-conv-net / app.py
ariG23498's picture
ariG23498 HF staff
chore: house cleaning
0f130d4
raw
history blame
1.72 kB
# import the necessary packages
from utilities import config
from utilities import model
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow as tf
import matplotlib.pyplot as plt
import math
import gradio as gr
# load the models from disk
conv_stem = keras.models.load_model(
config.IMAGENETTE_STEM_PATH,
compile=False
)
conv_trunk = keras.models.load_model(
config.IMAGENETTE_TRUNK_PATH,
compile=False
)
conv_attn = keras.models.load_model(
config.IMAGENETTE_ATTN_PATH,
compile=False
)
def plot_attention(image):
# resize the image to a 224, 224 dim
image = tf.image.convert_image_dtype(image, tf.float32)
image = tf.image.resize(image, (224, 224))
image = image[tf.newaxis, ...]
# pass through the stem
test_x = conv_stem(image)
# pass through the trunk
test_x = conv_trunk(test_x)
# pass through the attention pooling block
_, test_viz_weights = conv_attn(test_x)
test_viz_weights = test_viz_weights[tf.newaxis, ...]
# reshape the vizualization weights
num_patches = tf.shape(test_viz_weights)[-1]
height = width = int(math.sqrt(num_patches))
test_viz_weights = layers.Reshape((height, width))(test_viz_weights)
index = 0
selected_image = image[index]
selected_weight = test_viz_weights[index]
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))
ax[0].imshow(selected_image)
ax[0].set_title(f"Original")
ax[0].axis("off")
img = ax[1].imshow(selected_image)
ax[1].imshow(selected_weight, cmap='inferno', alpha=0.6, extent=img.get_extent())
ax[1].set_title(f"Attended")
ax[1].axis("off")
plt.axis("off")
return plt
iface = gr.Interface(
fn=plot_attention,
inputs=[gr.inputs.Image(label="Input Image")],
outputs="image").launch()