geninhu's picture
Add application file
history blame
4.04 kB
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from matplotlib import pyplot as plt
import numpy as np
import tensorflow as tf
import gradio as gr
from huggingface_hub import from_pretrained_keras
model = from_pretrained_keras('geninhu/attention_mil')
# functions for inference
# resize the image and it to a float between 0,1
def plot(input_images=None, predictions=None, attention_weights=None):
bag_class = np.argmax(predictions)
bag_class = 'This set of image does not contain number 8' if bag_class == 0 else 'This set of image contains number 8'
# attention_weights = [round(i, 2) for i in attention_weights]
prob_str = f"Each image probability: {attention_weights[0]:.2f}, {attention_weights[1]:.2f}, {attention_weights[2]:.2f}"
if input_images is not None:
figure = plt.figure(figsize=(8, 8))
for j in range(len(input_images)):
image = input_images[j]
figure.add_subplot(1, len(input_images), j + 1)
if attention_weights is not None:
return [bag_class, plt.gcf()]
return [bag_class, prob_str]
def preprocess_image(image):
# image = image[:, :, 0]
image = image / 255.0
image = np.expand_dims(image, axis = 0)
return image
def infer(input_images_1, input_images_2, input_images_3):
if (input_images_1 is not None) & (input_images_2 is not None) & (input_images_3 is not None):
# Normalize input data
input_images_1 = preprocess_image(input_images_1)
input_images_2 = preprocess_image(input_images_2)
input_images_3 = preprocess_image(input_images_3)
# Collect info per model.
prediction = model.predict([input_images_1, input_images_2, input_images_3])
prediction = np.squeeze(np.swapaxes(prediction, 1, 0))
intermediate_model = keras.Model(model.input, model.get_layer("alpha").output)
intermediate_predictions = intermediate_model.predict([input_images_1, input_images_2, input_images_3])
attention_weights = np.squeeze(np.swapaxes(intermediate_predictions, 1, 0))
return plot(
[input_images_1, input_images_2, input_images_3],
# get the inputs
input1 = gr.Image(shape=(28, 28), type='numpy', image_mode='L', label='First image', show_label=True, visible=True)
input2 = gr.Image(shape=(28, 28), type='numpy', image_mode='L', label='Second image', show_label=True, visible=True)
input3 = gr.Image(shape=(28, 28), type='numpy', image_mode='L', label='Third image', show_label=True, visible=True)
# the app outputs two segmented images
output = [gr.Label(), gr.Plot()]
# output = [gr.Plot()]
# it's good practice to pass examples, description and a title to guide users
title = 'Image classification'
description = 'Upload an image'
gr_interface = gr.Interface(
infer, inputs=[input1, input2, input3], outputs=output, allow_flagging='never',
analytics_enabled=False, title=title, description=description, live=True,
# examples = [[f'{i}.png' for i in range(0,3)], [f'{i}.png' for i in range(3,6)], [f'{i}.png' for i in range(6,9)], '9.png']
examples = [['samples/0.png', 'samples/6.png', 'samples/2.png'], ['samples/1.png','samples/2.png', 'samples/3.png'],
['samples/4.png', 'samples/8.png', 'samples/7.png'], ['samples/8.png', 'samples/0.png', 'samples/9.png'],
['samples/5.png', 'samples/6.png', 'samples/3.png'], ['samples/7.png', 'samples/8.png', 'samples/9.png']]
gr_interface.launch(enable_queue=True, debug=True, inbrowser=True)
# gr_interface = gr.Interface(infer, input, output, examples=examples, allow_flagging=False, analytics_enabled=False, title=title, description=description).launch(enable_queue=True, debug=False)
# gr_interface.launch()