cloud-detection / app.py
truthdotphd's picture
Update app.py
0cc73e9 verified
raw
history blame
11.3 kB
import gradio as gr
import numpy as np
import os
import cv2
import matplotlib.pyplot as plt
from huggingface_hub import snapshot_download
import rasterio
from rasterio.enums import Resampling
from rasterio.plot import reshape_as_image
import sys
# Download the entire repository to a subdirectory
repo_id = "truthdotphd/cloud-detection"
repo_subdir = "."
repo_dir = snapshot_download(repo_id=repo_id, local_dir=repo_subdir)
# Add the repository directory to the Python path
sys.path.append(repo_dir)
# Import the necessary functions from the downloaded modules
try:
from omnicloudmask import predict_from_array
except ImportError:
omnicloudmask_dir = os.path.join(repo_dir, "omnicloudmask")
if os.path.exists(omnicloudmask_dir):
sys.path.append(omnicloudmask_dir)
from omnicloudmask import predict_from_array
else:
raise ImportError("Could not find the omnicloudmask module in the downloaded repository")
def visualize_rgb(red_file, green_file, blue_file, nir_file):
"""
Create and display an RGB visualization immediately after images are uploaded.
"""
if not all([red_file, green_file, blue_file, nir_file]):
return None
try:
# Get dimensions from red band to use for resampling
with rasterio.open(red_file) as src:
target_height = src.height
target_width = src.width
# Load bands
blue_data = load_band(blue_file)
green_data = load_band(green_file)
red_data = load_band(red_file)
# Compute max values for each channel for dynamic normalization
red_max = np.max(red_data)
green_max = np.max(green_data)
blue_max = np.max(blue_data)
# Create RGB image for visualization with dynamic normalization
rgb_image = np.zeros((red_data.shape[0], red_data.shape[1], 3), dtype=np.float32)
# Normalize each channel individually
epsilon = 1e-10
rgb_image[:, :, 0] = red_data / (red_max + epsilon)
rgb_image[:, :, 1] = green_data / (green_max + epsilon)
rgb_image[:, :, 2] = blue_data / (blue_max + epsilon)
# Clip values to 0-1 range
rgb_image = np.clip(rgb_image, 0, 1)
# Apply contrast enhancement for better visualization
p2 = np.percentile(rgb_image, 2)
p98 = np.percentile(rgb_image, 98)
rgb_image_enhanced = np.clip((rgb_image - p2) / (p98 - p2), 0, 1)
# Convert to uint8 for display
rgb_display = (rgb_image_enhanced * 255).astype(np.uint8)
return rgb_display
except Exception as e:
print(f"Error generating RGB preview: {e}")
return None
def visualize_jp2(file_path):
"""
Visualize a single JP2 file.
"""
with rasterio.open(file_path) as src:
# Read the data
data = src.read(1)
# Normalize the data for visualization
data = (data - np.min(data)) / (np.max(data) - np.min(data))
# Apply a colormap for better visualization
cmap = plt.get_cmap('viridis')
colored_image = cmap(data)
# Convert to 8-bit for display
return (colored_image[:, :, :3] * 255).astype(np.uint8)
def load_band(file_path, resample=False, target_height=None, target_width=None):
"""
Load a single band from a raster file with optional resampling.
"""
with rasterio.open(file_path) as src:
if resample and target_height is not None and target_width is not None:
band_data = src.read(
out_shape=(src.count, target_height, target_width),
resampling=Resampling.bilinear
)[0].astype(np.float32)
else:
band_data = src.read()[0].astype(np.float32)
return band_data
def prepare_input_array(red_file, green_file, blue_file, nir_file):
"""
Prepare a stacked array of satellite bands for cloud mask prediction.
"""
# Get dimensions from red band to use for resampling
with rasterio.open(red_file) as src:
target_height = src.height
target_width = src.width
# Load bands (resample NIR band to match 10m resolution)
blue_data = load_band(blue_file)
green_data = load_band(green_file)
red_data = load_band(red_file)
nir_data = load_band(
nir_file,
resample=True,
target_height=target_height,
target_width=target_width
)
# Print band shapes for debugging
print(f"Band shapes - Blue: {blue_data.shape}, Green: {green_data.shape}, Red: {red_data.shape}, NIR: {nir_data.shape}")
# Compute max values for each channel for dynamic normalization
red_max = np.max(red_data)
green_max = np.max(green_data)
blue_max = np.max(blue_data)
print(f"Max values - Red: {red_max}, Green: {green_max}, Blue: {blue_max}")
# Create RGB image for visualization with dynamic normalization
rgb_image = np.zeros((red_data.shape[0], red_data.shape[1], 3), dtype=np.float32)
# Normalize each channel individually
# Add a small epsilon to avoid division by zero
epsilon = 1e-10
rgb_image[:, :, 0] = red_data / (red_max + epsilon)
rgb_image[:, :, 1] = green_data / (green_max + epsilon)
rgb_image[:, :, 2] = blue_data / (blue_max + epsilon)
# Clip values to 0-1 range
rgb_image = np.clip(rgb_image, 0, 1)
# Optional: Apply contrast enhancement for better visualization
p2 = np.percentile(rgb_image, 2)
p98 = np.percentile(rgb_image, 98)
rgb_image_enhanced = np.clip((rgb_image - p2) / (p98 - p2), 0, 1)
# Stack bands in CHW format for cloud mask prediction (red, green, nir)
prediction_array = np.stack([red_data, green_data, nir_data], axis=0)
return prediction_array, rgb_image_enhanced
def visualize_cloud_mask(rgb_image, pred_mask):
"""
Create a visualization of the cloud mask overlaid on the RGB image.
"""
# Ensure pred_mask has the right dimensions
if pred_mask.ndim > 2:
pred_mask = np.squeeze(pred_mask)
print(f"RGB image shape: {rgb_image.shape}, Pred mask shape: {pred_mask.shape}")
# Ensure mask has the same spatial dimensions as the image
if pred_mask.shape != rgb_image.shape[:2]:
pred_mask = cv2.resize(
pred_mask.astype(np.float32),
(rgb_image.shape[1], rgb_image.shape[0]),
interpolation=cv2.INTER_NEAREST
).astype(np.uint8)
print(f"Resized mask shape: {pred_mask.shape}")
# Define colors for each class
colors = {
0: [0, 255, 0], # Clear - Green
1: [255, 255, 255], # Thick Cloud - White
2: [200, 200, 200], # Thin Cloud - Light Gray
3: [100, 100, 100] # Cloud Shadow - Dark Gray
}
# Create a color-coded mask
mask_vis = np.zeros((pred_mask.shape[0], pred_mask.shape[1], 3), dtype=np.uint8)
for class_idx, color in colors.items():
mask_vis[pred_mask == class_idx] = color
# Create a blended visualization
alpha = 0.5
blended = cv2.addWeighted((rgb_image * 255).astype(np.uint8), 1-alpha, mask_vis, alpha, 0)
# Get the width of the blended image for the legend
image_width = blended.shape[1]
# Create a legend with the same width as the image
legend = np.ones((100, image_width, 3), dtype=np.uint8) * 255
legend_text = ["Clear", "Thick Cloud", "Thin Cloud", "Cloud Shadow"]
legend_colors = [colors[i] for i in range(4)]
for i, (text, color) in enumerate(zip(legend_text, legend_colors)):
cv2.rectangle(legend, (10, 10 + i*20), (30, 30 + i*20), color, -1)
cv2.putText(legend, text, (40, 25 + i*20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1)
# Combine image and legend
final_output = np.vstack([blended, legend])
return final_output
def process_satellite_images(red_file, green_file, blue_file, nir_file, batch_size, patch_size, patch_overlap):
"""
Process the satellite images and detect clouds.
"""
if not all([red_file, green_file, blue_file, nir_file]):
return None, None, "Please upload all four channel files (Red, Green, Blue, NIR)"
# Prepare input array and RGB image for visualization
input_array, rgb_image = prepare_input_array(red_file, green_file, blue_file, nir_file)
# Convert RGB image to format suitable for display
rgb_display = (rgb_image * 255).astype(np.uint8)
# Predict cloud mask using omnicloudmask
pred_mask = predict_from_array(
input_array,
batch_size=batch_size,
patch_size=patch_size,
patch_overlap=patch_overlap
)
# Calculate class distribution
if pred_mask.ndim > 2:
flat_mask = np.squeeze(pred_mask)
else:
flat_mask = pred_mask
clear_pixels = np.sum(flat_mask == 0)
thick_cloud_pixels = np.sum(flat_mask == 1)
thin_cloud_pixels = np.sum(flat_mask == 2)
cloud_shadow_pixels = np.sum(flat_mask == 3)
total_pixels = flat_mask.size
stats = f"""
Cloud Mask Statistics:
- Clear: {clear_pixels} pixels ({clear_pixels/total_pixels*100:.2f}%)
- Thick Cloud: {thick_cloud_pixels} pixels ({thick_cloud_pixels/total_pixels*100:.2f}%)
- Thin Cloud: {thin_cloud_pixels} pixels ({thin_cloud_pixels/total_pixels*100:.2f}%)
- Cloud Shadow: {cloud_shadow_pixels} pixels ({cloud_shadow_pixels/total_pixels*100:.2f}%)
- Total Cloud Cover: {(thick_cloud_pixels + thin_cloud_pixels)/total_pixels*100:.2f}%
"""
# Visualize the cloud mask on the original image
visualization = visualize_cloud_mask(rgb_image, flat_mask)
return rgb_display, visualization, stats
# Create Gradio interface with default examples
demo = gr.Interface(
fn=process_satellite_images,
inputs=[
gr.File(label="Red Channel (JP2)"),
gr.File(label="Green Channel (JP2)"),
gr.File(label="Blue Channel (JP2)"),
gr.File(label="NIR Channel (JP2)"),
gr.Slider(minimum=1, maximum=32, value=1, step=1, label="Batch Size", info="Higher values use more memory but process faster"),
gr.Slider(minimum=500, maximum=2000, value=1000, step=100, label="Patch Size", info="Size of image patches for processing"),
gr.Slider(minimum=100, maximum=500, value=300, step=50, label="Patch Overlap", info="Overlap between patches to avoid edge artifacts")
],
outputs=[
gr.Image(label="Original RGB Image"),
gr.Image(label="Cloud Detection Visualization"),
gr.Textbox(label="Statistics")
],
title="Satellite Cloud Detection",
description="""
Upload separate JP2 files for Red, Green, Blue, and NIR channels to detect clouds in satellite imagery.
This application uses the OmniCloudMask model to classify each pixel as:
- Clear (0)
- Thick Cloud (1)
- Thin Cloud (2)
- Cloud Shadow (3)
The model works best with imagery at 10-50m resolution. For higher resolution imagery, downsampling is recommended.
""",
examples=[
["jp2s/B04.jp2", "jp2s/B03.jp2", "jp2s/B02.jp2", "jp2s/B8A.jp2", 1, 1000, 300]
]
)
# Launch the app
demo.launch(debug=True)