Spaces:
Runtime error
Runtime error
File size: 809 Bytes
310a06c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
# import the necessary packages
from utilities import config
from utilities import model
from utilities import visualization
from tensorflow import keras
import gradio as gr
# load the models from disk
conv_stem = keras.models.load_model(
config.IMAGENETTE_STEM_PATH,
compile=False
)
conv_trunk = keras.models.load_model(
config.IMAGENETTE_TRUNK_PATH,
compile=False
)
conv_attn = keras.models.load_model(
config.IMAGENETTE_ATTN_PATH,
compile=False
)
# create the patch conv net
patch_conv_net = model.PatchConvNet(
stem=conv_stem,
trunk=conv_trunk,
attention_pooling=conv_attn,
)
# get the plot attention function
plot_attention = visualization.PlotAttention(model=patch_conv_net)
iface = gr.Interface(
fn=plot_attention,
inputs=[gr.inputs.Image(label="Input Image")],
outputs="image").launch() |