akridge commited on
Commit
5dedf78
Β·
verified Β·
1 Parent(s): 03c2ce2

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -39
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import gradio as gr
2
  import torch
3
  from transformers import ViTForImageClassification, AutoImageProcessor
@@ -5,36 +6,36 @@ from PIL import Image, ImageDraw, ImageFont
5
  import random
6
  import os
7
 
8
- # Load model and processor from Hugging Face
9
  model_name = "akridge/noaa-esd-coral-bleaching-vit-classifier-v1"
10
  model = ViTForImageClassification.from_pretrained(model_name)
11
  processor = AutoImageProcessor.from_pretrained(model_name, use_fast=True)
12
 
13
- # Ensure id2label keys are integers and labels are uppercase
14
  id2label = {int(k): v.upper() for k, v in model.config.id2label.items()}
15
 
16
- # Label colors (RGBA)
17
  LABEL_COLORS = {
18
  "CORAL": ((0, 0, 255, 80), (0, 0, 200)), # Fill: Blue transparent, Border: Dark blue
19
- "CORAL_BL": ((255, 255, 255, 150), (150, 150, 150)) # Fill: White transparent, Border: Gray
20
  }
21
 
22
  def predict_and_overlay(image, rows=2, cols=5, patch_size=224):
23
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
  model.to(device).eval()
25
 
26
- # Load image
27
  image = image.convert('RGB')
28
  width, height = image.size
29
  scale_factor = max(width, height) / 800
30
  font_size = max(12, int(scale_factor * 8))
31
  border_width = max(3, int(scale_factor * 3))
32
 
33
- # Create overlay
34
  overlay = Image.new('RGBA', image.size, (255, 255, 255, 0))
35
  overlay_draw = ImageDraw.Draw(overlay)
36
 
37
- # Generate sampled points
38
  cell_width, cell_height = width / cols, height / rows
39
  sampled_points = [
40
  (random.randint(int(col * cell_width), int((col + 1) * cell_width - 1)),
@@ -44,18 +45,18 @@ def predict_and_overlay(image, rows=2, cols=5, patch_size=224):
44
 
45
  predictions = []
46
 
47
- # Load font
48
  try:
49
  font = ImageFont.truetype("arial.ttf", size=font_size)
50
  except IOError:
51
  font = ImageFont.load_default()
52
 
53
- # Predict and draw patches
54
  for x, y in sampled_points:
55
  left, upper = max(0, x - patch_size // 2), max(0, y - patch_size // 2)
56
  right, lower = min(width, left + patch_size), min(height, upper + patch_size)
57
 
58
- # Predict label
59
  patch = image.crop((left, upper, right, lower))
60
  inputs = processor(images=patch, return_tensors="pt").to(device)
61
  with torch.no_grad():
@@ -63,10 +64,10 @@ def predict_and_overlay(image, rows=2, cols=5, patch_size=224):
63
  pred_label = id2label.get(pred_id, "UNKNOWN")
64
  predictions.append(pred_label)
65
 
66
- # Fill and border colors
67
  fill_color, border_color = LABEL_COLORS.get(pred_label, ((200, 200, 200, 100), (100, 100, 100)))
68
 
69
- # Draw filled rectangle and label
70
  overlay_draw.rectangle([(left, upper), (right, lower)], fill=fill_color, outline=border_color, width=border_width)
71
 
72
  label_text = pred_label
@@ -77,37 +78,74 @@ def predict_and_overlay(image, rows=2, cols=5, patch_size=224):
77
  overlay_draw.rectangle(text_bg_coords, fill=(0, 0, 0, 200))
78
  overlay_draw.text((left + 3, upper - text_height - 4), label_text, fill="white", font=font)
79
 
80
- # Merge overlay with original
81
  final_image = Image.alpha_composite(image.convert('RGBA'), overlay).convert('RGB')
82
 
83
- return final_image, predictions
84
-
85
- # Function to load example images
86
- def load_example_image(example_image):
87
- return Image.open(example_image)
88
 
89
- # List example images
90
- example_images = [os.path.join("example_images", img) for img in os.listdir("coral_images")]
91
 
92
- # Gradio interface
93
  def gradio_interface(image, rows=2, cols=5):
 
 
94
  final_image, predictions = predict_and_overlay(image, rows, cols)
95
- return final_image, ", ".join(predictions)
96
 
97
- iface = gr.Interface(
98
- fn=gradio_interface,
99
- inputs=[
100
- gr.inputs.Image(type="pil", label="Upload Coral Image", optional=True),
101
- gr.inputs.Dropdown(choices=example_images, label="Or Select an Example Image"),
102
- gr.inputs.Slider(1, 10, value=2, step=1, label="Rows of Sample Points"),
103
- gr.inputs.Slider(1, 10, value=5, step=1, label="Columns of Sample Points"),
104
- ],
105
- outputs=[
106
- gr.outputs.Image(type="pil", label="Image with Predictions"),
107
- gr.outputs.Textbox(label="Predictions")
108
- ],
109
- title="NOAA ESD Coral Bleaching ViT Classifier",
110
- description="Upload an image or select an example to sample points/patches and predict coral bleaching using the ViT classifier model hosted on Hugging Face."
111
- )
112
-
113
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
  import gradio as gr
3
  import torch
4
  from transformers import ViTForImageClassification, AutoImageProcessor
 
6
  import random
7
  import os
8
 
9
+ # βœ… Load model and processor from Hugging Face
10
  model_name = "akridge/noaa-esd-coral-bleaching-vit-classifier-v1"
11
  model = ViTForImageClassification.from_pretrained(model_name)
12
  processor = AutoImageProcessor.from_pretrained(model_name, use_fast=True)
13
 
14
+ # πŸ—‚οΈ Ensure id2label keys are integers and labels are uppercase
15
  id2label = {int(k): v.upper() for k, v in model.config.id2label.items()}
16
 
17
+ # 🎨 Label colors (RGBA)
18
  LABEL_COLORS = {
19
  "CORAL": ((0, 0, 255, 80), (0, 0, 200)), # Fill: Blue transparent, Border: Dark blue
20
+ "CORAL_BL": ((255, 255, 255, 150), (150, 150, 150)), # Fill: White transparent, Border: Gray
21
  }
22
 
23
  def predict_and_overlay(image, rows=2, cols=5, patch_size=224):
24
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
  model.to(device).eval()
26
 
27
+ # βœ… Load image
28
  image = image.convert('RGB')
29
  width, height = image.size
30
  scale_factor = max(width, height) / 800
31
  font_size = max(12, int(scale_factor * 8))
32
  border_width = max(3, int(scale_factor * 3))
33
 
34
+ # πŸ–ŒοΈ Create overlay
35
  overlay = Image.new('RGBA', image.size, (255, 255, 255, 0))
36
  overlay_draw = ImageDraw.Draw(overlay)
37
 
38
+ # βœ… Generate sampled points
39
  cell_width, cell_height = width / cols, height / rows
40
  sampled_points = [
41
  (random.randint(int(col * cell_width), int((col + 1) * cell_width - 1)),
 
45
 
46
  predictions = []
47
 
48
+ # βœ… Load font
49
  try:
50
  font = ImageFont.truetype("arial.ttf", size=font_size)
51
  except IOError:
52
  font = ImageFont.load_default()
53
 
54
+ # πŸ” Predict and draw patches
55
  for x, y in sampled_points:
56
  left, upper = max(0, x - patch_size // 2), max(0, y - patch_size // 2)
57
  right, lower = min(width, left + patch_size), min(height, upper + patch_size)
58
 
59
+ # 🧠 Predict label
60
  patch = image.crop((left, upper, right, lower))
61
  inputs = processor(images=patch, return_tensors="pt").to(device)
62
  with torch.no_grad():
 
64
  pred_label = id2label.get(pred_id, "UNKNOWN")
65
  predictions.append(pred_label)
66
 
67
+ # 🎨 Fill and border colors
68
  fill_color, border_color = LABEL_COLORS.get(pred_label, ((200, 200, 200, 100), (100, 100, 100)))
69
 
70
+ # 🟦 Draw filled rectangle and label
71
  overlay_draw.rectangle([(left, upper), (right, lower)], fill=fill_color, outline=border_color, width=border_width)
72
 
73
  label_text = pred_label
 
78
  overlay_draw.rectangle(text_bg_coords, fill=(0, 0, 0, 200))
79
  overlay_draw.text((left + 3, upper - text_height - 4), label_text, fill="white", font=font)
80
 
81
+ # πŸ–ΌοΈ Merge overlay with original
82
  final_image = Image.alpha_composite(image.convert('RGBA'), overlay).convert('RGB')
83
 
84
+ return final_image, ", ".join(predictions)
 
 
 
 
85
 
86
+ # Load example images from coral_images folder
87
+ example_images = glob.glob("coral_images/*.[jp][pn]g")
88
 
89
+ # πŸš€ Gradio interface
90
  def gradio_interface(image, rows=2, cols=5):
91
+ if image is None:
92
+ return None, "No image uploaded. Please upload an image or select from examples."
93
  final_image, predictions = predict_and_overlay(image, rows, cols)
94
+ return final_image, predictions
95
 
96
+ app_title = "🌊 NOAA ESD Coral Bleaching ViT Classifier"
97
+ app_description = """
98
+ Upload a coral image or select from example images to sample points and predict coral bleaching using the ViT classifier model hosted on Hugging Face.
99
+
100
+ **Model:** [akridge/noaa-esd-coral-bleaching-vit-classifier-v1](https://huggingface.co/akridge/noaa-esd-coral-bleaching-vit-classifier-v1)
101
+ """
102
+
103
+ # Custom CSS for improved styling
104
+ custom_css = """
105
+ .gradio-container h1 {
106
+ font-size: 2.2em;
107
+ text-align: center;
108
+ }
109
+ .gradio-container p {
110
+ font-size: 1.2em;
111
+ }
112
+ .gradio-container .gr-button {
113
+ font-size: 1.2em;
114
+ }
115
+ """
116
+
117
+ with gr.Blocks(theme=gr.themes.Ocean(), css=custom_css, title=app_title) as interface:
118
+ gr.Markdown(f"<h1>{app_title}</h1>")
119
+ gr.Markdown(app_description)
120
+
121
+ with gr.Row():
122
+ image_input = gr.Image(type="pil", label="Upload Coral Image")
123
+ result_output = gr.Image(type="pil", label="Predicted Results")
124
+
125
+ with gr.Row():
126
+ rows_slider = gr.Slider(1, 10, value=2, label="Rows of Sample Points")
127
+ cols_slider = gr.Slider(1, 10, value=5, label="Columns of Sample Points")
128
+
129
+ with gr.Row():
130
+ run_button = gr.Button("Run Prediction", variant="primary")
131
+ clear_button = gr.Button("Clear")
132
+
133
+ # Add example images section
134
+ gr.Examples(
135
+ examples=[[img] for img in example_images],
136
+ inputs=[image_input],
137
+ outputs=[result_output],
138
+ examples_per_page=6,
139
+ label="Example Coral Images"
140
+ )
141
+
142
+ # Define button actions
143
+ run_button.click(
144
+ fn=gradio_interface,
145
+ inputs=[image_input, rows_slider, cols_slider],
146
+ outputs=[result_output, gr.Textbox(label="Predictions")]
147
+ )
148
+
149
+ clear_button.click(lambda: (None, ""), outputs=[image_input, result_output])
150
+
151
+ interface.launch()