akridge's picture
Upload app.py
5dedf78 verified
raw
history blame
5.75 kB
import glob
import gradio as gr
import torch
from transformers import ViTForImageClassification, AutoImageProcessor
from PIL import Image, ImageDraw, ImageFont
import random
import os
# βœ… Load model and processor from Hugging Face
model_name = "akridge/noaa-esd-coral-bleaching-vit-classifier-v1"
model = ViTForImageClassification.from_pretrained(model_name)
processor = AutoImageProcessor.from_pretrained(model_name, use_fast=True)
# πŸ—‚οΈ Ensure id2label keys are integers and labels are uppercase
id2label = {int(k): v.upper() for k, v in model.config.id2label.items()}
# 🎨 Label colors (RGBA)
LABEL_COLORS = {
"CORAL": ((0, 0, 255, 80), (0, 0, 200)), # Fill: Blue transparent, Border: Dark blue
"CORAL_BL": ((255, 255, 255, 150), (150, 150, 150)), # Fill: White transparent, Border: Gray
}
def predict_and_overlay(image, rows=2, cols=5, patch_size=224):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device).eval()
# βœ… Load image
image = image.convert('RGB')
width, height = image.size
scale_factor = max(width, height) / 800
font_size = max(12, int(scale_factor * 8))
border_width = max(3, int(scale_factor * 3))
# πŸ–ŒοΈ Create overlay
overlay = Image.new('RGBA', image.size, (255, 255, 255, 0))
overlay_draw = ImageDraw.Draw(overlay)
# βœ… Generate sampled points
cell_width, cell_height = width / cols, height / rows
sampled_points = [
(random.randint(int(col * cell_width), int((col + 1) * cell_width - 1)),
random.randint(int(row * cell_height), int((row + 1) * cell_height - 1)))
for row in range(rows) for col in range(cols)
]
predictions = []
# βœ… Load font
try:
font = ImageFont.truetype("arial.ttf", size=font_size)
except IOError:
font = ImageFont.load_default()
# πŸ” Predict and draw patches
for x, y in sampled_points:
left, upper = max(0, x - patch_size // 2), max(0, y - patch_size // 2)
right, lower = min(width, left + patch_size), min(height, upper + patch_size)
# 🧠 Predict label
patch = image.crop((left, upper, right, lower))
inputs = processor(images=patch, return_tensors="pt").to(device)
with torch.no_grad():
pred_id = model(**inputs).logits.argmax(-1).item()
pred_label = id2label.get(pred_id, "UNKNOWN")
predictions.append(pred_label)
# 🎨 Fill and border colors
fill_color, border_color = LABEL_COLORS.get(pred_label, ((200, 200, 200, 100), (100, 100, 100)))
# 🟦 Draw filled rectangle and label
overlay_draw.rectangle([(left, upper), (right, lower)], fill=fill_color, outline=border_color, width=border_width)
label_text = pred_label
bbox = overlay_draw.textbbox((0, 0), label_text, font=font)
text_width, text_height = bbox[2] - bbox[0], bbox[3] - bbox[1]
text_bg_coords = [(left, upper - text_height - 6), (left + text_width + 6, upper)]
overlay_draw.rectangle(text_bg_coords, fill=(0, 0, 0, 200))
overlay_draw.text((left + 3, upper - text_height - 4), label_text, fill="white", font=font)
# πŸ–ΌοΈ Merge overlay with original
final_image = Image.alpha_composite(image.convert('RGBA'), overlay).convert('RGB')
return final_image, ", ".join(predictions)
# Load example images from coral_images folder
example_images = glob.glob("coral_images/*.[jp][pn]g")
# πŸš€ Gradio interface
def gradio_interface(image, rows=2, cols=5):
if image is None:
return None, "No image uploaded. Please upload an image or select from examples."
final_image, predictions = predict_and_overlay(image, rows, cols)
return final_image, predictions
app_title = "🌊 NOAA ESD Coral Bleaching ViT Classifier"
app_description = """
Upload a coral image or select from example images to sample points and predict coral bleaching using the ViT classifier model hosted on Hugging Face.
**Model:** [akridge/noaa-esd-coral-bleaching-vit-classifier-v1](https://huggingface.co/akridge/noaa-esd-coral-bleaching-vit-classifier-v1)
"""
# Custom CSS for improved styling
custom_css = """
.gradio-container h1 {
font-size: 2.2em;
text-align: center;
}
.gradio-container p {
font-size: 1.2em;
}
.gradio-container .gr-button {
font-size: 1.2em;
}
"""
with gr.Blocks(theme=gr.themes.Ocean(), css=custom_css, title=app_title) as interface:
gr.Markdown(f"<h1>{app_title}</h1>")
gr.Markdown(app_description)
with gr.Row():
image_input = gr.Image(type="pil", label="Upload Coral Image")
result_output = gr.Image(type="pil", label="Predicted Results")
with gr.Row():
rows_slider = gr.Slider(1, 10, value=2, label="Rows of Sample Points")
cols_slider = gr.Slider(1, 10, value=5, label="Columns of Sample Points")
with gr.Row():
run_button = gr.Button("Run Prediction", variant="primary")
clear_button = gr.Button("Clear")
# Add example images section
gr.Examples(
examples=[[img] for img in example_images],
inputs=[image_input],
outputs=[result_output],
examples_per_page=6,
label="Example Coral Images"
)
# Define button actions
run_button.click(
fn=gradio_interface,
inputs=[image_input, rows_slider, cols_slider],
outputs=[result_output, gr.Textbox(label="Predictions")]
)
clear_button.click(lambda: (None, ""), outputs=[image_input, result_output])
interface.launch()