akridge commited on
Commit
96309db
·
verified ·
1 Parent(s): 4de2b8d

Upload 5 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ coral_images/00_example.png filter=lfs diff=lfs merge=lfs -text
37
+ coral_images/01_example.png filter=lfs diff=lfs merge=lfs -text
38
+ coral_images/02_example.png filter=lfs diff=lfs merge=lfs -text
39
+ coral_images/03_example.png filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import ViTForImageClassification, AutoImageProcessor
4
+ from PIL import Image, ImageDraw, ImageFont
5
+ import random
6
+ import os
7
+
8
+ # Load model and processor from Hugging Face
9
+ model_name = "akridge/noaa-esd-coral-bleaching-vit-classifier-v1"
10
+ model = ViTForImageClassification.from_pretrained(model_name)
11
+ processor = AutoImageProcessor.from_pretrained(model_name, use_fast=True)
12
+
13
+ # Ensure id2label keys are integers and labels are uppercase
14
+ id2label = {int(k): v.upper() for k, v in model.config.id2label.items()}
15
+
16
+ # Label colors (RGBA)
17
+ LABEL_COLORS = {
18
+ "CORAL": ((0, 0, 255, 80), (0, 0, 200)), # Fill: Blue transparent, Border: Dark blue
19
+ "CORAL_BL": ((255, 255, 255, 150), (150, 150, 150)) # Fill: White transparent, Border: Gray
20
+ }
21
+
22
+ def predict_and_overlay(image, rows=2, cols=5, patch_size=224):
23
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
+ model.to(device).eval()
25
+
26
+ # Load image
27
+ image = image.convert('RGB')
28
+ width, height = image.size
29
+ scale_factor = max(width, height) / 800
30
+ font_size = max(12, int(scale_factor * 8))
31
+ border_width = max(3, int(scale_factor * 3))
32
+
33
+ # Create overlay
34
+ overlay = Image.new('RGBA', image.size, (255, 255, 255, 0))
35
+ overlay_draw = ImageDraw.Draw(overlay)
36
+
37
+ # Generate sampled points
38
+ cell_width, cell_height = width / cols, height / rows
39
+ sampled_points = [
40
+ (random.randint(int(col * cell_width), int((col + 1) * cell_width - 1)),
41
+ random.randint(int(row * cell_height), int((row + 1) * cell_height - 1)))
42
+ for row in range(rows) for col in range(cols)
43
+ ]
44
+
45
+ predictions = []
46
+
47
+ # Load font
48
+ try:
49
+ font = ImageFont.truetype("arial.ttf", size=font_size)
50
+ except IOError:
51
+ font = ImageFont.load_default()
52
+
53
+ # Predict and draw patches
54
+ for x, y in sampled_points:
55
+ left, upper = max(0, x - patch_size // 2), max(0, y - patch_size // 2)
56
+ right, lower = min(width, left + patch_size), min(height, upper + patch_size)
57
+
58
+ # Predict label
59
+ patch = image.crop((left, upper, right, lower))
60
+ inputs = processor(images=patch, return_tensors="pt").to(device)
61
+ with torch.no_grad():
62
+ pred_id = model(**inputs).logits.argmax(-1).item()
63
+ pred_label = id2label.get(pred_id, "UNKNOWN")
64
+ predictions.append(pred_label)
65
+
66
+ # Fill and border colors
67
+ fill_color, border_color = LABEL_COLORS.get(pred_label, ((200, 200, 200, 100), (100, 100, 100)))
68
+
69
+ # Draw filled rectangle and label
70
+ overlay_draw.rectangle([(left, upper), (right, lower)], fill=fill_color, outline=border_color, width=border_width)
71
+
72
+ label_text = pred_label
73
+ bbox = overlay_draw.textbbox((0, 0), label_text, font=font)
74
+ text_width, text_height = bbox[2] - bbox[0], bbox[3] - bbox[1]
75
+ text_bg_coords = [(left, upper - text_height - 6), (left + text_width + 6, upper)]
76
+
77
+ overlay_draw.rectangle(text_bg_coords, fill=(0, 0, 0, 200))
78
+ overlay_draw.text((left + 3, upper - text_height - 4), label_text, fill="white", font=font)
79
+
80
+ # Merge overlay with original
81
+ final_image = Image.alpha_composite(image.convert('RGBA'), overlay).convert('RGB')
82
+
83
+ return final_image, predictions
84
+
85
+ # Function to load example images
86
+ def load_example_image(example_image):
87
+ return Image.open(example_image)
88
+
89
+ # List example images
90
+ example_images = [os.path.join("example_images", img) for img in os.listdir("coral_images")]
91
+
92
+ # Gradio interface
93
+ def gradio_interface(image, rows=2, cols=5):
94
+ final_image, predictions = predict_and_overlay(image, rows, cols)
95
+ return final_image, ", ".join(predictions)
96
+
97
+ iface = gr.Interface(
98
+ fn=gradio_interface,
99
+ inputs=[
100
+ gr.inputs.Image(type="pil", label="Upload Coral Image", optional=True),
101
+ gr.inputs.Dropdown(choices=example_images, label="Or Select an Example Image"),
102
+ gr.inputs.Slider(1, 10, value=2, step=1, label="Rows of Sample Points"),
103
+ gr.inputs.Slider(1, 10, value=5, step=1, label="Columns of Sample Points"),
104
+ ],
105
+ outputs=[
106
+ gr.outputs.Image(type="pil", label="Image with Predictions"),
107
+ gr.outputs.Textbox(label="Predictions")
108
+ ],
109
+ title="NOAA ESD Coral Bleaching ViT Classifier",
110
+ description="Upload an image or select an example to sample points/patches and predict coral bleaching using the ViT classifier model hosted on Hugging Face."
111
+ )
112
+
113
+ iface.launch()
coral_images/00_example.png ADDED

Git LFS Details

  • SHA256: 54f2e33bc5b26e5d3be1595b36d4ac0c10add609d34fe938cd89631eb43f8feb
  • Pointer size: 132 Bytes
  • Size of remote file: 1.22 MB
coral_images/01_example.png ADDED

Git LFS Details

  • SHA256: 17efe0c0e6f03e9c062c0b253f86c0fdf071beeaa31b8f2ef2d1bbf31b4b1b8e
  • Pointer size: 132 Bytes
  • Size of remote file: 1.26 MB
coral_images/02_example.png ADDED

Git LFS Details

  • SHA256: 0bb0d2648b71f1ec56930e3110949c60bafe7b010e705da863570a6677976cc8
  • Pointer size: 132 Bytes
  • Size of remote file: 1.42 MB
coral_images/03_example.png ADDED

Git LFS Details

  • SHA256: 07d9e9e68ceab27f4289df8647437ff3202348af648d73695158dd6081a36877
  • Pointer size: 132 Bytes
  • Size of remote file: 1.38 MB