truthdotphd commited on
Commit
8a539dd
·
verified ·
1 Parent(s): df6a791

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +579 -273
app.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import psutil
2
  import gradio as gr
3
  import numpy as np
@@ -9,347 +11,651 @@ import rasterio
9
  from rasterio.enums import Resampling
10
  from rasterio.plot import reshape_as_image
11
  import sys
 
12
 
13
- # Download the entire repository to a subdirectory
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  repo_id = "truthdotphd/cloud-detection"
15
  repo_subdir = "."
16
- repo_dir = snapshot_download(repo_id=repo_id, local_dir=repo_subdir)
 
 
17
 
18
- # Add the repository directory to the Python path
19
  sys.path.append(repo_dir)
20
 
21
- # Import the necessary functions from the downloaded modules
22
  try:
 
 
 
 
23
  from omnicloudmask import predict_from_array
24
- except ImportError:
25
- omnicloudmask_dir = os.path.join(repo_dir, "omnicloudmask")
26
- if os.path.exists(omnicloudmask_dir):
27
- sys.path.append(omnicloudmask_dir)
28
- from omnicloudmask import predict_from_array
29
- else:
30
- raise ImportError("Could not find the omnicloudmask module in the downloaded repository")
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
- def visualize_rgb(red_file, green_file, blue_file, nir_file):
33
  """
34
  Create and display an RGB visualization immediately after images are uploaded.
 
35
  """
36
- if not all([red_file, green_file, blue_file, nir_file]):
37
  return None
38
-
39
  try:
40
- # Get dimensions from red band to use for resampling
 
41
  with rasterio.open(red_file) as src:
42
  target_height = src.height
43
  target_width = src.width
44
-
45
- # Load bands
46
  blue_data = load_band(blue_file)
47
  green_data = load_band(green_file)
48
  red_data = load_band(red_file)
49
-
50
- # Compute max values for each channel for dynamic normalization
51
- red_max = np.max(red_data)
52
- green_max = np.max(green_data)
53
- blue_max = np.max(blue_data)
54
-
55
  # Create RGB image for visualization with dynamic normalization
56
  rgb_image = np.zeros((red_data.shape[0], red_data.shape[1], 3), dtype=np.float32)
57
-
58
- # Normalize each channel individually
59
  epsilon = 1e-10
60
- rgb_image[:, :, 0] = red_data / (red_max + epsilon)
61
- rgb_image[:, :, 1] = green_data / (green_max + epsilon)
62
- rgb_image[:, :, 2] = blue_data / (blue_max + epsilon)
63
-
64
- # Clip values to 0-1 range
65
- rgb_image = np.clip(rgb_image, 0, 1)
66
-
67
- # Apply contrast enhancement for better visualization
68
- p2 = np.percentile(rgb_image, 2)
69
- p98 = np.percentile(rgb_image, 98)
70
- rgb_image_enhanced = np.clip((rgb_image - p2) / (p98 - p2), 0, 1)
71
-
72
  # Convert to uint8 for display
73
  rgb_display = (rgb_image_enhanced * 255).astype(np.uint8)
74
-
75
  return rgb_display
76
  except Exception as e:
77
  print(f"Error generating RGB preview: {e}")
 
 
78
  return None
79
 
80
 
81
  def visualize_jp2(file_path):
82
  """
83
- Visualize a single JP2 file.
84
  """
85
- with rasterio.open(file_path) as src:
86
- # Read the data
87
- data = src.read(1)
88
-
89
- # Normalize the data for visualization
90
- data = (data - np.min(data)) / (np.max(data) - np.min(data))
91
-
92
- # Apply a colormap for better visualization
93
- cmap = plt.get_cmap('viridis')
94
- colored_image = cmap(data)
95
-
96
- # Convert to 8-bit for display
97
- return (colored_image[:, :, :3] * 255).astype(np.uint8)
 
 
 
 
 
 
 
 
98
 
99
  def load_band(file_path, resample=False, target_height=None, target_width=None):
100
  """
101
- Load a single band from a raster file with optional resampling.
102
  """
103
- with rasterio.open(file_path) as src:
104
- if resample and target_height is not None and target_width is not None:
105
- band_data = src.read(
106
- out_shape=(src.count, target_height, target_width),
107
- resampling=Resampling.bilinear
108
- )[0].astype(np.float32)
109
- else:
110
- band_data = src.read()[0].astype(np.float32)
111
-
112
- return band_data
 
 
 
 
 
 
113
 
114
  def prepare_input_array(red_file, green_file, blue_file, nir_file):
115
  """
116
- Prepare a stacked array of satellite bands for cloud mask prediction.
 
 
 
 
117
  """
118
- # Get dimensions from red band to use for resampling
119
- with rasterio.open(red_file) as src:
120
- target_height = src.height
121
- target_width = src.width
122
-
123
- # Load bands (resample NIR band to match 10m resolution)
124
- blue_data = load_band(blue_file)
125
- green_data = load_band(green_file)
126
- red_data = load_band(red_file)
127
- nir_data = load_band(
128
- nir_file,
129
- resample=True,
130
- target_height=target_height,
131
- target_width=target_width
132
- )
133
-
134
- # Print band shapes for debugging
135
- print(f"Band shapes - Blue: {blue_data.shape}, Green: {green_data.shape}, Red: {red_data.shape}, NIR: {nir_data.shape}")
136
-
137
- # Compute max values for each channel for dynamic normalization
138
- red_max = np.max(red_data)
139
- green_max = np.max(green_data)
140
- blue_max = np.max(blue_data)
141
-
142
- print(f"Max values - Red: {red_max}, Green: {green_max}, Blue: {blue_max}")
143
-
144
- # Create RGB image for visualization with dynamic normalization
145
- rgb_image = np.zeros((red_data.shape[0], red_data.shape[1], 3), dtype=np.float32)
146
-
147
- # Normalize each channel individually
148
- # Add a small epsilon to avoid division by zero
149
- epsilon = 1e-10
150
- rgb_image[:, :, 0] = red_data / (red_max + epsilon)
151
- rgb_image[:, :, 1] = green_data / (green_max + epsilon)
152
- rgb_image[:, :, 2] = blue_data / (blue_max + epsilon)
153
-
154
- # Clip values to 0-1 range
155
- rgb_image = np.clip(rgb_image, 0, 1)
156
-
157
- # Optional: Apply contrast enhancement for better visualization
158
- p2 = np.percentile(rgb_image, 2)
159
- p98 = np.percentile(rgb_image, 98)
160
- rgb_image_enhanced = np.clip((rgb_image - p2) / (p98 - p2), 0, 1)
161
-
162
- # Stack bands in CHW format for cloud mask prediction (red, green, nir)
163
- prediction_array = np.stack([red_data, green_data, nir_data], axis=0)
164
-
165
- return prediction_array, rgb_image_enhanced
166
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
  def visualize_cloud_mask(rgb_image, pred_mask):
169
  """
170
  Create a visualization of the cloud mask overlaid on the RGB image.
 
171
  """
172
- # Ensure pred_mask has the right dimensions
173
- if pred_mask.ndim > 2:
174
- pred_mask = np.squeeze(pred_mask)
175
-
176
- print(f"RGB image shape: {rgb_image.shape}, Pred mask shape: {pred_mask.shape}")
177
-
178
- # Ensure mask has the same spatial dimensions as the image
179
- if pred_mask.shape != rgb_image.shape[:2]:
180
- pred_mask = cv2.resize(
181
- pred_mask.astype(np.float32),
182
- (rgb_image.shape[1], rgb_image.shape[0]),
183
- interpolation=cv2.INTER_NEAREST
184
- ).astype(np.uint8)
185
- print(f"Resized mask shape: {pred_mask.shape}")
186
-
187
- # Define colors for each class
188
- colors = {
189
- 0: [0, 255, 0], # Clear - Green
190
- 1: [255, 255, 255], # Thick Cloud - White
191
- 2: [200, 200, 200], # Thin Cloud - Light Gray
192
- 3: [100, 100, 100] # Cloud Shadow - Dark Gray
193
- }
194
-
195
- # Create a color-coded mask
196
- mask_vis = np.zeros((pred_mask.shape[0], pred_mask.shape[1], 3), dtype=np.uint8)
197
- for class_idx, color in colors.items():
198
- mask_vis[pred_mask == class_idx] = color
199
-
200
- # Create a blended visualization
201
- alpha = 0.5
202
- blended = cv2.addWeighted((rgb_image * 255).astype(np.uint8), 1-alpha, mask_vis, alpha, 0)
203
-
204
- # Get the width of the blended image for the legend
205
- image_width = blended.shape[1]
206
-
207
- # Create a legend with the same width as the image
208
- legend = np.ones((100, image_width, 3), dtype=np.uint8) * 255
209
- legend_text = ["Clear", "Thick Cloud", "Thin Cloud", "Cloud Shadow"]
210
- legend_colors = [colors[i] for i in range(4)]
211
-
212
- for i, (text, color) in enumerate(zip(legend_text, legend_colors)):
213
- cv2.rectangle(legend, (10, 10 + i*20), (30, 30 + i*20), color, -1)
214
- cv2.putText(legend, text, (40, 25 + i*20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1)
215
-
216
- # Combine image and legend
217
- final_output = np.vstack([blended, legend])
218
-
219
- return final_output
220
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  def process_satellite_images(red_file, green_file, blue_file, nir_file, batch_size, patch_size, patch_overlap):
222
  """
223
- Process the satellite images and detect clouds.
224
  """
225
  if not all([red_file, green_file, blue_file, nir_file]):
226
- return None, None, "Please upload all four channel files (Red, Green, Blue, NIR)"
227
-
228
- # Prepare input array and RGB image for visualization
229
- input_array, rgb_image = prepare_input_array(red_file, green_file, blue_file, nir_file)
230
-
231
- # Convert RGB image to format suitable for display
232
- rgb_display = (rgb_image * 255).astype(np.uint8)
233
-
234
- # Predict cloud mask using omnicloudmask
235
- pred_mask = predict_from_array(
236
- input_array,
237
- batch_size=batch_size,
238
- patch_size=patch_size,
239
- patch_overlap=patch_overlap
240
- )
241
-
242
- # Calculate class distribution
243
- if pred_mask.ndim > 2:
244
- flat_mask = np.squeeze(pred_mask)
 
 
 
 
245
  else:
246
- flat_mask = pred_mask
247
-
248
- clear_pixels = np.sum(flat_mask == 0)
249
- thick_cloud_pixels = np.sum(flat_mask == 1)
250
- thin_cloud_pixels = np.sum(flat_mask == 2)
251
- cloud_shadow_pixels = np.sum(flat_mask == 3)
252
- total_pixels = flat_mask.size
253
-
254
- stats = f"""
255
- Cloud Mask Statistics:
256
- - Clear: {clear_pixels} pixels ({clear_pixels/total_pixels*100:.2f}%)
257
- - Thick Cloud: {thick_cloud_pixels} pixels ({thick_cloud_pixels/total_pixels*100:.2f}%)
258
- - Thin Cloud: {thin_cloud_pixels} pixels ({thin_cloud_pixels/total_pixels*100:.2f}%)
259
- - Cloud Shadow: {cloud_shadow_pixels} pixels ({cloud_shadow_pixels/total_pixels*100:.2f}%)
260
- - Total Cloud Cover: {(thick_cloud_pixels + thin_cloud_pixels)/total_pixels*100:.2f}%
261
- """
262
-
263
- # Visualize the cloud mask on the original image
264
- visualization = visualize_cloud_mask(rgb_image, flat_mask)
265
-
266
- return rgb_display, visualization, stats
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
 
268
- def update_cpu():
269
- return f"CPU Usage: {psutil.cpu_percent()}%"
270
 
271
- with gr.Blocks() as demo:
272
- cpu_text = gr.Textbox(label="CPU Usage")
273
- check_cpu_btn = gr.Button("Check CPU")
274
-
275
- # Attach the event handler using the click method
276
- check_cpu_btn.click(fn=update_cpu, inputs=None, outputs=cpu_text)
277
-
278
 
279
- # Define the CPU check function
280
  def check_cpu_usage():
281
  """Check and return the current CPU usage."""
282
  return f"CPU Usage: {psutil.cpu_percent()}%"
283
 
284
- # Create the Gradio application with Blocks
285
- with gr.Blocks(title="Satellite Cloud Detection") as demo:
286
- # Add the description
287
  gr.Markdown("""
288
- # Satellite Cloud Detection
289
-
290
- Upload separate JP2 files for Red, Green, Blue, and NIR channels to detect clouds in satellite imagery.
291
-
292
- This application uses the OmniCloudMask model to classify each pixel as:
293
- - Clear (0)
294
- - Thick Cloud (1)
295
- - Thin Cloud (2)
296
- - Cloud Shadow (3)
297
-
298
- The model works best with imagery at 10-50m resolution. For higher resolution imagery, downsampling is recommended.
 
 
299
  """)
300
-
301
  # Main cloud detection interface
302
  with gr.Row():
303
- with gr.Column():
304
- # Input components
305
- red_input = gr.Image(type="filepath", label="Red Channel (JP2)")
306
- green_input = gr.Image(type="filepath", label="Green Channel (JP2)")
307
- blue_input = gr.Image(type="filepath", label="Blue Channel (JP2)")
308
- nir_input = gr.Image(type="filepath", label="NIR Channel (JP2)")
309
-
310
- batch_size = gr.Slider(minimum=1, maximum=32, value=1, step=1,
311
- label="Batch Size",
312
- info="Higher values use more memory but process faster")
313
- patch_size = gr.Slider(minimum=500, maximum=2000, value=1000, step=100,
314
- label="Patch Size",
315
- info="Size of image patches for processing")
316
- patch_overlap = gr.Slider(minimum=100, maximum=500, value=300, step=50,
317
- label="Patch Overlap",
318
- info="Overlap between patches to avoid edge artifacts")
319
-
320
- process_btn = gr.Button("Process Cloud Detection")
321
-
322
- with gr.Column():
 
 
 
323
  # Output components
324
- rgb_output = gr.Image(label="Original RGB Image")
325
- cloud_output = gr.Image(label="Cloud Detection Visualization")
326
- stats_output = gr.Textbox(label="Statistics")
327
-
328
- # CPU usage monitoring section
329
- with gr.Row():
330
- with gr.Column():
331
- gr.Markdown("## System Monitoring")
332
- cpu_button = gr.Button("Check CPU Usage")
333
- cpu_output = gr.Textbox(label="CPU Usage")
334
-
335
- # Set up event handlers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336
  process_btn.click(
337
  fn=process_satellite_images,
338
  inputs=[red_input, green_input, blue_input, nir_input, batch_size, patch_size, patch_overlap],
339
  outputs=[rgb_output, cloud_output, stats_output]
340
  )
341
-
342
- cpu_button.click(
343
- fn=check_cpu_usage,
344
- inputs=None,
345
- outputs=cpu_output
346
- )
347
-
348
- # Add examples
349
- gr.Examples(
350
- examples=[["jp2s/B04.jp2", "jp2s/B03.jp2", "jp2s/B02.jp2", "jp2s/B8A.jp2", 1, 1000, 300]],
351
- inputs=[red_input, green_input, blue_input, nir_input, batch_size, patch_size, patch_overlap]
352
- )
353
 
354
- # Launch the app
355
- demo.queue(default_concurrency_limit=8).launch(debug=True)
 
 
 
1
+ # Gradio App Code (based on paste.txt) with Triton Integration and Fallback
2
+
3
  import psutil
4
  import gradio as gr
5
  import numpy as np
 
11
  from rasterio.enums import Resampling
12
  from rasterio.plot import reshape_as_image
13
  import sys
14
+ import time # For potential timeouts/delays
15
 
16
+ # --- Triton Client Imports ---
17
+ try:
18
+ import tritonclient.http as httpclient
19
+ import tritonclient.utils as triton_utils # For InferenceServerException
20
+ TRITON_CLIENT_AVAILABLE = True
21
+ except ImportError:
22
+ print("WARNING: tritonclient is not installed. Triton inference will not be available.")
23
+ print("Install using: pip install tritonclient[all]")
24
+ TRITON_CLIENT_AVAILABLE = False
25
+ httpclient = None # Define dummy to avoid NameErrors later
26
+ triton_utils = None
27
+
28
+ # --- Configuration ---
29
+ # Download the entire repository for local fallback and utils
30
  repo_id = "truthdotphd/cloud-detection"
31
  repo_subdir = "."
32
+ print(f"Downloading/Checking Hugging Face repo '{repo_id}'...")
33
+ repo_dir = snapshot_download(repo_id=repo_id, local_dir=repo_subdir, local_dir_use_symlinks=False) # Use False for symlinks in Gradio/Docker usually
34
+ print(f"Repo downloaded/cached at: {repo_dir}")
35
 
36
+ # Add the repository directory to the Python path for local modules
37
  sys.path.append(repo_dir)
38
 
39
+ # Import the necessary functions from the downloaded modules for LOCAL fallback
40
  try:
41
+ # Adjust path if omnicloudmask is inside a subfolder
42
+ omnicloudmask_path = os.path.join(repo_dir, "omnicloudmask")
43
+ if os.path.isdir(omnicloudmask_path):
44
+ sys.path.append(omnicloudmask_path) # Add subfolder if exists
45
  from omnicloudmask import predict_from_array
46
+ LOCAL_MODEL_AVAILABLE = True
47
+ print("Local omnicloudmask module loaded successfully.")
48
+ except ImportError as e:
49
+ print(f"ERROR: Could not import local 'predict_from_array' from omnicloudmask: {e}")
50
+ print("Local fallback will not be available.")
51
+ LOCAL_MODEL_AVAILABLE = False
52
+ predict_from_array = None # Define dummy
53
+
54
+ # --- Triton Server Configuration ---
55
+ TRITON_IP = "206.123.129.87" # Use the public IP provided
56
+ HTTP_TRITON_URL = f"{TRITON_IP}:8000"
57
+ # GRPC_TRITON_URL = f"{TRITON_IP}:8001" # Keep for potential future use
58
+ TRITON_MODEL_NAME = "cloud-detection" # Ensure this matches your deployed model name
59
+ TRITON_INPUT_NAME = "input_jp2_bytes" # Ensure this matches your model's config.pbtxt
60
+ TRITON_OUTPUT_NAME = "output_mask" # Ensure this matches your model's config.pbtxt
61
+ TRITON_TIMEOUT_SECONDS = 300.0 # 5 minutes timeout for connection/network
62
+
63
+
64
+ # --- Utility Functions (mostly from paste.txt) ---
65
 
66
+ def visualize_rgb(red_file, green_file, blue_file):
67
  """
68
  Create and display an RGB visualization immediately after images are uploaded.
69
+ (Modified slightly: doesn't need nir_file)
70
  """
71
+ if not all([red_file, green_file, blue_file]):
72
  return None
73
+
74
  try:
75
+ # Load bands (using load_band utility)
76
+ # Get target shape from red band
77
  with rasterio.open(red_file) as src:
78
  target_height = src.height
79
  target_width = src.width
80
+
 
81
  blue_data = load_band(blue_file)
82
  green_data = load_band(green_file)
83
  red_data = load_band(red_file)
84
+
85
+ # Compute max values for scaling (simple approach)
86
+ red_max = np.percentile(red_data[red_data>0], 98) if np.any(red_data>0) else 1.0
87
+ green_max = np.percentile(green_data[green_data>0], 98) if np.any(green_data>0) else 1.0
88
+ blue_max = np.percentile(blue_data[blue_data>0], 98) if np.any(blue_data>0) else 1.0
89
+
90
  # Create RGB image for visualization with dynamic normalization
91
  rgb_image = np.zeros((red_data.shape[0], red_data.shape[1], 3), dtype=np.float32)
92
+
 
93
  epsilon = 1e-10
94
+ rgb_image[:, :, 0] = np.clip(red_data / (red_max + epsilon), 0, 1)
95
+ rgb_image[:, :, 1] = np.clip(green_data / (green_max + epsilon), 0, 1)
96
+ rgb_image[:, :, 2] = np.clip(blue_data / (blue_max + epsilon), 0, 1)
97
+
98
+ # Simple brightness/contrast adjustment (gamma correction)
99
+ gamma = 1.8
100
+ rgb_image_enhanced = np.power(rgb_image, 1/gamma)
101
+
 
 
 
 
102
  # Convert to uint8 for display
103
  rgb_display = (rgb_image_enhanced * 255).astype(np.uint8)
104
+
105
  return rgb_display
106
  except Exception as e:
107
  print(f"Error generating RGB preview: {e}")
108
+ import traceback
109
+ traceback.print_exc()
110
  return None
111
 
112
 
113
  def visualize_jp2(file_path):
114
  """
115
+ Visualize a single JP2 file. (Unchanged from paste.txt)
116
  """
117
+ try:
118
+ with rasterio.open(file_path) as src:
119
+ data = src.read(1)
120
+ # Check if data is all zero or invalid
121
+ if np.all(data == 0) or np.ptp(data) == 0:
122
+ print(f"Warning: Data in {file_path} is constant or zero. Cannot normalize.")
123
+ # Return a black image or handle as appropriate
124
+ return np.zeros((src.height, src.width, 3), dtype=np.uint8)
125
+
126
+ # Normalize the data for visualization
127
+ data_norm = (data - np.min(data)) / (np.max(data) - np.min(data))
128
+
129
+ # Apply a colormap for better visualization
130
+ cmap = plt.get_cmap('viridis')
131
+ colored_image = cmap(data_norm)
132
+
133
+ # Convert to 8-bit for display
134
+ return (colored_image[:, :, :3] * 255).astype(np.uint8)
135
+ except Exception as e:
136
+ print(f"Error visualizing JP2 file {file_path}: {e}")
137
+ return None
138
 
139
  def load_band(file_path, resample=False, target_height=None, target_width=None):
140
  """
141
+ Load a single band from a raster file with optional resampling. (Unchanged from paste.txt)
142
  """
143
+ try:
144
+ with rasterio.open(file_path) as src:
145
+ if resample and target_height is not None and target_width is not None:
146
+ # Ensure output shape matches target channels (1 for single band)
147
+ out_shape = (1, target_height, target_width)
148
+ band_data = src.read(
149
+ out_shape=out_shape,
150
+ resampling=Resampling.bilinear
151
+ )[0].astype(np.float32) # Read only the first band after resampling
152
+ else:
153
+ band_data = src.read(1).astype(np.float32) # Read only the first band
154
+
155
+ return band_data
156
+ except Exception as e:
157
+ print(f"Error loading band {file_path}: {e}")
158
+ raise # Re-raise error to be caught by calling function
159
 
160
  def prepare_input_array(red_file, green_file, blue_file, nir_file):
161
  """
162
+ Prepare a stacked array (R, G, NIR) for the LOCAL model and an RGB image for visualization.
163
+ (Slightly modified from paste.txt to handle potential loading errors)
164
+ Returns:
165
+ prediction_array (np.ndarray): Stacked array (R,G,NIR) for local model, or None on error.
166
+ rgb_image_enhanced (np.ndarray): RGB image (0-1 float) for visualization, or None on error.
167
  """
168
+ try:
169
+ # Get dimensions from red band to use for resampling
170
+ with rasterio.open(red_file) as src:
171
+ target_height = src.height
172
+ target_width = src.width
173
+
174
+ # Load bands (resample NIR band to match 10m resolution)
175
+ blue_data = load_band(blue_file) # Needed for RGB viz
176
+ green_data = load_band(green_file)
177
+ red_data = load_band(red_file)
178
+ nir_data = load_band(
179
+ nir_file,
180
+ resample=True,
181
+ target_height=target_height,
182
+ target_width=target_width
183
+ )
184
+
185
+ # --- Prepare RGB Image for Visualization (similar to visualize_rgb but returns float array) ---
186
+ red_max = np.percentile(red_data[red_data>0], 98) if np.any(red_data>0) else 1.0
187
+ green_max = np.percentile(green_data[green_data>0], 98) if np.any(green_data>0) else 1.0
188
+ blue_max = np.percentile(blue_data[blue_data>0], 98) if np.any(blue_data>0) else 1.0
189
+ epsilon = 1e-10
190
+
191
+ rgb_image = np.zeros((target_height, target_width, 3), dtype=np.float32)
192
+ rgb_image[:, :, 0] = np.clip(red_data / (red_max + epsilon), 0, 1)
193
+ rgb_image[:, :, 1] = np.clip(green_data / (green_max + epsilon), 0, 1)
194
+ rgb_image[:, :, 2] = np.clip(blue_data / (blue_max + epsilon), 0, 1)
195
+
196
+ # Apply gamma correction for enhancement
197
+ gamma = 1.8
198
+ rgb_image_enhanced = np.power(rgb_image, 1/gamma)
199
+ # --- End RGB Image Preparation ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
 
201
+ # Stack bands in CHW format for LOCAL cloud mask prediction (red, green, nir)
202
+ # Ensure all bands have the same shape before stacking
203
+ if not (red_data.shape == green_data.shape == nir_data.shape):
204
+ print("ERROR: Band shapes mismatch after loading/resampling!")
205
+ print(f"Shapes - Red: {red_data.shape}, Green: {green_data.shape}, NIR: {nir_data.shape}")
206
+ return None, None # Indicate error
207
+
208
+ prediction_array = np.stack([red_data, green_data, nir_data], axis=0) # CHW format
209
+
210
+ print(f"Local prediction array shape: {prediction_array.shape}")
211
+ print(f"RGB visualization image shape: {rgb_image_enhanced.shape}")
212
+
213
+ return prediction_array, rgb_image_enhanced
214
+
215
+ except Exception as e:
216
+ print(f"Error during input preparation: {e}")
217
+ import traceback
218
+ traceback.print_exc()
219
+ return None, None # Indicate error
220
 
221
  def visualize_cloud_mask(rgb_image, pred_mask):
222
  """
223
  Create a visualization of the cloud mask overlaid on the RGB image.
224
+ (Unchanged from paste.txt, but added error checks)
225
  """
226
+ if rgb_image is None or pred_mask is None:
227
+ print("Cannot visualize cloud mask: Missing RGB image or prediction mask.")
228
+ return None
229
+
230
+ try:
231
+ # Ensure pred_mask has the right dimensions (H, W)
232
+ if pred_mask.ndim == 3 and pred_mask.shape[0] == 1: # Squeeze channel dim if present
233
+ pred_mask = np.squeeze(pred_mask, axis=0)
234
+ elif pred_mask.ndim != 2:
235
+ print(f"ERROR: Unexpected prediction mask dimension: {pred_mask.ndim}, shape: {pred_mask.shape}")
236
+ # Attempt to squeeze if possible, otherwise fail
237
+ try:
238
+ pred_mask = np.squeeze(pred_mask)
239
+ if pred_mask.ndim != 2: raise ValueError("Still not 2D after squeeze")
240
+ except Exception as sq_err:
241
+ print(f"Could not convert mask to 2D: {sq_err}")
242
+ return None # Cannot visualize
243
+
244
+ print(f"Visualization - RGB image shape: {rgb_image.shape}, Pred mask shape: {pred_mask.shape}")
245
+
246
+ # Ensure mask has the same spatial dimensions as the image
247
+ if pred_mask.shape != rgb_image.shape[:2]:
248
+ print(f"Warning: Resizing prediction mask from {pred_mask.shape} to {rgb_image.shape[:2]} for visualization.")
249
+ # Ensure mask is integer type for nearest neighbor interpolation
250
+ if not np.issubdtype(pred_mask.dtype, np.integer):
251
+ print("Warning: Prediction mask is not integer type, casting to uint8 for resize.")
252
+ pred_mask = pred_mask.astype(np.uint8)
253
+
254
+ pred_mask_resized = cv2.resize(
255
+ pred_mask,
256
+ (rgb_image.shape[1], rgb_image.shape[0]), # Target shape (width, height) for cv2.resize
257
+ interpolation=cv2.INTER_NEAREST # Use nearest to preserve class labels
258
+ )
259
+ pred_mask = pred_mask_resized
260
+ print(f"Resized mask shape: {pred_mask.shape}")
261
+
262
+ # Define colors for each class
263
+ colors = {
264
+ 0: [0, 255, 0], # Clear - Green
265
+ 1: [255, 0, 0], # Thick Cloud - Red (Changed from White for better contrast)
266
+ 2: [255, 255, 0], # Thin Cloud - Yellow (Changed from Gray)
267
+ 3: [0, 0, 255] # Cloud Shadow - Blue (Changed from Gray)
268
+ }
269
+
270
+ # Create a color-coded mask visualization
271
+ mask_vis = np.zeros((pred_mask.shape[0], pred_mask.shape[1], 3), dtype=np.uint8)
272
+ for class_idx, color in colors.items():
273
+ # Handle potential out-of-bounds class indices in mask
274
+ mask_vis[pred_mask == class_idx] = color
275
+
276
+ # Create a blended visualization
277
+ alpha = 0.4 # Transparency of the mask overlay
278
+ # Ensure rgb_image is uint8 for blending
279
+ rgb_uint8 = (np.clip(rgb_image, 0, 1) * 255).astype(np.uint8)
280
+ blended = cv2.addWeighted(rgb_uint8, 1-alpha, mask_vis, alpha, 0)
281
+
282
+ # --- Create Legend ---
283
+ legend_height = 100
284
+ legend_width = blended.shape[1] # Match image width
285
+ legend = np.ones((legend_height, legend_width, 3), dtype=np.uint8) * 255 # White background
286
+
287
+ legend_text = ["Clear", "Thick Cloud", "Thin Cloud", "Cloud Shadow"]
288
+ legend_colors = [colors.get(i, [0,0,0]) for i in range(4)] # Use .get for safety
289
+
290
+ box_size = 15
291
+ text_offset_x = 40
292
+ start_y = 15
293
+ padding_y = 20
294
+
295
+ for i, (text, color) in enumerate(zip(legend_text, legend_colors)):
296
+ # Draw color box
297
+ cv2.rectangle(legend,
298
+ (10, start_y + i*padding_y - box_size // 2),
299
+ (10 + box_size, start_y + i*padding_y + box_size // 2),
300
+ color, -1)
301
+ # Draw text
302
+ cv2.putText(legend, text,
303
+ (text_offset_x, start_y + i*padding_y + box_size // 4), # Adjust vertical alignment
304
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)
305
+ # --- End Legend ---
306
+
307
+ # Combine image and legend
308
+ final_output = np.vstack([blended, legend])
309
+
310
+ return final_output
311
+
312
+ except Exception as e:
313
+ print(f"Error during visualization: {e}")
314
+ import traceback
315
+ traceback.print_exc()
316
+ return None # Return None if visualization fails
317
+
318
+
319
+ # --- Triton Client Functions (Adapted from paste-2.txt) ---
320
+
321
+ def is_triton_server_healthy(url=HTTP_TRITON_URL):
322
+ """Checks if the Triton Inference Server is live."""
323
+ if not TRITON_CLIENT_AVAILABLE:
324
+ return False
325
+ try:
326
+ triton_client = httpclient.InferenceServerClient(url=url, connection_timeout=10.0) # Short timeout for health check
327
+ server_live = triton_client.is_server_live()
328
+ if server_live:
329
+ print(f"Triton server at {url} is live.")
330
+ # Optionally check readiness:
331
+ # server_ready = triton_client.is_server_ready()
332
+ # print(f"Triton server at {url} is ready: {server_ready}")
333
+ # return server_ready
334
+ else:
335
+ print(f"Triton server at {url} is not live.")
336
+ return server_live
337
+ except Exception as e:
338
+ print(f"Could not connect to Triton server at {url}: {e}")
339
+ return False
340
+
341
+ def get_jp2_bytes_for_triton(red_file_path, green_file_path, nir_file_path):
342
+ """
343
+ Reads the raw bytes of Red, Green, and NIR JP2 files for Triton.
344
+ Order: Red, Green, NIR (must match Triton model input expectation)
345
+ """
346
+ byte_list = []
347
+ files_to_read = [red_file_path, green_file_path, nir_file_path]
348
+ band_names = ['Red', 'Green', 'NIR']
349
+
350
+ for file_path, band_name in zip(files_to_read, band_names):
351
+ try:
352
+ with open(file_path, "rb") as f:
353
+ file_bytes = f.read()
354
+ byte_list.append(file_bytes)
355
+ print(f"Read {len(file_bytes)} bytes for {band_name} band from {os.path.basename(file_path)}")
356
+ except FileNotFoundError:
357
+ print(f"ERROR: File not found: {file_path}")
358
+ raise # Propagate error
359
+ except Exception as e:
360
+ print(f"ERROR: Could not read file {file_path}: {e}")
361
+ raise # Propagate error
362
+
363
+ # Create NumPy array of object type to hold bytes
364
+ input_byte_array = np.array(byte_list, dtype=object)
365
+
366
+ # Expected shape is (3,) -> a 1D array containing 3 byte objects
367
+ print(f"Prepared Triton input byte array with shape: {input_byte_array.shape} and dtype: {input_byte_array.dtype}")
368
+ return input_byte_array
369
+
370
+
371
+ def run_inference_triton_http(input_byte_array):
372
+ """
373
+ Run inference using Triton HTTP client with raw JP2 bytes.
374
+ """
375
+ if not TRITON_CLIENT_AVAILABLE:
376
+ raise RuntimeError("Triton client library not available.")
377
+
378
+ print("Attempting inference using Triton HTTP client...")
379
+ try:
380
+ client = httpclient.InferenceServerClient(
381
+ url=HTTP_TRITON_URL,
382
+ verbose=False,
383
+ connection_timeout=TRITON_TIMEOUT_SECONDS,
384
+ network_timeout=TRITON_TIMEOUT_SECONDS
385
+ )
386
+ except Exception as e:
387
+ print(f"ERROR: Couldn't create Triton HTTP client: {e}")
388
+ raise # Propagate error
389
+
390
+ # Prepare input tensor (BYTES type)
391
+ # Shape [3] matches the 1D numpy array holding 3 byte strings
392
+ inputs = [httpclient.InferInput(TRITON_INPUT_NAME, input_byte_array.shape, "BYTES")]
393
+ inputs[0].set_data_from_numpy(input_byte_array, binary_data=True) # binary_data=True is important for BYTES
394
+
395
+ # Prepare output tensor request
396
+ outputs = [httpclient.InferRequestedOutput(TRITON_OUTPUT_NAME, binary_data=True)]
397
+
398
+ # Send inference request
399
+ try:
400
+ print(f"Sending inference request to Triton model '{TRITON_MODEL_NAME}' at {HTTP_TRITON_URL}...")
401
+ response = client.infer(
402
+ model_name=TRITON_MODEL_NAME,
403
+ inputs=inputs,
404
+ outputs=outputs,
405
+ request_id=str(os.getpid()), # Optional request ID
406
+ timeout=TRITON_TIMEOUT_SECONDS + 60.0 # Give extra time for the inference itself
407
+ )
408
+ print("Triton inference request successful.")
409
+ mask = response.as_numpy(TRITON_OUTPUT_NAME)
410
+ print(f"Received output mask from Triton with shape: {mask.shape}, dtype: {mask.dtype}")
411
+ return mask
412
+ except triton_utils.InferenceServerException as e:
413
+ print(f"ERROR: Triton server failed inference: Status code {e.status()}, message: {e.message()}")
414
+ print(f"Debug details: {e.debug_details()}")
415
+ raise # Propagate error to trigger fallback
416
+ except Exception as e:
417
+ print(f"ERROR: An unexpected error occurred during Triton HTTP inference: {e}")
418
+ import traceback
419
+ traceback.print_exc()
420
+ raise # Propagate error to trigger fallback
421
+
422
+
423
+ # --- Main Processing Function with Fallback Logic ---
424
+
425
  def process_satellite_images(red_file, green_file, blue_file, nir_file, batch_size, patch_size, patch_overlap):
426
  """
427
+ Process satellite images: Try Triton first, fallback to local model.
428
  """
429
  if not all([red_file, green_file, blue_file, nir_file]):
430
+ return None, None, "ERROR: Please upload all four channel files (Red, Green, Blue, NIR)"
431
+
432
+ # Store file paths from Gradio Image components
433
+ red_file_path = red_file if isinstance(red_file, str) else red_file.name
434
+ green_file_path = green_file if isinstance(green_file, str) else green_file.name
435
+ blue_file_path = blue_file if isinstance(blue_file, str) else blue_file.name
436
+ nir_file_path = nir_file if isinstance(nir_file, str) else nir_file.name
437
+
438
+ print("\n--- Starting Cloud Detection Process ---")
439
+ print(f"Input files: R={os.path.basename(red_file_path)}, G={os.path.basename(green_file_path)}, B={os.path.basename(blue_file_path)}, N={os.path.basename(nir_file_path)}")
440
+
441
+ pred_mask = None
442
+ status_message = ""
443
+ rgb_display_image = None # For the raw RGB output panel
444
+ rgb_float_image = None # For overlay visualization
445
+
446
+ # 1. Prepare Visualization Image (always needed) & Local Input Array (needed for fallback)
447
+ print("Preparing visualization image and local model input array...")
448
+ local_input_array, rgb_float_image = prepare_input_array(red_file_path, green_file_path, blue_file_path, nir_file_path)
449
+
450
+ if rgb_float_image is not None:
451
+ # Convert float image (0-1) to uint8 (0-255) for the RGB output panel
452
+ rgb_display_image = (np.clip(rgb_float_image, 0, 1) * 255).astype(np.uint8)
453
  else:
454
+ print("ERROR: Failed to create RGB visualization image.")
455
+ # Return early if visualization prep failed, as likely indicates file loading issues
456
+ return None, None, "ERROR: Failed to load or process input band files."
457
+
458
+ # 2. Check Triton Server Health
459
+ use_triton = False
460
+ if TRITON_CLIENT_AVAILABLE:
461
+ print(f"Checking Triton server health at {HTTP_TRITON_URL}...")
462
+ if is_triton_server_healthy(HTTP_TRITON_URL):
463
+ use_triton = True
464
+ else:
465
+ print("Triton server is not healthy or unavailable.")
466
+ status_message += "Triton server unavailable. "
467
+ else:
468
+ print("Triton client library not installed. Skipping Triton check.")
469
+ status_message += "Triton client not installed. "
470
+
471
+ # 3. Attempt Triton Inference if Healthy
472
+ if use_triton:
473
+ try:
474
+ print("Preparing JP2 bytes for Triton...")
475
+ # Use Red, Green, NIR file paths
476
+ triton_byte_input = get_jp2_bytes_for_triton(red_file_path, green_file_path, nir_file_path)
477
+ pred_mask = run_inference_triton_http(triton_byte_input)
478
+ status_message += "Inference performed using Triton Server. "
479
+ print("Triton inference successful.")
480
+ except Exception as e:
481
+ print(f"Triton inference failed: {e}. Falling back to local model.")
482
+ status_message += f"Triton inference failed ({type(e).__name__}). "
483
+ pred_mask = None # Ensure mask is None to trigger fallback
484
+ use_triton = False # Explicitly mark Triton as not used
485
+
486
+ # 4. Fallback to Local Model if Triton failed or wasn't available/healthy
487
+ if pred_mask is None: # Check if mask wasn't obtained from Triton
488
+ status_message += "Falling back to local inference. "
489
+ if LOCAL_MODEL_AVAILABLE and local_input_array is not None:
490
+ print("Running local inference using omnicloudmask...")
491
+ try:
492
+ # Predict cloud mask using local omnicloudmask
493
+ pred_mask = predict_from_array(
494
+ local_input_array,
495
+ batch_size=batch_size,
496
+ patch_size=patch_size,
497
+ patch_overlap=patch_overlap
498
+ )
499
+ print(f"Local prediction successful. Output mask shape: {pred_mask.shape}, dtype: {pred_mask.dtype}")
500
+ status_message += "Local inference successful."
501
+ except Exception as e:
502
+ print(f"ERROR: Local inference failed: {e}")
503
+ import traceback
504
+ traceback.print_exc()
505
+ status_message += f"Local inference FAILED: {e}"
506
+ # Keep pred_mask as None
507
+ elif not LOCAL_MODEL_AVAILABLE:
508
+ status_message += "Local model not available. Cannot perform inference."
509
+ print("ERROR: Local model could not be loaded.")
510
+ elif local_input_array is None:
511
+ status_message += "Local input data preparation failed. Cannot perform local inference."
512
+ print("ERROR: Failed to prepare input array for local model.")
513
+ else:
514
+ status_message += "Unknown state, cannot perform inference." # Should not happen
515
+
516
+ # 5. Process Results (Stats and Visualization) if mask was generated
517
+ if pred_mask is not None:
518
+ # Ensure mask is squeezed to 2D if necessary (local model might return extra dim)
519
+ if pred_mask.ndim == 3 and pred_mask.shape[0] == 1:
520
+ flat_mask = np.squeeze(pred_mask, axis=0)
521
+ elif pred_mask.ndim == 2:
522
+ flat_mask = pred_mask
523
+ else:
524
+ print(f"ERROR: Unexpected mask shape after inference: {pred_mask.shape}")
525
+ status_message += " ERROR: Invalid mask shape received."
526
+ flat_mask = None # Invalidate mask
527
+
528
+ if flat_mask is not None:
529
+ # Calculate class distribution
530
+ clear_pixels = np.sum(flat_mask == 0)
531
+ thick_cloud_pixels = np.sum(flat_mask == 1)
532
+ thin_cloud_pixels = np.sum(flat_mask == 2)
533
+ cloud_shadow_pixels = np.sum(flat_mask == 3)
534
+ total_pixels = flat_mask.size
535
+
536
+ stats = f"""
537
+ Cloud Mask Statistics ({'Triton' if use_triton else 'Local'}):
538
+ - Clear: {clear_pixels} pixels ({clear_pixels/total_pixels*100:.2f}%)
539
+ - Thick Cloud: {thick_cloud_pixels} pixels ({thick_cloud_pixels/total_pixels*100:.2f}%)
540
+ - Thin Cloud: {thin_cloud_pixels} pixels ({thin_cloud_pixels/total_pixels*100:.2f}%)
541
+ - Cloud Shadow: {cloud_shadow_pixels} pixels ({cloud_shadow_pixels/total_pixels*100:.2f}%)
542
+ - Total Cloud Cover (Thick+Thin): {(thick_cloud_pixels + thin_cloud_pixels)/total_pixels*100:.2f}%
543
+ """
544
+ status_message += f"\nMask stats calculated. Total pixels: {total_pixels}."
545
+
546
+ # Visualize the cloud mask on the original image
547
+ print("Generating final visualization...")
548
+ visualization = visualize_cloud_mask(rgb_float_image, flat_mask) # Use float image for viz function
549
+
550
+ if visualization is None:
551
+ status_message += " ERROR: Failed to generate visualization."
552
+
553
+ print("--- Cloud Detection Process Finished ---")
554
+ return rgb_display_image, visualization, status_message + "\n" + stats
555
+ else:
556
+ # Mask had wrong shape
557
+ return rgb_display_image, None, status_message + "\nERROR: Could not process prediction mask."
558
+
559
+ else:
560
+ # Inference failed both ways or initial loading failed
561
+ print("--- Cloud Detection Process Failed ---")
562
+ return rgb_display_image, None, status_message + "\nERROR: Could not generate cloud mask."
563
 
 
 
564
 
565
+ # --- Gradio Interface (from paste.txt) ---
 
 
 
 
 
 
566
 
 
567
  def check_cpu_usage():
568
  """Check and return the current CPU usage."""
569
  return f"CPU Usage: {psutil.cpu_percent()}%"
570
 
571
+ # --- Build Gradio App ---
572
+ print("Building Gradio interface...")
573
+ with gr.Blocks(title="Satellite Cloud Detection (Triton/Local)") as demo:
574
  gr.Markdown("""
575
+ # Satellite Cloud Detection (with Triton Fallback)
576
+
577
+ Upload separate JP2 files for Red (e.g., B04), Green (e.g., B03), Blue (e.g., B02), and NIR (e.g., B8A) channels.
578
+ The application will **first attempt** to use a remote Triton Inference Server. If the server is unavailable or inference fails,
579
+ it will **fall back** to using the local OmniCloudMask model.
580
+
581
+ **Pixel Classification:**
582
+ - Clear (Green)
583
+ - Thick Cloud (Red)
584
+ - Thin Cloud (Yellow)
585
+ - Cloud Shadow (Blue)
586
+
587
+ The model works best with imagery at 10-50m resolution.
588
  """)
589
+
590
  # Main cloud detection interface
591
  with gr.Row():
592
+ with gr.Column(scale=1):
593
+ gr.Markdown("### Input Bands (JP2)")
594
+ # Use filepaths which are needed for both local reading and byte reading
595
+ red_input = gr.File(label="Red Channel (e.g., B04)", type="filepath")
596
+ green_input = gr.File(label="Green Channel (e.g., B03)", type="filepath")
597
+ blue_input = gr.File(label="Blue Channel (e.g., B02)", type="filepath")
598
+ nir_input = gr.File(label="NIR Channel (e.g., B8A)", type="filepath")
599
+
600
+ gr.Markdown("### Local Model Parameters (Used for Fallback)")
601
+ batch_size = gr.Slider(minimum=1, maximum=32, value=4, step=1,
602
+ label="Batch Size",
603
+ info="Memory usage/speed for local model")
604
+ patch_size = gr.Slider(minimum=256, maximum=2048, value=1024, step=128,
605
+ label="Patch Size",
606
+ info="Patch size for local model processing")
607
+ patch_overlap = gr.Slider(minimum=64, maximum=512, value=256, step=64,
608
+ label="Patch Overlap",
609
+ info="Overlap for local model processing")
610
+
611
+ process_btn = gr.Button("Process Cloud Detection", variant="primary")
612
+
613
+ with gr.Column(scale=2):
614
+ gr.Markdown("### Results")
615
  # Output components
616
+ rgb_output = gr.Image(label="Original RGB Image (Approx. True Color)", type="numpy")
617
+ cloud_output = gr.Image(label="Cloud Detection Visualization (Mask Overlay)", type="numpy")
618
+ stats_output = gr.Textbox(label="Processing Status & Statistics", lines=10)
619
+
620
+ # CPU usage monitoring section (Optional)
621
+ with gr.Accordion("System Monitoring", open=False):
622
+ cpu_button = gr.Button("Check CPU Usage")
623
+ cpu_output = gr.Textbox(label="Current CPU Usage")
624
+ cpu_button.click(fn=check_cpu_usage, inputs=None, outputs=cpu_output)
625
+
626
+ # Examples section
627
+ # Ensure example paths are relative to where the script is run,
628
+ # or absolute if needed. Assumes 'jp2s' folder is present.
629
+ example_base = os.path.join(repo_dir, "jp2s") # Use downloaded repo path
630
+ example_files = [
631
+ os.path.join(example_base, "B04.jp2"), # Red
632
+ os.path.join(example_base, "B03.jp2"), # Green
633
+ os.path.join(example_base, "B02.jp2"), # Blue
634
+ os.path.join(example_base, "B8A.jp2") # NIR
635
+ ]
636
+
637
+ # Check if example files actually exist before adding example
638
+ if all(os.path.exists(f) for f in example_files):
639
+ print("Adding examples...")
640
+ gr.Examples(
641
+ examples=[example_files + [4, 1024, 256]], # Corresponds to inputs below
642
+ inputs=[red_input, green_input, blue_input, nir_input, batch_size, patch_size, patch_overlap],
643
+ outputs=[rgb_output, cloud_output, stats_output], # Define outputs for examples too
644
+ fn=process_satellite_images, # Function to run for examples
645
+ cache_examples=False # Maybe disable caching if files change or for debugging
646
+ )
647
+ else:
648
+ print(f"WARN: Example JP2 files not found in '{example_base}'. Skipping examples.")
649
+
650
+
651
+ # Setup main button click handler
652
  process_btn.click(
653
  fn=process_satellite_images,
654
  inputs=[red_input, green_input, blue_input, nir_input, batch_size, patch_size, patch_overlap],
655
  outputs=[rgb_output, cloud_output, stats_output]
656
  )
 
 
 
 
 
 
 
 
 
 
 
 
657
 
658
+ # --- Launch the App ---
659
+ print("Launching Gradio app...")
660
+ # Allow queueing and potentially increase workers if needed
661
+ demo.queue(default_concurrency_limit=4).launch(debug=True, share=False) # share=True for public link if needed