import glob import gradio as gr import torch from transformers import ViTForImageClassification, AutoImageProcessor from PIL import Image, ImageDraw, ImageFont import random import os # ✅ Load model and processor from Hugging Face model_name = "akridge/noaa-esd-coral-bleaching-vit-classifier-v1" model = ViTForImageClassification.from_pretrained(model_name) processor = AutoImageProcessor.from_pretrained(model_name, use_fast=True) # 🗂️ Ensure id2label keys are integers and labels are uppercase id2label = {int(k): v.upper() for k, v in model.config.id2label.items()} # 🎨 Label colors (RGBA) LABEL_COLORS = { "CORAL": ((0, 0, 255, 80), (0, 0, 200)), # Fill: Blue transparent, Border: Dark blue "CORAL_BL": ((255, 255, 255, 150), (150, 150, 150)), # Fill: White transparent, Border: Gray } def predict_and_overlay(image, rows=2, cols=5, patch_size=224): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device).eval() # ✅ Load image image = image.convert('RGB') width, height = image.size scale_factor = max(width, height) / 800 font_size = max(12, int(scale_factor * 8)) border_width = max(3, int(scale_factor * 3)) # 🖌️ Create overlay overlay = Image.new('RGBA', image.size, (255, 255, 255, 0)) overlay_draw = ImageDraw.Draw(overlay) # ✅ Generate sampled points cell_width, cell_height = width / cols, height / rows sampled_points = [ (random.randint(int(col * cell_width), int((col + 1) * cell_width - 1)), random.randint(int(row * cell_height), int((row + 1) * cell_height - 1))) for row in range(rows) for col in range(cols) ] predictions = [] # ✅ Load font try: font = ImageFont.truetype("arial.ttf", size=font_size) except IOError: font = ImageFont.load_default() # 🔍 Predict and draw patches for x, y in sampled_points: left, upper = max(0, x - patch_size // 2), max(0, y - patch_size // 2) right, lower = min(width, left + patch_size), min(height, upper + patch_size) # 🧠 Predict label patch = image.crop((left, upper, right, lower)) inputs = processor(images=patch, return_tensors="pt").to(device) with torch.no_grad(): pred_id = model(**inputs).logits.argmax(-1).item() pred_label = id2label.get(pred_id, "UNKNOWN") predictions.append(pred_label) # 🎨 Fill and border colors fill_color, border_color = LABEL_COLORS.get(pred_label, ((200, 200, 200, 100), (100, 100, 100))) # 🟦 Draw filled rectangle and label overlay_draw.rectangle([(left, upper), (right, lower)], fill=fill_color, outline=border_color, width=border_width) label_text = pred_label bbox = overlay_draw.textbbox((0, 0), label_text, font=font) text_width, text_height = bbox[2] - bbox[0], bbox[3] - bbox[1] text_bg_coords = [(left, upper - text_height - 6), (left + text_width + 6, upper)] overlay_draw.rectangle(text_bg_coords, fill=(0, 0, 0, 200)) overlay_draw.text((left + 3, upper - text_height - 4), label_text, fill="white", font=font) # 🖼️ Merge overlay with original final_image = Image.alpha_composite(image.convert('RGBA'), overlay).convert('RGB') return final_image, ", ".join(predictions) # Load example images from coral_images folder example_images = glob.glob("coral_images/*.[jp][pn]g") # 🚀 Gradio interface def gradio_interface(image, rows=2, cols=5): if image is None: return None, "No image uploaded. Please upload an image or select from examples." final_image, predictions = predict_and_overlay(image, rows, cols) return final_image, predictions app_title = "🌊 NOAA ESD Coral Bleaching Classifier Demo" app_description = """ Upload a coral image or select from example images to sample points and predict coral bleaching using the classifier model. **Model:** [akridge/noaa-esd-coral-bleaching-vit-classifier-v1](https://huggingface.co/akridge/noaa-esd-coral-bleaching-vit-classifier-v1) **Dataset:** [NOAA-ESD-CORAL-Bleaching-Dataset](https://huggingface.co/datasets/akridge/NOAA-ESD-CORAL-Bleaching-Dataset) """ # Custom CSS for improved styling custom_css = """ .gradio-container h1 { font-size: 2.2em; text-align: center; } .gradio-container p { font-size: 1.2em; } .gradio-container .gr-button { font-size: 1.2em; } """ with gr.Blocks(theme=gr.themes.Ocean(), css=custom_css, title=app_title) as interface: gr.Markdown(f"

{app_title}

") gr.Markdown(app_description) with gr.Row(): image_input = gr.Image(type="pil", label="Upload Coral Image") result_output = gr.Image(type="pil", label="Predicted Results") with gr.Row(): rows_slider = gr.Slider(1, 10, value=2, label="Rows of Sample Points") cols_slider = gr.Slider(1, 10, value=5, label="Columns of Sample Points") with gr.Row(): run_button = gr.Button("Run Prediction", variant="primary") clear_button = gr.Button("Clear") # Add example images section gr.Examples( examples=[[img] for img in example_images], inputs=[image_input], outputs=[result_output], examples_per_page=6, label="Example Coral Images" ) # Define button actions run_button.click( fn=gradio_interface, inputs=[image_input, rows_slider, cols_slider], outputs=[result_output, gr.Textbox(label="Predictions")] ) clear_button.click(lambda: (None, ""), outputs=[image_input, result_output]) interface.launch()