File size: 5,844 Bytes
5dedf78
96309db
 
 
 
 
 
 
5dedf78
96309db
 
 
 
5dedf78
96309db
 
5dedf78
96309db
 
5dedf78
96309db
 
 
 
 
 
5dedf78
96309db
 
 
 
 
 
5dedf78
96309db
 
 
5dedf78
96309db
 
 
 
 
 
 
 
 
5dedf78
96309db
 
 
 
 
5dedf78
96309db
 
 
 
5dedf78
96309db
 
 
 
 
 
 
5dedf78
96309db
 
5dedf78
96309db
 
 
 
 
 
 
 
 
 
5dedf78
96309db
 
5dedf78
96309db
5dedf78
 
96309db
5dedf78
96309db
5dedf78
 
96309db
5dedf78
96309db
09e742d
5dedf78
09e742d
5dedf78
 
09e742d
5dedf78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import glob
import gradio as gr
import torch
from transformers import ViTForImageClassification, AutoImageProcessor
from PIL import Image, ImageDraw, ImageFont
import random
import os

# βœ… Load model and processor from Hugging Face
model_name = "akridge/noaa-esd-coral-bleaching-vit-classifier-v1"
model = ViTForImageClassification.from_pretrained(model_name)
processor = AutoImageProcessor.from_pretrained(model_name, use_fast=True)

# πŸ—‚οΈ Ensure id2label keys are integers and labels are uppercase
id2label = {int(k): v.upper() for k, v in model.config.id2label.items()}

# 🎨 Label colors (RGBA)
LABEL_COLORS = {
    "CORAL": ((0, 0, 255, 80), (0, 0, 200)),        # Fill: Blue transparent, Border: Dark blue
    "CORAL_BL": ((255, 255, 255, 150), (150, 150, 150)),  # Fill: White transparent, Border: Gray
}

def predict_and_overlay(image, rows=2, cols=5, patch_size=224):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device).eval()

    # βœ… Load image
    image = image.convert('RGB')
    width, height = image.size
    scale_factor = max(width, height) / 800
    font_size = max(12, int(scale_factor * 8))
    border_width = max(3, int(scale_factor * 3))

    # πŸ–ŒοΈ Create overlay
    overlay = Image.new('RGBA', image.size, (255, 255, 255, 0))
    overlay_draw = ImageDraw.Draw(overlay)

    # βœ… Generate sampled points
    cell_width, cell_height = width / cols, height / rows
    sampled_points = [
        (random.randint(int(col * cell_width), int((col + 1) * cell_width - 1)),
         random.randint(int(row * cell_height), int((row + 1) * cell_height - 1)))
        for row in range(rows) for col in range(cols)
    ]

    predictions = []

    # βœ… Load font
    try:
        font = ImageFont.truetype("arial.ttf", size=font_size)
    except IOError:
        font = ImageFont.load_default()

    # πŸ” Predict and draw patches
    for x, y in sampled_points:
        left, upper = max(0, x - patch_size // 2), max(0, y - patch_size // 2)
        right, lower = min(width, left + patch_size), min(height, upper + patch_size)

        # 🧠 Predict label
        patch = image.crop((left, upper, right, lower))
        inputs = processor(images=patch, return_tensors="pt").to(device)
        with torch.no_grad():
            pred_id = model(**inputs).logits.argmax(-1).item()
            pred_label = id2label.get(pred_id, "UNKNOWN")
            predictions.append(pred_label)

        # 🎨 Fill and border colors
        fill_color, border_color = LABEL_COLORS.get(pred_label, ((200, 200, 200, 100), (100, 100, 100)))

        # 🟦 Draw filled rectangle and label
        overlay_draw.rectangle([(left, upper), (right, lower)], fill=fill_color, outline=border_color, width=border_width)

        label_text = pred_label
        bbox = overlay_draw.textbbox((0, 0), label_text, font=font)
        text_width, text_height = bbox[2] - bbox[0], bbox[3] - bbox[1]
        text_bg_coords = [(left, upper - text_height - 6), (left + text_width + 6, upper)]

        overlay_draw.rectangle(text_bg_coords, fill=(0, 0, 0, 200))
        overlay_draw.text((left + 3, upper - text_height - 4), label_text, fill="white", font=font)

    # πŸ–ΌοΈ Merge overlay with original
    final_image = Image.alpha_composite(image.convert('RGBA'), overlay).convert('RGB')

    return final_image, ", ".join(predictions)

# Load example images from coral_images folder
example_images = glob.glob("coral_images/*.[jp][pn]g")

# πŸš€ Gradio interface
def gradio_interface(image, rows=2, cols=5):
    if image is None:
        return None, "No image uploaded. Please upload an image or select from examples."
    final_image, predictions = predict_and_overlay(image, rows, cols)
    return final_image, predictions

app_title = "🌊 NOAA ESD Coral Bleaching Classifier Demo"
app_description = """

Upload a coral image or select from example images to sample points and predict coral bleaching using the classifier model.



**Model:** [akridge/noaa-esd-coral-bleaching-vit-classifier-v1](https://huggingface.co/akridge/noaa-esd-coral-bleaching-vit-classifier-v1)

**Dataset:** [NOAA-ESD-CORAL-Bleaching-Dataset](https://huggingface.co/datasets/akridge/NOAA-ESD-CORAL-Bleaching-Dataset)

"""

# Custom CSS for improved styling
custom_css = """

    .gradio-container h1 {

        font-size: 2.2em;

        text-align: center;

    }

    .gradio-container p {

        font-size: 1.2em;

    }

    .gradio-container .gr-button {

        font-size: 1.2em;

    }

"""

with gr.Blocks(theme=gr.themes.Ocean(), css=custom_css, title=app_title) as interface:
    gr.Markdown(f"<h1>{app_title}</h1>")
    gr.Markdown(app_description)

    with gr.Row():
        image_input = gr.Image(type="pil", label="Upload Coral Image")
        result_output = gr.Image(type="pil", label="Predicted Results")

    with gr.Row():
        rows_slider = gr.Slider(1, 10, value=2, label="Rows of Sample Points")
        cols_slider = gr.Slider(1, 10, value=5, label="Columns of Sample Points")

    with gr.Row():
        run_button = gr.Button("Run Prediction", variant="primary")
        clear_button = gr.Button("Clear")

    # Add example images section
    gr.Examples(
        examples=[[img] for img in example_images],
        inputs=[image_input],
        outputs=[result_output],
        examples_per_page=6,
        label="Example Coral Images"
    )

    # Define button actions
    run_button.click(
        fn=gradio_interface,
        inputs=[image_input, rows_slider, cols_slider],
        outputs=[result_output, gr.Textbox(label="Predictions")]
    )

    clear_button.click(lambda: (None, ""), outputs=[image_input, result_output])

interface.launch()