""" Cloud Mask Prediction and Visualization Module This script processes Sentinel-2 satellite imagery bands to predict cloud masks using the omnicloudmask library. It reads blue, red, green, and near-infrared bands, resamples them as needed, creates a stacked array for prediction, and visualizes the cloud mask overlaid on the original RGB image. """ import rasterio import numpy as np from rasterio.enums import Resampling from omnicloudmask import predict_from_array import matplotlib.pyplot as plt from matplotlib.colors import ListedColormap import matplotlib.patches as mpatches def load_band(file_path, resample=False, target_height=None, target_width=None): """ Load a single band from a raster file with optional resampling. Args: file_path (str): Path to the raster file resample (bool): Whether to resample the band target_height (int, optional): Target height for resampling target_width (int, optional): Target width for resampling Returns: numpy.ndarray: Band data as float32 array """ with rasterio.open(file_path) as src: if resample and target_height is not None and target_width is not None: band_data = src.read( out_shape=(src.count, target_height, target_width), resampling=Resampling.bilinear )[0].astype(np.float32) else: band_data = src.read()[0].astype(np.float32) return band_data def prepare_input_array(base_path="jp2s/"): """ Prepare a stacked array of satellite bands for cloud mask prediction. This function loads blue, red, green, and near-infrared bands from Sentinel-2 imagery, resamples the NIR band if needed (from 20m to 10m resolution), and stacks the required bands for cloud mask prediction in CHW (channel, height, width) format. Args: base_path (str): Base directory containing the JP2 band files Returns: tuple: (stacked_array, rgb_image) - stacked_array: numpy.ndarray with bands stacked in CHW format for prediction - rgb_image: numpy.ndarray with RGB bands for visualization """ # Define paths to band files band_paths = { 'blue': f"{base_path}B02.jp2", # Blue band (10m) 'green': f"{base_path}B03.jp2", # Green band (10m) 'red': f"{base_path}B04.jp2", # Red band (10m) 'nir': f"{base_path}B8A.jp2" # Near-infrared band (20m) } # Get dimensions from red band to use for resampling with rasterio.open(band_paths['red']) as src: target_height = src.height target_width = src.width # Load bands (resample NIR band to match 10m resolution) blue_data = load_band(band_paths['blue']) green_data = load_band(band_paths['green']) red_data = load_band(band_paths['red']) nir_data = load_band( band_paths['nir'], resample=True, target_height=target_height, target_width=target_width ) # Print band shapes for debugging print(f"Band shapes - Blue: {blue_data.shape}, Green: {green_data.shape}, Red: {red_data.shape}, NIR: {nir_data.shape}") # Create RGB image for visualization (scale to 0-1 range) # Adjust scaling factor based on your data's bit depth (e.g., 10000 for 16-bit Sentinel-2) scale_factor = 10000.0 # Adjust based on your data rgb_image = np.stack([ red_data / scale_factor, green_data / scale_factor, blue_data / scale_factor ], axis=-1) # Clip values to 0-1 range rgb_image = np.clip(rgb_image, 0, 1) # Stack bands in CHW format for cloud mask prediction (red, green, nir) prediction_array = np.stack([red_data, green_data, nir_data], axis=0) return prediction_array, rgb_image def visualize_cloud_mask(rgb_image, cloud_mask, output_path="cloud_mask_visualization.png"): """ Visualize the cloud mask overlaid on the original RGB image. Args: rgb_image (numpy.ndarray): RGB image array (HWC format) cloud_mask (numpy.ndarray): Predicted cloud mask output_path (str): Path to save the visualization """ # Fix the cloud mask shape if it has an extra dimension if cloud_mask.ndim > 2: # Check the shape and squeeze if needed print(f"Original cloud mask shape: {cloud_mask.shape}") cloud_mask = np.squeeze(cloud_mask) print(f"Squeezed cloud mask shape: {cloud_mask.shape}") # Create figure with two subplots fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 6)) # Plot original RGB image ax1.imshow(rgb_image) ax1.set_title("Original RGB Image") ax1.axis('off') # Define colormap for cloud mask # 0=Clear, 1=Thick Cloud, 2=Thin Cloud, 3=Cloud Shadow cloud_cmap = ListedColormap(['green', 'red', 'yellow', 'blue']) # Plot cloud mask im = ax2.imshow(cloud_mask, cmap=cloud_cmap, vmin=0, vmax=3) ax2.set_title("Cloud Mask") ax2.axis('off') # Create legend patches legend_patches = [ mpatches.Patch(color='green', label='Clear'), mpatches.Patch(color='red', label='Thick Cloud'), mpatches.Patch(color='yellow', label='Thin Cloud'), mpatches.Patch(color='blue', label='Cloud Shadow') ] ax2.legend(handles=legend_patches, bbox_to_anchor=(1.05, 1), loc='upper left') # Plot RGB with semi-transparent cloud mask overlay ax3.imshow(rgb_image) # Create a masked array with transparency cloud_mask_rgba = np.zeros((*cloud_mask.shape, 4)) # Set colors with alpha for each class cloud_mask_rgba[cloud_mask == 0] = [0, 1, 0, 0.3] # Clear - green with low opacity cloud_mask_rgba[cloud_mask == 1] = [1, 0, 0, 0.5] # Thick Cloud - red cloud_mask_rgba[cloud_mask == 2] = [1, 1, 0, 0.5] # Thin Cloud - yellow cloud_mask_rgba[cloud_mask == 3] = [0, 0, 1, 0.5] # Cloud Shadow - blue ax3.imshow(cloud_mask_rgba) ax3.set_title("RGB with Cloud Mask Overlay") ax3.axis('off') # Add legend to the overlay plot as well ax3.legend(handles=legend_patches, bbox_to_anchor=(1.05, 1), loc='upper left') # Adjust layout and save plt.tight_layout() plt.savefig(output_path, dpi=300, bbox_inches='tight') plt.show() print(f"Visualization saved to {output_path}") def main(): """ Main function to run the cloud mask prediction and visualization workflow. """ # Create input array from satellite bands and get RGB image for visualization input_array, rgb_image = prepare_input_array() # Predict cloud mask using omnicloudmask pred_mask = predict_from_array(input_array) # Print prediction results and shape print("Cloud mask prediction results:") print(f"Cloud mask shape: {pred_mask.shape}") print(f"Unique classes in mask: {np.unique(pred_mask)}") # Calculate class distribution if pred_mask.ndim > 2: # Squeeze if needed for counting flat_mask = np.squeeze(pred_mask) else: flat_mask = pred_mask print(f"Class distribution: Clear: {np.sum(flat_mask == 0)}, Thick Cloud: {np.sum(flat_mask == 1)}, " f"Thin Cloud: {np.sum(flat_mask == 2)}, Cloud Shadow: {np.sum(flat_mask == 3)}") # Visualize the cloud mask on the original image visualize_cloud_mask(rgb_image, pred_mask) if __name__ == "__main__": main()