import gradio as gr import numpy as np import os import cv2 import matplotlib.pyplot as plt from huggingface_hub import snapshot_download import rasterio from rasterio.enums import Resampling from rasterio.plot import reshape_as_image import sys # Download the entire repository to a subdirectory repo_id = "truthdotphd/cloud-detection" repo_subdir = "." repo_dir = snapshot_download(repo_id=repo_id, local_dir=repo_subdir) # Add the repository directory to the Python path sys.path.append(repo_dir) # Import the necessary functions from the downloaded modules try: from omnicloudmask import predict_from_array except ImportError: omnicloudmask_dir = os.path.join(repo_dir, "omnicloudmask") if os.path.exists(omnicloudmask_dir): sys.path.append(omnicloudmask_dir) from omnicloudmask import predict_from_array else: raise ImportError("Could not find the omnicloudmask module in the downloaded repository") def visualize_rgb(red_file, green_file, blue_file, nir_file): """ Create and display an RGB visualization immediately after images are uploaded. """ if not all([red_file, green_file, blue_file, nir_file]): return None try: # Get dimensions from red band to use for resampling with rasterio.open(red_file) as src: target_height = src.height target_width = src.width # Load bands blue_data = load_band(blue_file) green_data = load_band(green_file) red_data = load_band(red_file) # Compute max values for each channel for dynamic normalization red_max = np.max(red_data) green_max = np.max(green_data) blue_max = np.max(blue_data) # Create RGB image for visualization with dynamic normalization rgb_image = np.zeros((red_data.shape[0], red_data.shape[1], 3), dtype=np.float32) # Normalize each channel individually epsilon = 1e-10 rgb_image[:, :, 0] = red_data / (red_max + epsilon) rgb_image[:, :, 1] = green_data / (green_max + epsilon) rgb_image[:, :, 2] = blue_data / (blue_max + epsilon) # Clip values to 0-1 range rgb_image = np.clip(rgb_image, 0, 1) # Apply contrast enhancement for better visualization p2 = np.percentile(rgb_image, 2) p98 = np.percentile(rgb_image, 98) rgb_image_enhanced = np.clip((rgb_image - p2) / (p98 - p2), 0, 1) # Convert to uint8 for display rgb_display = (rgb_image_enhanced * 255).astype(np.uint8) return rgb_display except Exception as e: print(f"Error generating RGB preview: {e}") return None def visualize_jp2(file_path): """ Visualize a single JP2 file. """ with rasterio.open(file_path) as src: # Read the data data = src.read(1) # Normalize the data for visualization data = (data - np.min(data)) / (np.max(data) - np.min(data)) # Apply a colormap for better visualization cmap = plt.get_cmap('viridis') colored_image = cmap(data) # Convert to 8-bit for display return (colored_image[:, :, :3] * 255).astype(np.uint8) def load_band(file_path, resample=False, target_height=None, target_width=None): """ Load a single band from a raster file with optional resampling. """ with rasterio.open(file_path) as src: if resample and target_height is not None and target_width is not None: band_data = src.read( out_shape=(src.count, target_height, target_width), resampling=Resampling.bilinear )[0].astype(np.float32) else: band_data = src.read()[0].astype(np.float32) return band_data def prepare_input_array(red_file, green_file, blue_file, nir_file): """ Prepare a stacked array of satellite bands for cloud mask prediction. """ # Get dimensions from red band to use for resampling with rasterio.open(red_file) as src: target_height = src.height target_width = src.width # Load bands (resample NIR band to match 10m resolution) blue_data = load_band(blue_file) green_data = load_band(green_file) red_data = load_band(red_file) nir_data = load_band( nir_file, resample=True, target_height=target_height, target_width=target_width ) # Print band shapes for debugging print(f"Band shapes - Blue: {blue_data.shape}, Green: {green_data.shape}, Red: {red_data.shape}, NIR: {nir_data.shape}") # Compute max values for each channel for dynamic normalization red_max = np.max(red_data) green_max = np.max(green_data) blue_max = np.max(blue_data) print(f"Max values - Red: {red_max}, Green: {green_max}, Blue: {blue_max}") # Create RGB image for visualization with dynamic normalization rgb_image = np.zeros((red_data.shape[0], red_data.shape[1], 3), dtype=np.float32) # Normalize each channel individually # Add a small epsilon to avoid division by zero epsilon = 1e-10 rgb_image[:, :, 0] = red_data / (red_max + epsilon) rgb_image[:, :, 1] = green_data / (green_max + epsilon) rgb_image[:, :, 2] = blue_data / (blue_max + epsilon) # Clip values to 0-1 range rgb_image = np.clip(rgb_image, 0, 1) # Optional: Apply contrast enhancement for better visualization p2 = np.percentile(rgb_image, 2) p98 = np.percentile(rgb_image, 98) rgb_image_enhanced = np.clip((rgb_image - p2) / (p98 - p2), 0, 1) # Stack bands in CHW format for cloud mask prediction (red, green, nir) prediction_array = np.stack([red_data, green_data, nir_data], axis=0) return prediction_array, rgb_image_enhanced def visualize_cloud_mask(rgb_image, pred_mask): """ Create a visualization of the cloud mask overlaid on the RGB image. """ # Ensure pred_mask has the right dimensions if pred_mask.ndim > 2: pred_mask = np.squeeze(pred_mask) print(f"RGB image shape: {rgb_image.shape}, Pred mask shape: {pred_mask.shape}") # Ensure mask has the same spatial dimensions as the image if pred_mask.shape != rgb_image.shape[:2]: pred_mask = cv2.resize( pred_mask.astype(np.float32), (rgb_image.shape[1], rgb_image.shape[0]), interpolation=cv2.INTER_NEAREST ).astype(np.uint8) print(f"Resized mask shape: {pred_mask.shape}") # Define colors for each class colors = { 0: [0, 255, 0], # Clear - Green 1: [255, 255, 255], # Thick Cloud - White 2: [200, 200, 200], # Thin Cloud - Light Gray 3: [100, 100, 100] # Cloud Shadow - Dark Gray } # Create a color-coded mask mask_vis = np.zeros((pred_mask.shape[0], pred_mask.shape[1], 3), dtype=np.uint8) for class_idx, color in colors.items(): mask_vis[pred_mask == class_idx] = color # Create a blended visualization alpha = 0.5 blended = cv2.addWeighted((rgb_image * 255).astype(np.uint8), 1-alpha, mask_vis, alpha, 0) # Get the width of the blended image for the legend image_width = blended.shape[1] # Create a legend with the same width as the image legend = np.ones((100, image_width, 3), dtype=np.uint8) * 255 legend_text = ["Clear", "Thick Cloud", "Thin Cloud", "Cloud Shadow"] legend_colors = [colors[i] for i in range(4)] for i, (text, color) in enumerate(zip(legend_text, legend_colors)): cv2.rectangle(legend, (10, 10 + i*20), (30, 30 + i*20), color, -1) cv2.putText(legend, text, (40, 25 + i*20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1) # Combine image and legend final_output = np.vstack([blended, legend]) return final_output def process_satellite_images(red_file, green_file, blue_file, nir_file, batch_size, patch_size, patch_overlap): """ Process the satellite images and detect clouds. """ if not all([red_file, green_file, blue_file, nir_file]): return None, None, "Please upload all four channel files (Red, Green, Blue, NIR)" # Prepare input array and RGB image for visualization input_array, rgb_image = prepare_input_array(red_file, green_file, blue_file, nir_file) # Convert RGB image to format suitable for display rgb_display = (rgb_image * 255).astype(np.uint8) # Predict cloud mask using omnicloudmask pred_mask = predict_from_array( input_array, batch_size=batch_size, patch_size=patch_size, patch_overlap=patch_overlap ) # Calculate class distribution if pred_mask.ndim > 2: flat_mask = np.squeeze(pred_mask) else: flat_mask = pred_mask clear_pixels = np.sum(flat_mask == 0) thick_cloud_pixels = np.sum(flat_mask == 1) thin_cloud_pixels = np.sum(flat_mask == 2) cloud_shadow_pixels = np.sum(flat_mask == 3) total_pixels = flat_mask.size stats = f""" Cloud Mask Statistics: - Clear: {clear_pixels} pixels ({clear_pixels/total_pixels*100:.2f}%) - Thick Cloud: {thick_cloud_pixels} pixels ({thick_cloud_pixels/total_pixels*100:.2f}%) - Thin Cloud: {thin_cloud_pixels} pixels ({thin_cloud_pixels/total_pixels*100:.2f}%) - Cloud Shadow: {cloud_shadow_pixels} pixels ({cloud_shadow_pixels/total_pixels*100:.2f}%) - Total Cloud Cover: {(thick_cloud_pixels + thin_cloud_pixels)/total_pixels*100:.2f}% """ # Visualize the cloud mask on the original image visualization = visualize_cloud_mask(rgb_image, flat_mask) return rgb_display, visualization, stats # Create Gradio interface with default examples demo = gr.Interface( fn=process_satellite_images, inputs=[ gr.File(label="Red Channel (JP2)"), gr.File(label="Green Channel (JP2)"), gr.File(label="Blue Channel (JP2)"), gr.File(label="NIR Channel (JP2)"), gr.Slider(minimum=1, maximum=32, value=1, step=1, label="Batch Size", info="Higher values use more memory but process faster"), gr.Slider(minimum=500, maximum=2000, value=1000, step=100, label="Patch Size", info="Size of image patches for processing"), gr.Slider(minimum=100, maximum=500, value=300, step=50, label="Patch Overlap", info="Overlap between patches to avoid edge artifacts") ], outputs=[ gr.Image(label="Original RGB Image"), gr.Image(label="Cloud Detection Visualization"), gr.Textbox(label="Statistics") ], title="Satellite Cloud Detection", description=""" Upload separate JP2 files for Red, Green, Blue, and NIR channels to detect clouds in satellite imagery. This application uses the OmniCloudMask model to classify each pixel as: - Clear (0) - Thick Cloud (1) - Thin Cloud (2) - Cloud Shadow (3) The model works best with imagery at 10-50m resolution. For higher resolution imagery, downsampling is recommended. """, examples=[ ["jp2s/B04.jp2", "jp2s/B03.jp2", "jp2s/B02.jp2", "jp2s/B8A.jp2", 1, 1000, 300] ] ) # Launch the app demo.launch(debug=True)