truthdotphd commited on
Commit
57355cc
·
verified ·
1 Parent(s): 3bce562

initial commit

Browse files
Files changed (2) hide show
  1. app.py +302 -0
  2. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import os
4
+ import cv2
5
+ import matplotlib.pyplot as plt
6
+ from huggingface_hub import snapshot_download
7
+ import rasterio
8
+ from rasterio.enums import Resampling
9
+ from rasterio.plot import reshape_as_image
10
+ import sys
11
+
12
+ # Download the entire repository to a subdirectory
13
+ repo_id = "truthdotphd/cloud-detection"
14
+ repo_subdir = "."
15
+ repo_dir = snapshot_download(repo_id=repo_id, local_dir=repo_subdir)
16
+
17
+ # Add the repository directory to the Python path
18
+ sys.path.append(repo_dir)
19
+
20
+ # Import the necessary functions from the downloaded modules
21
+ try:
22
+ from omnicloudmask import predict_from_array
23
+ except ImportError:
24
+ omnicloudmask_dir = os.path.join(repo_dir, "omnicloudmask")
25
+ if os.path.exists(omnicloudmask_dir):
26
+ sys.path.append(omnicloudmask_dir)
27
+ from omnicloudmask import predict_from_array
28
+ else:
29
+ raise ImportError("Could not find the omnicloudmask module in the downloaded repository")
30
+
31
+ def visualize_rgb(red_file, green_file, blue_file, nir_file):
32
+ """
33
+ Create and display an RGB visualization immediately after images are uploaded.
34
+ """
35
+ if not all([red_file, green_file, blue_file, nir_file]):
36
+ return None
37
+
38
+ try:
39
+ # Get dimensions from red band to use for resampling
40
+ with rasterio.open(red_file) as src:
41
+ target_height = src.height
42
+ target_width = src.width
43
+
44
+ # Load bands
45
+ blue_data = load_band(blue_file)
46
+ green_data = load_band(green_file)
47
+ red_data = load_band(red_file)
48
+
49
+ # Compute max values for each channel for dynamic normalization
50
+ red_max = np.max(red_data)
51
+ green_max = np.max(green_data)
52
+ blue_max = np.max(blue_data)
53
+
54
+ # Create RGB image for visualization with dynamic normalization
55
+ rgb_image = np.zeros((red_data.shape[0], red_data.shape[1], 3), dtype=np.float32)
56
+
57
+ # Normalize each channel individually
58
+ epsilon = 1e-10
59
+ rgb_image[:, :, 0] = red_data / (red_max + epsilon)
60
+ rgb_image[:, :, 1] = green_data / (green_max + epsilon)
61
+ rgb_image[:, :, 2] = blue_data / (blue_max + epsilon)
62
+
63
+ # Clip values to 0-1 range
64
+ rgb_image = np.clip(rgb_image, 0, 1)
65
+
66
+ # Apply contrast enhancement for better visualization
67
+ p2 = np.percentile(rgb_image, 2)
68
+ p98 = np.percentile(rgb_image, 98)
69
+ rgb_image_enhanced = np.clip((rgb_image - p2) / (p98 - p2), 0, 1)
70
+
71
+ # Convert to uint8 for display
72
+ rgb_display = (rgb_image_enhanced * 255).astype(np.uint8)
73
+
74
+ return rgb_display
75
+ except Exception as e:
76
+ print(f"Error generating RGB preview: {e}")
77
+ return None
78
+
79
+
80
+ def visualize_jp2(file_path):
81
+ """
82
+ Visualize a single JP2 file.
83
+ """
84
+ with rasterio.open(file_path) as src:
85
+ # Read the data
86
+ data = src.read(1)
87
+
88
+ # Normalize the data for visualization
89
+ data = (data - np.min(data)) / (np.max(data) - np.min(data))
90
+
91
+ # Apply a colormap for better visualization
92
+ cmap = plt.get_cmap('viridis')
93
+ colored_image = cmap(data)
94
+
95
+ # Convert to 8-bit for display
96
+ return (colored_image[:, :, :3] * 255).astype(np.uint8)
97
+
98
+ def load_band(file_path, resample=False, target_height=None, target_width=None):
99
+ """
100
+ Load a single band from a raster file with optional resampling.
101
+ """
102
+ with rasterio.open(file_path) as src:
103
+ if resample and target_height is not None and target_width is not None:
104
+ band_data = src.read(
105
+ out_shape=(src.count, target_height, target_width),
106
+ resampling=Resampling.bilinear
107
+ )[0].astype(np.float32)
108
+ else:
109
+ band_data = src.read()[0].astype(np.float32)
110
+
111
+ return band_data
112
+
113
+ def prepare_input_array(red_file, green_file, blue_file, nir_file):
114
+ """
115
+ Prepare a stacked array of satellite bands for cloud mask prediction.
116
+ """
117
+ # Get dimensions from red band to use for resampling
118
+ with rasterio.open(red_file) as src:
119
+ target_height = src.height
120
+ target_width = src.width
121
+
122
+ # Load bands (resample NIR band to match 10m resolution)
123
+ blue_data = load_band(blue_file)
124
+ green_data = load_band(green_file)
125
+ red_data = load_band(red_file)
126
+ nir_data = load_band(
127
+ nir_file,
128
+ resample=True,
129
+ target_height=target_height,
130
+ target_width=target_width
131
+ )
132
+
133
+ # Print band shapes for debugging
134
+ print(f"Band shapes - Blue: {blue_data.shape}, Green: {green_data.shape}, Red: {red_data.shape}, NIR: {nir_data.shape}")
135
+
136
+ # Compute max values for each channel for dynamic normalization
137
+ red_max = np.max(red_data)
138
+ green_max = np.max(green_data)
139
+ blue_max = np.max(blue_data)
140
+
141
+ print(f"Max values - Red: {red_max}, Green: {green_max}, Blue: {blue_max}")
142
+
143
+ # Create RGB image for visualization with dynamic normalization
144
+ rgb_image = np.zeros((red_data.shape[0], red_data.shape[1], 3), dtype=np.float32)
145
+
146
+ # Normalize each channel individually
147
+ # Add a small epsilon to avoid division by zero
148
+ epsilon = 1e-10
149
+ rgb_image[:, :, 0] = red_data / (red_max + epsilon)
150
+ rgb_image[:, :, 1] = green_data / (green_max + epsilon)
151
+ rgb_image[:, :, 2] = blue_data / (blue_max + epsilon)
152
+
153
+ # Clip values to 0-1 range
154
+ rgb_image = np.clip(rgb_image, 0, 1)
155
+
156
+ # Optional: Apply contrast enhancement for better visualization
157
+ p2 = np.percentile(rgb_image, 2)
158
+ p98 = np.percentile(rgb_image, 98)
159
+ rgb_image_enhanced = np.clip((rgb_image - p2) / (p98 - p2), 0, 1)
160
+
161
+ # Stack bands in CHW format for cloud mask prediction (red, green, nir)
162
+ prediction_array = np.stack([red_data, green_data, nir_data], axis=0)
163
+
164
+ return prediction_array, rgb_image_enhanced
165
+
166
+
167
+ def visualize_cloud_mask(rgb_image, pred_mask):
168
+ """
169
+ Create a visualization of the cloud mask overlaid on the RGB image.
170
+ """
171
+ # Ensure pred_mask has the right dimensions
172
+ if pred_mask.ndim > 2:
173
+ pred_mask = np.squeeze(pred_mask)
174
+
175
+ print(f"RGB image shape: {rgb_image.shape}, Pred mask shape: {pred_mask.shape}")
176
+
177
+ # Ensure mask has the same spatial dimensions as the image
178
+ if pred_mask.shape != rgb_image.shape[:2]:
179
+ pred_mask = cv2.resize(
180
+ pred_mask.astype(np.float32),
181
+ (rgb_image.shape[1], rgb_image.shape[0]),
182
+ interpolation=cv2.INTER_NEAREST
183
+ ).astype(np.uint8)
184
+ print(f"Resized mask shape: {pred_mask.shape}")
185
+
186
+ # Define colors for each class
187
+ colors = {
188
+ 0: [0, 255, 0], # Clear - Green
189
+ 1: [255, 255, 255], # Thick Cloud - White
190
+ 2: [200, 200, 200], # Thin Cloud - Light Gray
191
+ 3: [100, 100, 100] # Cloud Shadow - Dark Gray
192
+ }
193
+
194
+ # Create a color-coded mask
195
+ mask_vis = np.zeros((pred_mask.shape[0], pred_mask.shape[1], 3), dtype=np.uint8)
196
+ for class_idx, color in colors.items():
197
+ mask_vis[pred_mask == class_idx] = color
198
+
199
+ # Create a blended visualization
200
+ alpha = 0.5
201
+ blended = cv2.addWeighted((rgb_image * 255).astype(np.uint8), 1-alpha, mask_vis, alpha, 0)
202
+
203
+ # Get the width of the blended image for the legend
204
+ image_width = blended.shape[1]
205
+
206
+ # Create a legend with the same width as the image
207
+ legend = np.ones((100, image_width, 3), dtype=np.uint8) * 255
208
+ legend_text = ["Clear", "Thick Cloud", "Thin Cloud", "Cloud Shadow"]
209
+ legend_colors = [colors[i] for i in range(4)]
210
+
211
+ for i, (text, color) in enumerate(zip(legend_text, legend_colors)):
212
+ cv2.rectangle(legend, (10, 10 + i*20), (30, 30 + i*20), color, -1)
213
+ cv2.putText(legend, text, (40, 25 + i*20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1)
214
+
215
+ # Combine image and legend
216
+ final_output = np.vstack([blended, legend])
217
+
218
+ return final_output
219
+
220
+ def process_satellite_images(red_file, green_file, blue_file, nir_file, batch_size, patch_size, patch_overlap):
221
+ """
222
+ Process the satellite images and detect clouds.
223
+ """
224
+ if not all([red_file, green_file, blue_file, nir_file]):
225
+ return None, None, "Please upload all four channel files (Red, Green, Blue, NIR)"
226
+
227
+ # Prepare input array and RGB image for visualization
228
+ input_array, rgb_image = prepare_input_array(red_file, green_file, blue_file, nir_file)
229
+
230
+ # Convert RGB image to format suitable for display
231
+ rgb_display = (rgb_image * 255).astype(np.uint8)
232
+
233
+ # Predict cloud mask using omnicloudmask
234
+ pred_mask = predict_from_array(
235
+ input_array,
236
+ batch_size=batch_size,
237
+ patch_size=patch_size,
238
+ patch_overlap=patch_overlap
239
+ )
240
+
241
+ # Calculate class distribution
242
+ if pred_mask.ndim > 2:
243
+ flat_mask = np.squeeze(pred_mask)
244
+ else:
245
+ flat_mask = pred_mask
246
+
247
+ clear_pixels = np.sum(flat_mask == 0)
248
+ thick_cloud_pixels = np.sum(flat_mask == 1)
249
+ thin_cloud_pixels = np.sum(flat_mask == 2)
250
+ cloud_shadow_pixels = np.sum(flat_mask == 3)
251
+ total_pixels = flat_mask.size
252
+
253
+ stats = f"""
254
+ Cloud Mask Statistics:
255
+ - Clear: {clear_pixels} pixels ({clear_pixels/total_pixels*100:.2f}%)
256
+ - Thick Cloud: {thick_cloud_pixels} pixels ({thick_cloud_pixels/total_pixels*100:.2f}%)
257
+ - Thin Cloud: {thin_cloud_pixels} pixels ({thin_cloud_pixels/total_pixels*100:.2f}%)
258
+ - Cloud Shadow: {cloud_shadow_pixels} pixels ({cloud_shadow_pixels/total_pixels*100:.2f}%)
259
+ - Total Cloud Cover: {(thick_cloud_pixels + thin_cloud_pixels)/total_pixels*100:.2f}%
260
+ """
261
+
262
+ # Visualize the cloud mask on the original image
263
+ visualization = visualize_cloud_mask(rgb_image, flat_mask)
264
+
265
+ return rgb_display, visualization, stats
266
+
267
+
268
+ # Create Gradio interface
269
+ demo = gr.Interface(
270
+ fn=process_satellite_images,
271
+ inputs=[
272
+ gr.Image(type="filepath", label="Red Channel (JP2)"),
273
+ gr.Image(type="filepath", label="Green Channel (JP2)"),
274
+ gr.Image(type="filepath", label="Blue Channel (JP2)"),
275
+ gr.Image(type="filepath", label="NIR Channel (JP2)"),
276
+ gr.Slider(minimum=1, maximum=32, value=1, step=1, label="Batch Size", info="Higher values use more memory but process faster"),
277
+ gr.Slider(minimum=500, maximum=2000, value=1000, step=100, label="Patch Size", info="Size of image patches for processing"),
278
+ gr.Slider(minimum=100, maximum=500, value=300, step=50, label="Patch Overlap", info="Overlap between patches to avoid edge artifacts")
279
+ ],
280
+ outputs=[
281
+ gr.Image(label="Original RGB Image"),
282
+ gr.Image(label="Cloud Detection Visualization"),
283
+ gr.Textbox(label="Statistics")
284
+ ],
285
+ title="Satellite Cloud Detection",
286
+ description="""
287
+ Upload separate JP2 files for Red, Green, Blue, and NIR channels to detect clouds in satellite imagery.
288
+
289
+ This application uses the OmniCloudMask model to classify each pixel as:
290
+ - Clear (0)
291
+ - Thick Cloud (1)
292
+ - Thin Cloud (2)
293
+ - Cloud Shadow (3)
294
+
295
+ The model works best with imagery at 10-50m resolution. For higher resolution imagery, downsampling is recommended.
296
+ """,
297
+ examples=[
298
+ ["jp2s/B04.jp2", "jp2s/B03.jp2", "jp2s/B02.jp2", "jp2s/B8A.jp2", 1, 1000, 300]
299
+ ]
300
+ )
301
+ # Launch the app
302
+ demo.launch(debug=True)
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ rasterio==1.3.11
2
+ matplotlib==3.7.5
3
+ fastai>=2.7
4
+ timm>=0.9
5
+ tqdm>=4.0
6
+ gdown>=5.1.0
7
+ torch>=2.2