# import the necessary packages from utilities import config from utilities import model from utilities import visualization from tensorflow import keras import gradio as gr # load the models from disk conv_stem = keras.models.load_model( config.IMAGENETTE_STEM_PATH, compile=False ) conv_trunk = keras.models.load_model( config.IMAGENETTE_TRUNK_PATH, compile=False ) conv_attn = keras.models.load_model( config.IMAGENETTE_ATTN_PATH, compile=False ) # create the patch conv net patch_conv_net = model.PatchConvNet( stem=conv_stem, trunk=conv_trunk, attention_pooling=conv_attn, ) # get the plot attention function plot_attention = visualization.PlotAttention(model=patch_conv_net) iface = gr.Interface( fn=plot_attention, inputs=[gr.inputs.Image(label="Input Image")], outputs="image").launch()