import gradio as gr import torch from transformers import ViTForImageClassification, AutoImageProcessor from PIL import Image, ImageDraw, ImageFont import random import os # Load model and processor from Hugging Face model_name = "akridge/noaa-esd-coral-bleaching-vit-classifier-v1" model = ViTForImageClassification.from_pretrained(model_name) processor = AutoImageProcessor.from_pretrained(model_name, use_fast=True) # Ensure id2label keys are integers and labels are uppercase id2label = {int(k): v.upper() for k, v in model.config.id2label.items()} # Label colors (RGBA) LABEL_COLORS = { "CORAL": ((0, 0, 255, 80), (0, 0, 200)), # Fill: Blue transparent, Border: Dark blue "CORAL_BL": ((255, 255, 255, 150), (150, 150, 150)) # Fill: White transparent, Border: Gray } def predict_and_overlay(image, rows=2, cols=5, patch_size=224): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device).eval() # Load image image = image.convert('RGB') width, height = image.size scale_factor = max(width, height) / 800 font_size = max(12, int(scale_factor * 8)) border_width = max(3, int(scale_factor * 3)) # Create overlay overlay = Image.new('RGBA', image.size, (255, 255, 255, 0)) overlay_draw = ImageDraw.Draw(overlay) # Generate sampled points cell_width, cell_height = width / cols, height / rows sampled_points = [ (random.randint(int(col * cell_width), int((col + 1) * cell_width - 1)), random.randint(int(row * cell_height), int((row + 1) * cell_height - 1))) for row in range(rows) for col in range(cols) ] predictions = [] # Load font try: font = ImageFont.truetype("arial.ttf", size=font_size) except IOError: font = ImageFont.load_default() # Predict and draw patches for x, y in sampled_points: left, upper = max(0, x - patch_size // 2), max(0, y - patch_size // 2) right, lower = min(width, left + patch_size), min(height, upper + patch_size) # Predict label patch = image.crop((left, upper, right, lower)) inputs = processor(images=patch, return_tensors="pt").to(device) with torch.no_grad(): pred_id = model(**inputs).logits.argmax(-1).item() pred_label = id2label.get(pred_id, "UNKNOWN") predictions.append(pred_label) # Fill and border colors fill_color, border_color = LABEL_COLORS.get(pred_label, ((200, 200, 200, 100), (100, 100, 100))) # Draw filled rectangle and label overlay_draw.rectangle([(left, upper), (right, lower)], fill=fill_color, outline=border_color, width=border_width) label_text = pred_label bbox = overlay_draw.textbbox((0, 0), label_text, font=font) text_width, text_height = bbox[2] - bbox[0], bbox[3] - bbox[1] text_bg_coords = [(left, upper - text_height - 6), (left + text_width + 6, upper)] overlay_draw.rectangle(text_bg_coords, fill=(0, 0, 0, 200)) overlay_draw.text((left + 3, upper - text_height - 4), label_text, fill="white", font=font) # Merge overlay with original final_image = Image.alpha_composite(image.convert('RGBA'), overlay).convert('RGB') return final_image, predictions # Function to load example images def load_example_image(example_image): return Image.open(example_image) # List example images example_images = [os.path.join("example_images", img) for img in os.listdir("coral_images")] # Gradio interface def gradio_interface(image, rows=2, cols=5): final_image, predictions = predict_and_overlay(image, rows, cols) return final_image, ", ".join(predictions) iface = gr.Interface( fn=gradio_interface, inputs=[ gr.inputs.Image(type="pil", label="Upload Coral Image", optional=True), gr.inputs.Dropdown(choices=example_images, label="Or Select an Example Image"), gr.inputs.Slider(1, 10, value=2, step=1, label="Rows of Sample Points"), gr.inputs.Slider(1, 10, value=5, step=1, label="Columns of Sample Points"), ], outputs=[ gr.outputs.Image(type="pil", label="Image with Predictions"), gr.outputs.Textbox(label="Predictions") ], title="NOAA ESD Coral Bleaching ViT Classifier", description="Upload an image or select an example to sample points/patches and predict coral bleaching using the ViT classifier model hosted on Hugging Face." ) iface.launch()