File size: 1,715 Bytes
310a06c
 
 
 
0f130d4
 
 
 
310a06c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f130d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
310a06c
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# import the necessary packages
from utilities import config
from utilities import model
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow as tf
import matplotlib.pyplot as plt
import math
import gradio as gr

# load the models from disk
conv_stem = keras.models.load_model(
	config.IMAGENETTE_STEM_PATH,
	compile=False
)
conv_trunk = keras.models.load_model(
	config.IMAGENETTE_TRUNK_PATH,
	compile=False
)
conv_attn = keras.models.load_model(
	config.IMAGENETTE_ATTN_PATH,
	compile=False
)

def plot_attention(image):
	# resize the image to a 224, 224 dim
	image = tf.image.convert_image_dtype(image, tf.float32)
	image = tf.image.resize(image, (224, 224))
	image = image[tf.newaxis, ...]

	# pass through the stem
	test_x = conv_stem(image)
	# pass through the trunk
	test_x = conv_trunk(test_x)
	# pass through the attention pooling block
	_, test_viz_weights = conv_attn(test_x)
	test_viz_weights = test_viz_weights[tf.newaxis, ...]
	
	# reshape the vizualization weights
	num_patches = tf.shape(test_viz_weights)[-1]
	height = width = int(math.sqrt(num_patches))
	test_viz_weights = layers.Reshape((height, width))(test_viz_weights)
	
	index = 0
	selected_image = image[index]
	selected_weight = test_viz_weights[index]

	fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))
	ax[0].imshow(selected_image)
	ax[0].set_title(f"Original")
	ax[0].axis("off")
	
	img = ax[1].imshow(selected_image)
	ax[1].imshow(selected_weight, cmap='inferno', alpha=0.6, extent=img.get_extent())
	ax[1].set_title(f"Attended")
	ax[1].axis("off")

	plt.axis("off")
	return plt

iface = gr.Interface(
	fn=plot_attention,
	inputs=[gr.inputs.Image(label="Input Image")],
	outputs="image").launch()