Mohammed Innat
Update app.py
9f2a615
raw
history blame
3.02 kB
import os
import gdown
import gradio as gr
import tensorflow as tf
from config import Parameters
from models.hybrid_model import GradientAccumulation
from utils.model_utils import *
from utils.viz_utils import make_gradcam_heatmap
from utils.viz_utils import save_and_display_gradcam
image_size = Parameters().image_size
str_labels = [
"daisy",
"dandelion",
"roses",
"sunflowers",
"tulips",
]
def get_model():
"""Get the model."""
model = GradientAccumulation(
n_gradients=params.num_grad_accumulation, model_name="HybridModel"
)
_ = model(tf.ones((1, params.image_size, params.image_size, 3)))[0].shape
return model
def get_model_weight(model_id):
"""Get the trained weights."""
if not os.path.exists("model.h5"):
model_weight = gdown.download(id=model_id, quiet=False)
else:
model_weight = "model.h5"
return model_weight
def load_model(model_id):
"""Load trained model."""
weight = get_model_weight(model_id)
model = get_model()
model.load_weights(weight)
return model
def image_process(image):
"""Image preprocess for model input."""
image = tf.cast(image, dtype=tf.float32)
original_shape = image.shape
image = tf.image.resize(image, [image_size, image_size])
image = image[tf.newaxis, ...]
return image, original_shape
def predict_fn(image):
"""A predict function that will be invoked by gradio."""
loaded_model = load_model(model_id="1y6tseN0194T6d-4iIh5wo7RL9ttQERe0")
loaded_image, original_shape = image_process(image)
heatmap_a, heatmap_b, preds = make_gradcam_heatmap(loaded_image, loaded_model)
int_label = tf.argmax(preds, axis=-1).numpy()[0]
str_label = str_labels[int_label]
overaly_a = save_and_display_gradcam(
loaded_image[0], heatmap_a, image_shape=original_shape[:2]
)
overlay_b = save_and_display_gradcam(
loaded_image[0], heatmap_b, image_shape=original_shape[:2]
)
return [f"Predicted: {str_label}", overaly_a, overlay_b]
iface = gr.Interface(
fn=predict_fn,
inputs=gr.inputs.Image(label="Input Image"),
outputs=[
gr.outputs.Label(label="Prediction"),
gr.inputs.Image(label="CNN GradCAM"),
gr.inputs.Image(label="Transformer GradCAM"),
],
title="Hybrid EfficientNet Swin Transformer Demo",
description="The model is trained on tf_flowers dataset <a href='https://www.kaggle.com/datasets/alxmamaev/flowers-recognition'>Flowers Recognition Dataset</a>. It provides 5 categories, namely: `daisy`, `rose`, `sunflower`, `tulip`, `dandelion`. One example from each class is provided in the Example section.",
article = "<div><center><img src='https://visitor-badge.glitch.me/badge?page_id=hybrid-gradcam' alt='visitor badge'></center></div>",
examples=[
["examples/dandelion.jpg"],
["examples/sunflower.jpg"],
["examples/tulip.jpg"],
["examples/daisy.jpg"],
["examples/rose.jpg"],
],
)
iface.launch()