Spaces:
Runtime error
Runtime error
import os | |
import gdown | |
import gradio as gr | |
import tensorflow as tf | |
from config import Parameters | |
from models.hybrid_model import GradientAccumulation | |
from utils.model_utils import * | |
from utils.viz_utils import make_gradcam_heatmap | |
from utils.viz_utils import save_and_display_gradcam | |
image_size = Parameters().image_size | |
str_labels = [ | |
"daisy", | |
"dandelion", | |
"roses", | |
"sunflowers", | |
"tulips", | |
] | |
def get_model(): | |
"""Get the model.""" | |
model = GradientAccumulation( | |
n_gradients=params.num_grad_accumulation, model_name="HybridModel" | |
) | |
_ = model(tf.ones((1, params.image_size, params.image_size, 3)))[0].shape | |
return model | |
def get_model_weight(model_id): | |
"""Get the trained weights.""" | |
if not os.path.exists("model.h5"): | |
model_weight = gdown.download(id=model_id, quiet=False) | |
else: | |
model_weight = "model.h5" | |
return model_weight | |
def load_model(model_id): | |
"""Load trained model.""" | |
weight = get_model_weight(model_id) | |
model = get_model() | |
model.load_weights(weight) | |
return model | |
def image_process(image): | |
"""Image preprocess for model input.""" | |
image = tf.cast(image, dtype=tf.float32) | |
original_shape = image.shape | |
image = tf.image.resize(image, [image_size, image_size]) | |
image = image[tf.newaxis, ...] | |
return image, original_shape | |
def predict_fn(image): | |
"""A predict function that will be invoked by gradio.""" | |
loaded_model = load_model(model_id="1y6tseN0194T6d-4iIh5wo7RL9ttQERe0") | |
loaded_image, original_shape = image_process(image) | |
heatmap_a, heatmap_b, preds = make_gradcam_heatmap(loaded_image, loaded_model) | |
int_label = tf.argmax(preds, axis=-1).numpy()[0] | |
str_label = str_labels[int_label] | |
overaly_a = save_and_display_gradcam( | |
loaded_image[0], heatmap_a, image_shape=original_shape[:2] | |
) | |
overlay_b = save_and_display_gradcam( | |
loaded_image[0], heatmap_b, image_shape=original_shape[:2] | |
) | |
return [f"Predicted: {str_label}", overaly_a, overlay_b] | |
iface = gr.Interface( | |
fn=predict_fn, | |
inputs=gr.inputs.Image(label="Input Image"), | |
outputs=[ | |
gr.outputs.Label(label="Prediction"), | |
gr.inputs.Image(label="CNN GradCAM"), | |
gr.inputs.Image(label="Transformer GradCAM"), | |
], | |
title="Hybrid EfficientNet Swin Transformer Demo", | |
description="The model is trained on tf_flowers dataset <a href='https://www.kaggle.com/datasets/alxmamaev/flowers-recognition'>Flowers Recognition Dataset</a>. It provides 5 categories, namely: `daisy`, `rose`, `sunflower`, `tulip`, `dandelion`. One example from each class is provided in the Example section.", | |
article = "<div><center><img src='https://visitor-badge.glitch.me/badge?page_id=hybrid-gradcam' alt='visitor badge'></center></div>", | |
examples=[ | |
["examples/dandelion.jpg"], | |
["examples/sunflower.jpg"], | |
["examples/tulip.jpg"], | |
["examples/daisy.jpg"], | |
["examples/rose.jpg"], | |
], | |
) | |
iface.launch() | |