File size: 7,503 Bytes
551ee08 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 |
"""
Cloud Mask Prediction and Visualization Module
This script processes Sentinel-2 satellite imagery bands to predict cloud masks
using the omnicloudmask library. It reads blue, red, green, and near-infrared bands,
resamples them as needed, creates a stacked array for prediction, and visualizes
the cloud mask overlaid on the original RGB image.
"""
import rasterio
import numpy as np
from rasterio.enums import Resampling
from omnicloudmask import predict_from_array
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import matplotlib.patches as mpatches
def load_band(file_path, resample=False, target_height=None, target_width=None):
"""
Load a single band from a raster file with optional resampling.
Args:
file_path (str): Path to the raster file
resample (bool): Whether to resample the band
target_height (int, optional): Target height for resampling
target_width (int, optional): Target width for resampling
Returns:
numpy.ndarray: Band data as float32 array
"""
with rasterio.open(file_path) as src:
if resample and target_height is not None and target_width is not None:
band_data = src.read(
out_shape=(src.count, target_height, target_width),
resampling=Resampling.bilinear
)[0].astype(np.float32)
else:
band_data = src.read()[0].astype(np.float32)
return band_data
def prepare_input_array(base_path="jp2s/"):
"""
Prepare a stacked array of satellite bands for cloud mask prediction.
This function loads blue, red, green, and near-infrared bands from Sentinel-2 imagery,
resamples the NIR band if needed (from 20m to 10m resolution), and stacks the required
bands for cloud mask prediction in CHW (channel, height, width) format.
Args:
base_path (str): Base directory containing the JP2 band files
Returns:
tuple: (stacked_array, rgb_image)
- stacked_array: numpy.ndarray with bands stacked in CHW format for prediction
- rgb_image: numpy.ndarray with RGB bands for visualization
"""
# Define paths to band files
band_paths = {
'blue': f"{base_path}B02.jp2", # Blue band (10m)
'green': f"{base_path}B03.jp2", # Green band (10m)
'red': f"{base_path}B04.jp2", # Red band (10m)
'nir': f"{base_path}B8A.jp2" # Near-infrared band (20m)
}
# Get dimensions from red band to use for resampling
with rasterio.open(band_paths['red']) as src:
target_height = src.height
target_width = src.width
# Load bands (resample NIR band to match 10m resolution)
blue_data = load_band(band_paths['blue'])
green_data = load_band(band_paths['green'])
red_data = load_band(band_paths['red'])
nir_data = load_band(
band_paths['nir'],
resample=True,
target_height=target_height,
target_width=target_width
)
# Print band shapes for debugging
print(f"Band shapes - Blue: {blue_data.shape}, Green: {green_data.shape}, Red: {red_data.shape}, NIR: {nir_data.shape}")
# Create RGB image for visualization (scale to 0-1 range)
# Adjust scaling factor based on your data's bit depth (e.g., 10000 for 16-bit Sentinel-2)
scale_factor = 10000.0 # Adjust based on your data
rgb_image = np.stack([
red_data / scale_factor,
green_data / scale_factor,
blue_data / scale_factor
], axis=-1)
# Clip values to 0-1 range
rgb_image = np.clip(rgb_image, 0, 1)
# Stack bands in CHW format for cloud mask prediction (red, green, nir)
prediction_array = np.stack([red_data, green_data, nir_data], axis=0)
return prediction_array, rgb_image
def visualize_cloud_mask(rgb_image, cloud_mask, output_path="cloud_mask_visualization.png"):
"""
Visualize the cloud mask overlaid on the original RGB image.
Args:
rgb_image (numpy.ndarray): RGB image array (HWC format)
cloud_mask (numpy.ndarray): Predicted cloud mask
output_path (str): Path to save the visualization
"""
# Fix the cloud mask shape if it has an extra dimension
if cloud_mask.ndim > 2:
# Check the shape and squeeze if needed
print(f"Original cloud mask shape: {cloud_mask.shape}")
cloud_mask = np.squeeze(cloud_mask)
print(f"Squeezed cloud mask shape: {cloud_mask.shape}")
# Create figure with two subplots
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 6))
# Plot original RGB image
ax1.imshow(rgb_image)
ax1.set_title("Original RGB Image")
ax1.axis('off')
# Define colormap for cloud mask
# 0=Clear, 1=Thick Cloud, 2=Thin Cloud, 3=Cloud Shadow
cloud_cmap = ListedColormap(['green', 'red', 'yellow', 'blue'])
# Plot cloud mask
im = ax2.imshow(cloud_mask, cmap=cloud_cmap, vmin=0, vmax=3)
ax2.set_title("Cloud Mask")
ax2.axis('off')
# Create legend patches
legend_patches = [
mpatches.Patch(color='green', label='Clear'),
mpatches.Patch(color='red', label='Thick Cloud'),
mpatches.Patch(color='yellow', label='Thin Cloud'),
mpatches.Patch(color='blue', label='Cloud Shadow')
]
ax2.legend(handles=legend_patches, bbox_to_anchor=(1.05, 1), loc='upper left')
# Plot RGB with semi-transparent cloud mask overlay
ax3.imshow(rgb_image)
# Create a masked array with transparency
cloud_mask_rgba = np.zeros((*cloud_mask.shape, 4))
# Set colors with alpha for each class
cloud_mask_rgba[cloud_mask == 0] = [0, 1, 0, 0.3] # Clear - green with low opacity
cloud_mask_rgba[cloud_mask == 1] = [1, 0, 0, 0.5] # Thick Cloud - red
cloud_mask_rgba[cloud_mask == 2] = [1, 1, 0, 0.5] # Thin Cloud - yellow
cloud_mask_rgba[cloud_mask == 3] = [0, 0, 1, 0.5] # Cloud Shadow - blue
ax3.imshow(cloud_mask_rgba)
ax3.set_title("RGB with Cloud Mask Overlay")
ax3.axis('off')
# Add legend to the overlay plot as well
ax3.legend(handles=legend_patches, bbox_to_anchor=(1.05, 1), loc='upper left')
# Adjust layout and save
plt.tight_layout()
plt.savefig(output_path, dpi=300, bbox_inches='tight')
plt.show()
print(f"Visualization saved to {output_path}")
def main():
"""
Main function to run the cloud mask prediction and visualization workflow.
"""
# Create input array from satellite bands and get RGB image for visualization
input_array, rgb_image = prepare_input_array()
# Predict cloud mask using omnicloudmask
pred_mask = predict_from_array(input_array)
# Print prediction results and shape
print("Cloud mask prediction results:")
print(f"Cloud mask shape: {pred_mask.shape}")
print(f"Unique classes in mask: {np.unique(pred_mask)}")
# Calculate class distribution
if pred_mask.ndim > 2:
# Squeeze if needed for counting
flat_mask = np.squeeze(pred_mask)
else:
flat_mask = pred_mask
print(f"Class distribution: Clear: {np.sum(flat_mask == 0)}, Thick Cloud: {np.sum(flat_mask == 1)}, "
f"Thin Cloud: {np.sum(flat_mask == 2)}, Cloud Shadow: {np.sum(flat_mask == 3)}")
# Visualize the cloud mask on the original image
visualize_cloud_mask(rgb_image, pred_mask)
if __name__ == "__main__":
main()
|