File size: 809 Bytes
310a06c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
# import the necessary packages
from utilities import config
from utilities import model
from utilities import visualization
from tensorflow import keras
import gradio as gr

# load the models from disk
conv_stem = keras.models.load_model(
	config.IMAGENETTE_STEM_PATH,
	compile=False
)
conv_trunk = keras.models.load_model(
	config.IMAGENETTE_TRUNK_PATH,
	compile=False
)
conv_attn = keras.models.load_model(
	config.IMAGENETTE_ATTN_PATH,
	compile=False
)

# create the patch conv net
patch_conv_net = model.PatchConvNet(
	stem=conv_stem,
	trunk=conv_trunk,
	attention_pooling=conv_attn,
)

# get the plot attention function
plot_attention = visualization.PlotAttention(model=patch_conv_net)
iface = gr.Interface(
	fn=plot_attention,
	inputs=[gr.inputs.Image(label="Input Image")],
	outputs="image").launch()