cloud-detection / run.py
truthdotphd's picture
Rename model.py to run.py
5f5f88f verified
"""
Cloud Mask Prediction and Visualization Module
This script processes Sentinel-2 satellite imagery bands to predict cloud masks
using the omnicloudmask library. It reads blue, red, green, and near-infrared bands,
resamples them as needed, creates a stacked array for prediction, and visualizes
the cloud mask overlaid on the original RGB image.
"""
import rasterio
import numpy as np
from rasterio.enums import Resampling
from omnicloudmask import predict_from_array
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import matplotlib.patches as mpatches
def load_band(file_path, resample=False, target_height=None, target_width=None):
"""
Load a single band from a raster file with optional resampling.
Args:
file_path (str): Path to the raster file
resample (bool): Whether to resample the band
target_height (int, optional): Target height for resampling
target_width (int, optional): Target width for resampling
Returns:
numpy.ndarray: Band data as float32 array
"""
with rasterio.open(file_path) as src:
if resample and target_height is not None and target_width is not None:
band_data = src.read(
out_shape=(src.count, target_height, target_width),
resampling=Resampling.bilinear
)[0].astype(np.float32)
else:
band_data = src.read()[0].astype(np.float32)
return band_data
def prepare_input_array(base_path="jp2s/"):
"""
Prepare a stacked array of satellite bands for cloud mask prediction.
This function loads blue, red, green, and near-infrared bands from Sentinel-2 imagery,
resamples the NIR band if needed (from 20m to 10m resolution), and stacks the required
bands for cloud mask prediction in CHW (channel, height, width) format.
Args:
base_path (str): Base directory containing the JP2 band files
Returns:
tuple: (stacked_array, rgb_image)
- stacked_array: numpy.ndarray with bands stacked in CHW format for prediction
- rgb_image: numpy.ndarray with RGB bands for visualization
"""
# Define paths to band files
band_paths = {
'blue': f"{base_path}B02.jp2", # Blue band (10m)
'green': f"{base_path}B03.jp2", # Green band (10m)
'red': f"{base_path}B04.jp2", # Red band (10m)
'nir': f"{base_path}B8A.jp2" # Near-infrared band (20m)
}
# Get dimensions from red band to use for resampling
with rasterio.open(band_paths['red']) as src:
target_height = src.height
target_width = src.width
# Load bands (resample NIR band to match 10m resolution)
blue_data = load_band(band_paths['blue'])
green_data = load_band(band_paths['green'])
red_data = load_band(band_paths['red'])
nir_data = load_band(
band_paths['nir'],
resample=True,
target_height=target_height,
target_width=target_width
)
# Print band shapes for debugging
print(f"Band shapes - Blue: {blue_data.shape}, Green: {green_data.shape}, Red: {red_data.shape}, NIR: {nir_data.shape}")
# Create RGB image for visualization (scale to 0-1 range)
# Adjust scaling factor based on your data's bit depth (e.g., 10000 for 16-bit Sentinel-2)
scale_factor = 10000.0 # Adjust based on your data
rgb_image = np.stack([
red_data / scale_factor,
green_data / scale_factor,
blue_data / scale_factor
], axis=-1)
# Clip values to 0-1 range
rgb_image = np.clip(rgb_image, 0, 1)
# Stack bands in CHW format for cloud mask prediction (red, green, nir)
prediction_array = np.stack([red_data, green_data, nir_data], axis=0)
return prediction_array, rgb_image
def visualize_cloud_mask(rgb_image, cloud_mask, output_path="cloud_mask_visualization.png"):
"""
Visualize the cloud mask overlaid on the original RGB image.
Args:
rgb_image (numpy.ndarray): RGB image array (HWC format)
cloud_mask (numpy.ndarray): Predicted cloud mask
output_path (str): Path to save the visualization
"""
# Fix the cloud mask shape if it has an extra dimension
if cloud_mask.ndim > 2:
# Check the shape and squeeze if needed
print(f"Original cloud mask shape: {cloud_mask.shape}")
cloud_mask = np.squeeze(cloud_mask)
print(f"Squeezed cloud mask shape: {cloud_mask.shape}")
# Create figure with two subplots
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 6))
# Plot original RGB image
ax1.imshow(rgb_image)
ax1.set_title("Original RGB Image")
ax1.axis('off')
# Define colormap for cloud mask
# 0=Clear, 1=Thick Cloud, 2=Thin Cloud, 3=Cloud Shadow
cloud_cmap = ListedColormap(['green', 'red', 'yellow', 'blue'])
# Plot cloud mask
im = ax2.imshow(cloud_mask, cmap=cloud_cmap, vmin=0, vmax=3)
ax2.set_title("Cloud Mask")
ax2.axis('off')
# Create legend patches
legend_patches = [
mpatches.Patch(color='green', label='Clear'),
mpatches.Patch(color='red', label='Thick Cloud'),
mpatches.Patch(color='yellow', label='Thin Cloud'),
mpatches.Patch(color='blue', label='Cloud Shadow')
]
ax2.legend(handles=legend_patches, bbox_to_anchor=(1.05, 1), loc='upper left')
# Plot RGB with semi-transparent cloud mask overlay
ax3.imshow(rgb_image)
# Create a masked array with transparency
cloud_mask_rgba = np.zeros((*cloud_mask.shape, 4))
# Set colors with alpha for each class
cloud_mask_rgba[cloud_mask == 0] = [0, 1, 0, 0.3] # Clear - green with low opacity
cloud_mask_rgba[cloud_mask == 1] = [1, 0, 0, 0.5] # Thick Cloud - red
cloud_mask_rgba[cloud_mask == 2] = [1, 1, 0, 0.5] # Thin Cloud - yellow
cloud_mask_rgba[cloud_mask == 3] = [0, 0, 1, 0.5] # Cloud Shadow - blue
ax3.imshow(cloud_mask_rgba)
ax3.set_title("RGB with Cloud Mask Overlay")
ax3.axis('off')
# Add legend to the overlay plot as well
ax3.legend(handles=legend_patches, bbox_to_anchor=(1.05, 1), loc='upper left')
# Adjust layout and save
plt.tight_layout()
plt.savefig(output_path, dpi=300, bbox_inches='tight')
plt.show()
print(f"Visualization saved to {output_path}")
def main():
"""
Main function to run the cloud mask prediction and visualization workflow.
"""
# Create input array from satellite bands and get RGB image for visualization
input_array, rgb_image = prepare_input_array()
# Predict cloud mask using omnicloudmask
pred_mask = predict_from_array(input_array)
# Print prediction results and shape
print("Cloud mask prediction results:")
print(f"Cloud mask shape: {pred_mask.shape}")
print(f"Unique classes in mask: {np.unique(pred_mask)}")
# Calculate class distribution
if pred_mask.ndim > 2:
# Squeeze if needed for counting
flat_mask = np.squeeze(pred_mask)
else:
flat_mask = pred_mask
print(f"Class distribution: Clear: {np.sum(flat_mask == 0)}, Thick Cloud: {np.sum(flat_mask == 1)}, "
f"Thin Cloud: {np.sum(flat_mask == 2)}, Cloud Shadow: {np.sum(flat_mask == 3)}")
# Visualize the cloud mask on the original image
visualize_cloud_mask(rgb_image, pred_mask)
if __name__ == "__main__":
main()