File size: 4,584 Bytes
96309db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import gradio as gr
import torch
from transformers import ViTForImageClassification, AutoImageProcessor
from PIL import Image, ImageDraw, ImageFont
import random
import os

# Load model and processor from Hugging Face
model_name = "akridge/noaa-esd-coral-bleaching-vit-classifier-v1"
model = ViTForImageClassification.from_pretrained(model_name)
processor = AutoImageProcessor.from_pretrained(model_name, use_fast=True)

# Ensure id2label keys are integers and labels are uppercase
id2label = {int(k): v.upper() for k, v in model.config.id2label.items()}

# Label colors (RGBA)
LABEL_COLORS = {
    "CORAL": ((0, 0, 255, 80), (0, 0, 200)),        # Fill: Blue transparent, Border: Dark blue
    "CORAL_BL": ((255, 255, 255, 150), (150, 150, 150))  # Fill: White transparent, Border: Gray
}

def predict_and_overlay(image, rows=2, cols=5, patch_size=224):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device).eval()

    # Load image
    image = image.convert('RGB')
    width, height = image.size
    scale_factor = max(width, height) / 800
    font_size = max(12, int(scale_factor * 8))
    border_width = max(3, int(scale_factor * 3))

    # Create overlay
    overlay = Image.new('RGBA', image.size, (255, 255, 255, 0))
    overlay_draw = ImageDraw.Draw(overlay)

    # Generate sampled points
    cell_width, cell_height = width / cols, height / rows
    sampled_points = [
        (random.randint(int(col * cell_width), int((col + 1) * cell_width - 1)),
         random.randint(int(row * cell_height), int((row + 1) * cell_height - 1)))
        for row in range(rows) for col in range(cols)
    ]

    predictions = []

    # Load font
    try:
        font = ImageFont.truetype("arial.ttf", size=font_size)
    except IOError:
        font = ImageFont.load_default()

    # Predict and draw patches
    for x, y in sampled_points:
        left, upper = max(0, x - patch_size // 2), max(0, y - patch_size // 2)
        right, lower = min(width, left + patch_size), min(height, upper + patch_size)

        # Predict label
        patch = image.crop((left, upper, right, lower))
        inputs = processor(images=patch, return_tensors="pt").to(device)
        with torch.no_grad():
            pred_id = model(**inputs).logits.argmax(-1).item()
            pred_label = id2label.get(pred_id, "UNKNOWN")
            predictions.append(pred_label)

        # Fill and border colors
        fill_color, border_color = LABEL_COLORS.get(pred_label, ((200, 200, 200, 100), (100, 100, 100)))

        # Draw filled rectangle and label
        overlay_draw.rectangle([(left, upper), (right, lower)], fill=fill_color, outline=border_color, width=border_width)

        label_text = pred_label
        bbox = overlay_draw.textbbox((0, 0), label_text, font=font)
        text_width, text_height = bbox[2] - bbox[0], bbox[3] - bbox[1]
        text_bg_coords = [(left, upper - text_height - 6), (left + text_width + 6, upper)]

        overlay_draw.rectangle(text_bg_coords, fill=(0, 0, 0, 200))
        overlay_draw.text((left + 3, upper - text_height - 4), label_text, fill="white", font=font)

    # Merge overlay with original
    final_image = Image.alpha_composite(image.convert('RGBA'), overlay).convert('RGB')

    return final_image, predictions

# Function to load example images
def load_example_image(example_image):
    return Image.open(example_image)

# List example images
example_images = [os.path.join("example_images", img) for img in os.listdir("coral_images")]

# Gradio interface
def gradio_interface(image, rows=2, cols=5):
    final_image, predictions = predict_and_overlay(image, rows, cols)
    return final_image, ", ".join(predictions)

iface = gr.Interface(
    fn=gradio_interface,
    inputs=[
        gr.inputs.Image(type="pil", label="Upload Coral Image", optional=True),
        gr.inputs.Dropdown(choices=example_images, label="Or Select an Example Image"),
        gr.inputs.Slider(1, 10, value=2, step=1, label="Rows of Sample Points"),
        gr.inputs.Slider(1, 10, value=5, step=1, label="Columns of Sample Points"),
    ],
    outputs=[
        gr.outputs.Image(type="pil", label="Image with Predictions"),
        gr.outputs.Textbox(label="Predictions")
    ],
    title="NOAA ESD Coral Bleaching ViT Classifier",
    description="Upload an image or select an example to sample points/patches and predict coral bleaching using the ViT classifier model hosted on Hugging Face."
)

iface.launch()