File size: 3,022 Bytes
0f09377
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42f5de6
9f2a615
0f09377
 
 
 
 
 
 
 
 
42f5de6
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os

import gdown
import gradio as gr
import tensorflow as tf

from config import Parameters
from models.hybrid_model import GradientAccumulation
from utils.model_utils import *
from utils.viz_utils import make_gradcam_heatmap
from utils.viz_utils import save_and_display_gradcam

image_size = Parameters().image_size
str_labels = [
    "daisy",
    "dandelion",
    "roses",
    "sunflowers",
    "tulips",
]


def get_model():
    """Get the model."""
    model = GradientAccumulation(
        n_gradients=params.num_grad_accumulation, model_name="HybridModel"
    )
    _ = model(tf.ones((1, params.image_size, params.image_size, 3)))[0].shape
    return model


def get_model_weight(model_id):
    """Get the trained weights."""
    if not os.path.exists("model.h5"):
        model_weight = gdown.download(id=model_id, quiet=False)
    else:
        model_weight = "model.h5"
    return model_weight


def load_model(model_id):
    """Load trained model."""
    weight = get_model_weight(model_id)
    model = get_model()
    model.load_weights(weight)
    return model


def image_process(image):
    """Image preprocess for model input."""
    image = tf.cast(image, dtype=tf.float32)
    original_shape = image.shape
    image = tf.image.resize(image, [image_size, image_size])
    image = image[tf.newaxis, ...]
    return image, original_shape


def predict_fn(image):
    """A predict function that will be invoked by gradio."""
    loaded_model = load_model(model_id="1y6tseN0194T6d-4iIh5wo7RL9ttQERe0")
    loaded_image, original_shape = image_process(image)

    heatmap_a, heatmap_b, preds = make_gradcam_heatmap(loaded_image, loaded_model)
    int_label = tf.argmax(preds, axis=-1).numpy()[0]
    str_label = str_labels[int_label]

    overaly_a = save_and_display_gradcam(
        loaded_image[0], heatmap_a, image_shape=original_shape[:2]
    )
    overlay_b = save_and_display_gradcam(
        loaded_image[0], heatmap_b, image_shape=original_shape[:2]
    )

    return [f"Predicted: {str_label}", overaly_a, overlay_b]


iface = gr.Interface(
    fn=predict_fn,
    inputs=gr.inputs.Image(label="Input Image"),
    outputs=[
        gr.outputs.Label(label="Prediction"),
        gr.inputs.Image(label="CNN GradCAM"),
        gr.inputs.Image(label="Transformer GradCAM"),
    ],
    title="Hybrid EfficientNet Swin Transformer Demo",
    description="The model is trained on tf_flowers dataset <a href='https://www.kaggle.com/datasets/alxmamaev/flowers-recognition'>Flowers Recognition Dataset</a>. It provides 5 categories, namely: `daisy`, `rose`, `sunflower`, `tulip`, `dandelion`. One example from each class is provided in the Example section.",
    article = "<div><center><img src='https://visitor-badge.glitch.me/badge?page_id=hybrid-gradcam' alt='visitor badge'></center></div>",
    examples=[
        ["examples/dandelion.jpg"],
        ["examples/sunflower.jpg"],
        ["examples/tulip.jpg"],
        ["examples/daisy.jpg"],
        ["examples/rose.jpg"],
    ],
)
iface.launch()