File size: 5,844 Bytes
5dedf78 96309db 5dedf78 96309db 5dedf78 96309db 5dedf78 96309db 5dedf78 96309db 5dedf78 96309db 5dedf78 96309db 5dedf78 96309db 5dedf78 96309db 5dedf78 96309db 5dedf78 96309db 5dedf78 96309db 5dedf78 96309db 5dedf78 96309db 5dedf78 96309db 5dedf78 96309db 5dedf78 96309db 5dedf78 96309db 5dedf78 96309db 09e742d 5dedf78 09e742d 5dedf78 09e742d 5dedf78 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
import glob
import gradio as gr
import torch
from transformers import ViTForImageClassification, AutoImageProcessor
from PIL import Image, ImageDraw, ImageFont
import random
import os
# β
Load model and processor from Hugging Face
model_name = "akridge/noaa-esd-coral-bleaching-vit-classifier-v1"
model = ViTForImageClassification.from_pretrained(model_name)
processor = AutoImageProcessor.from_pretrained(model_name, use_fast=True)
# ποΈ Ensure id2label keys are integers and labels are uppercase
id2label = {int(k): v.upper() for k, v in model.config.id2label.items()}
# π¨ Label colors (RGBA)
LABEL_COLORS = {
"CORAL": ((0, 0, 255, 80), (0, 0, 200)), # Fill: Blue transparent, Border: Dark blue
"CORAL_BL": ((255, 255, 255, 150), (150, 150, 150)), # Fill: White transparent, Border: Gray
}
def predict_and_overlay(image, rows=2, cols=5, patch_size=224):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device).eval()
# β
Load image
image = image.convert('RGB')
width, height = image.size
scale_factor = max(width, height) / 800
font_size = max(12, int(scale_factor * 8))
border_width = max(3, int(scale_factor * 3))
# ποΈ Create overlay
overlay = Image.new('RGBA', image.size, (255, 255, 255, 0))
overlay_draw = ImageDraw.Draw(overlay)
# β
Generate sampled points
cell_width, cell_height = width / cols, height / rows
sampled_points = [
(random.randint(int(col * cell_width), int((col + 1) * cell_width - 1)),
random.randint(int(row * cell_height), int((row + 1) * cell_height - 1)))
for row in range(rows) for col in range(cols)
]
predictions = []
# β
Load font
try:
font = ImageFont.truetype("arial.ttf", size=font_size)
except IOError:
font = ImageFont.load_default()
# π Predict and draw patches
for x, y in sampled_points:
left, upper = max(0, x - patch_size // 2), max(0, y - patch_size // 2)
right, lower = min(width, left + patch_size), min(height, upper + patch_size)
# π§ Predict label
patch = image.crop((left, upper, right, lower))
inputs = processor(images=patch, return_tensors="pt").to(device)
with torch.no_grad():
pred_id = model(**inputs).logits.argmax(-1).item()
pred_label = id2label.get(pred_id, "UNKNOWN")
predictions.append(pred_label)
# π¨ Fill and border colors
fill_color, border_color = LABEL_COLORS.get(pred_label, ((200, 200, 200, 100), (100, 100, 100)))
# π¦ Draw filled rectangle and label
overlay_draw.rectangle([(left, upper), (right, lower)], fill=fill_color, outline=border_color, width=border_width)
label_text = pred_label
bbox = overlay_draw.textbbox((0, 0), label_text, font=font)
text_width, text_height = bbox[2] - bbox[0], bbox[3] - bbox[1]
text_bg_coords = [(left, upper - text_height - 6), (left + text_width + 6, upper)]
overlay_draw.rectangle(text_bg_coords, fill=(0, 0, 0, 200))
overlay_draw.text((left + 3, upper - text_height - 4), label_text, fill="white", font=font)
# πΌοΈ Merge overlay with original
final_image = Image.alpha_composite(image.convert('RGBA'), overlay).convert('RGB')
return final_image, ", ".join(predictions)
# Load example images from coral_images folder
example_images = glob.glob("coral_images/*.[jp][pn]g")
# π Gradio interface
def gradio_interface(image, rows=2, cols=5):
if image is None:
return None, "No image uploaded. Please upload an image or select from examples."
final_image, predictions = predict_and_overlay(image, rows, cols)
return final_image, predictions
app_title = "π NOAA ESD Coral Bleaching Classifier Demo"
app_description = """
Upload a coral image or select from example images to sample points and predict coral bleaching using the classifier model.
**Model:** [akridge/noaa-esd-coral-bleaching-vit-classifier-v1](https://huggingface.co/akridge/noaa-esd-coral-bleaching-vit-classifier-v1)
**Dataset:** [NOAA-ESD-CORAL-Bleaching-Dataset](https://huggingface.co/datasets/akridge/NOAA-ESD-CORAL-Bleaching-Dataset)
"""
# Custom CSS for improved styling
custom_css = """
.gradio-container h1 {
font-size: 2.2em;
text-align: center;
}
.gradio-container p {
font-size: 1.2em;
}
.gradio-container .gr-button {
font-size: 1.2em;
}
"""
with gr.Blocks(theme=gr.themes.Ocean(), css=custom_css, title=app_title) as interface:
gr.Markdown(f"<h1>{app_title}</h1>")
gr.Markdown(app_description)
with gr.Row():
image_input = gr.Image(type="pil", label="Upload Coral Image")
result_output = gr.Image(type="pil", label="Predicted Results")
with gr.Row():
rows_slider = gr.Slider(1, 10, value=2, label="Rows of Sample Points")
cols_slider = gr.Slider(1, 10, value=5, label="Columns of Sample Points")
with gr.Row():
run_button = gr.Button("Run Prediction", variant="primary")
clear_button = gr.Button("Clear")
# Add example images section
gr.Examples(
examples=[[img] for img in example_images],
inputs=[image_input],
outputs=[result_output],
examples_per_page=6,
label="Example Coral Images"
)
# Define button actions
run_button.click(
fn=gradio_interface,
inputs=[image_input, rows_slider, cols_slider],
outputs=[result_output, gr.Textbox(label="Predictions")]
)
clear_button.click(lambda: (None, ""), outputs=[image_input, result_output])
interface.launch()
|