Spaces:
Running
Running
import psutil | |
import gradio as gr | |
import numpy as np | |
import os | |
import cv2 | |
import matplotlib.pyplot as plt | |
from huggingface_hub import snapshot_download | |
import rasterio | |
from rasterio.enums import Resampling | |
from rasterio.plot import reshape_as_image | |
import sys | |
# Download the entire repository to a subdirectory | |
repo_id = "truthdotphd/cloud-detection" | |
repo_subdir = "." | |
repo_dir = snapshot_download(repo_id=repo_id, local_dir=repo_subdir) | |
# Add the repository directory to the Python path | |
sys.path.append(repo_dir) | |
# Import the necessary functions from the downloaded modules | |
try: | |
from omnicloudmask import predict_from_array | |
except ImportError: | |
omnicloudmask_dir = os.path.join(repo_dir, "omnicloudmask") | |
if os.path.exists(omnicloudmask_dir): | |
sys.path.append(omnicloudmask_dir) | |
from omnicloudmask import predict_from_array | |
else: | |
raise ImportError("Could not find the omnicloudmask module in the downloaded repository") | |
def visualize_rgb(red_file, green_file, blue_file, nir_file): | |
""" | |
Create and display an RGB visualization immediately after images are uploaded. | |
""" | |
if not all([red_file, green_file, blue_file, nir_file]): | |
return None | |
try: | |
# Get dimensions from red band to use for resampling | |
with rasterio.open(red_file) as src: | |
target_height = src.height | |
target_width = src.width | |
# Load bands | |
blue_data = load_band(blue_file) | |
green_data = load_band(green_file) | |
red_data = load_band(red_file) | |
# Compute max values for each channel for dynamic normalization | |
red_max = np.max(red_data) | |
green_max = np.max(green_data) | |
blue_max = np.max(blue_data) | |
# Create RGB image for visualization with dynamic normalization | |
rgb_image = np.zeros((red_data.shape[0], red_data.shape[1], 3), dtype=np.float32) | |
# Normalize each channel individually | |
epsilon = 1e-10 | |
rgb_image[:, :, 0] = red_data / (red_max + epsilon) | |
rgb_image[:, :, 1] = green_data / (green_max + epsilon) | |
rgb_image[:, :, 2] = blue_data / (blue_max + epsilon) | |
# Clip values to 0-1 range | |
rgb_image = np.clip(rgb_image, 0, 1) | |
# Apply contrast enhancement for better visualization | |
p2 = np.percentile(rgb_image, 2) | |
p98 = np.percentile(rgb_image, 98) | |
rgb_image_enhanced = np.clip((rgb_image - p2) / (p98 - p2), 0, 1) | |
# Convert to uint8 for display | |
rgb_display = (rgb_image_enhanced * 255).astype(np.uint8) | |
return rgb_display | |
except Exception as e: | |
print(f"Error generating RGB preview: {e}") | |
return None | |
def visualize_jp2(file_path): | |
""" | |
Visualize a single JP2 file. | |
""" | |
with rasterio.open(file_path) as src: | |
# Read the data | |
data = src.read(1) | |
# Normalize the data for visualization | |
data = (data - np.min(data)) / (np.max(data) - np.min(data)) | |
# Apply a colormap for better visualization | |
cmap = plt.get_cmap('viridis') | |
colored_image = cmap(data) | |
# Convert to 8-bit for display | |
return (colored_image[:, :, :3] * 255).astype(np.uint8) | |
def load_band(file_path, resample=False, target_height=None, target_width=None): | |
""" | |
Load a single band from a raster file with optional resampling. | |
""" | |
with rasterio.open(file_path) as src: | |
if resample and target_height is not None and target_width is not None: | |
band_data = src.read( | |
out_shape=(src.count, target_height, target_width), | |
resampling=Resampling.bilinear | |
)[0].astype(np.float32) | |
else: | |
band_data = src.read()[0].astype(np.float32) | |
return band_data | |
def prepare_input_array(red_file, green_file, blue_file, nir_file): | |
""" | |
Prepare a stacked array of satellite bands for cloud mask prediction. | |
""" | |
# Get dimensions from red band to use for resampling | |
with rasterio.open(red_file) as src: | |
target_height = src.height | |
target_width = src.width | |
# Load bands (resample NIR band to match 10m resolution) | |
blue_data = load_band(blue_file) | |
green_data = load_band(green_file) | |
red_data = load_band(red_file) | |
nir_data = load_band( | |
nir_file, | |
resample=True, | |
target_height=target_height, | |
target_width=target_width | |
) | |
# Print band shapes for debugging | |
print(f"Band shapes - Blue: {blue_data.shape}, Green: {green_data.shape}, Red: {red_data.shape}, NIR: {nir_data.shape}") | |
# Compute max values for each channel for dynamic normalization | |
red_max = np.max(red_data) | |
green_max = np.max(green_data) | |
blue_max = np.max(blue_data) | |
print(f"Max values - Red: {red_max}, Green: {green_max}, Blue: {blue_max}") | |
# Create RGB image for visualization with dynamic normalization | |
rgb_image = np.zeros((red_data.shape[0], red_data.shape[1], 3), dtype=np.float32) | |
# Normalize each channel individually | |
# Add a small epsilon to avoid division by zero | |
epsilon = 1e-10 | |
rgb_image[:, :, 0] = red_data / (red_max + epsilon) | |
rgb_image[:, :, 1] = green_data / (green_max + epsilon) | |
rgb_image[:, :, 2] = blue_data / (blue_max + epsilon) | |
# Clip values to 0-1 range | |
rgb_image = np.clip(rgb_image, 0, 1) | |
# Optional: Apply contrast enhancement for better visualization | |
p2 = np.percentile(rgb_image, 2) | |
p98 = np.percentile(rgb_image, 98) | |
rgb_image_enhanced = np.clip((rgb_image - p2) / (p98 - p2), 0, 1) | |
# Stack bands in CHW format for cloud mask prediction (red, green, nir) | |
prediction_array = np.stack([red_data, green_data, nir_data], axis=0) | |
return prediction_array, rgb_image_enhanced | |
def visualize_cloud_mask(rgb_image, pred_mask): | |
""" | |
Create a visualization of the cloud mask overlaid on the RGB image. | |
""" | |
# Ensure pred_mask has the right dimensions | |
if pred_mask.ndim > 2: | |
pred_mask = np.squeeze(pred_mask) | |
print(f"RGB image shape: {rgb_image.shape}, Pred mask shape: {pred_mask.shape}") | |
# Ensure mask has the same spatial dimensions as the image | |
if pred_mask.shape != rgb_image.shape[:2]: | |
pred_mask = cv2.resize( | |
pred_mask.astype(np.float32), | |
(rgb_image.shape[1], rgb_image.shape[0]), | |
interpolation=cv2.INTER_NEAREST | |
).astype(np.uint8) | |
print(f"Resized mask shape: {pred_mask.shape}") | |
# Define colors for each class | |
colors = { | |
0: [0, 255, 0], # Clear - Green | |
1: [255, 255, 255], # Thick Cloud - White | |
2: [200, 200, 200], # Thin Cloud - Light Gray | |
3: [100, 100, 100] # Cloud Shadow - Dark Gray | |
} | |
# Create a color-coded mask | |
mask_vis = np.zeros((pred_mask.shape[0], pred_mask.shape[1], 3), dtype=np.uint8) | |
for class_idx, color in colors.items(): | |
mask_vis[pred_mask == class_idx] = color | |
# Create a blended visualization | |
alpha = 0.5 | |
blended = cv2.addWeighted((rgb_image * 255).astype(np.uint8), 1-alpha, mask_vis, alpha, 0) | |
# Get the width of the blended image for the legend | |
image_width = blended.shape[1] | |
# Create a legend with the same width as the image | |
legend = np.ones((100, image_width, 3), dtype=np.uint8) * 255 | |
legend_text = ["Clear", "Thick Cloud", "Thin Cloud", "Cloud Shadow"] | |
legend_colors = [colors[i] for i in range(4)] | |
for i, (text, color) in enumerate(zip(legend_text, legend_colors)): | |
cv2.rectangle(legend, (10, 10 + i*20), (30, 30 + i*20), color, -1) | |
cv2.putText(legend, text, (40, 25 + i*20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1) | |
# Combine image and legend | |
final_output = np.vstack([blended, legend]) | |
return final_output | |
def process_satellite_images(red_file, green_file, blue_file, nir_file, batch_size, patch_size, patch_overlap): | |
""" | |
Process the satellite images and detect clouds. | |
""" | |
if not all([red_file, green_file, blue_file, nir_file]): | |
return None, None, "Please upload all four channel files (Red, Green, Blue, NIR)" | |
# Prepare input array and RGB image for visualization | |
input_array, rgb_image = prepare_input_array(red_file, green_file, blue_file, nir_file) | |
# Convert RGB image to format suitable for display | |
rgb_display = (rgb_image * 255).astype(np.uint8) | |
# Predict cloud mask using omnicloudmask | |
pred_mask = predict_from_array( | |
input_array, | |
batch_size=batch_size, | |
patch_size=patch_size, | |
patch_overlap=patch_overlap | |
) | |
# Calculate class distribution | |
if pred_mask.ndim > 2: | |
flat_mask = np.squeeze(pred_mask) | |
else: | |
flat_mask = pred_mask | |
clear_pixels = np.sum(flat_mask == 0) | |
thick_cloud_pixels = np.sum(flat_mask == 1) | |
thin_cloud_pixels = np.sum(flat_mask == 2) | |
cloud_shadow_pixels = np.sum(flat_mask == 3) | |
total_pixels = flat_mask.size | |
stats = f""" | |
Cloud Mask Statistics: | |
- Clear: {clear_pixels} pixels ({clear_pixels/total_pixels*100:.2f}%) | |
- Thick Cloud: {thick_cloud_pixels} pixels ({thick_cloud_pixels/total_pixels*100:.2f}%) | |
- Thin Cloud: {thin_cloud_pixels} pixels ({thin_cloud_pixels/total_pixels*100:.2f}%) | |
- Cloud Shadow: {cloud_shadow_pixels} pixels ({cloud_shadow_pixels/total_pixels*100:.2f}%) | |
- Total Cloud Cover: {(thick_cloud_pixels + thin_cloud_pixels)/total_pixels*100:.2f}% | |
""" | |
# Visualize the cloud mask on the original image | |
visualization = visualize_cloud_mask(rgb_image, flat_mask) | |
return rgb_display, visualization, stats | |
def update_cpu(): | |
return f"CPU Usage: {psutil.cpu_percent()}%" | |
with gr.Blocks() as demo: | |
cpu_text = gr.Textbox(label="CPU Usage") | |
check_cpu_btn = gr.Button("Check CPU") | |
# Attach the event handler using the click method | |
check_cpu_btn.click(fn=update_cpu, inputs=None, outputs=cpu_text) | |
# Define the CPU check function | |
def check_cpu_usage(): | |
"""Check and return the current CPU usage.""" | |
return f"CPU Usage: {psutil.cpu_percent()}%" | |
# Create the Gradio application with Blocks | |
with gr.Blocks(title="Satellite Cloud Detection") as demo: | |
# Add the description | |
gr.Markdown(""" | |
# Satellite Cloud Detection | |
Upload separate JP2 files for Red, Green, Blue, and NIR channels to detect clouds in satellite imagery. | |
This application uses the OmniCloudMask model to classify each pixel as: | |
- Clear (0) | |
- Thick Cloud (1) | |
- Thin Cloud (2) | |
- Cloud Shadow (3) | |
The model works best with imagery at 10-50m resolution. For higher resolution imagery, downsampling is recommended. | |
""") | |
# Main cloud detection interface | |
with gr.Row(): | |
with gr.Column(): | |
# Input components | |
red_input = gr.Image(type="filepath", label="Red Channel (JP2)") | |
green_input = gr.Image(type="filepath", label="Green Channel (JP2)") | |
blue_input = gr.Image(type="filepath", label="Blue Channel (JP2)") | |
nir_input = gr.Image(type="filepath", label="NIR Channel (JP2)") | |
batch_size = gr.Slider(minimum=1, maximum=32, value=1, step=1, | |
label="Batch Size", | |
info="Higher values use more memory but process faster") | |
patch_size = gr.Slider(minimum=500, maximum=2000, value=1000, step=100, | |
label="Patch Size", | |
info="Size of image patches for processing") | |
patch_overlap = gr.Slider(minimum=100, maximum=500, value=300, step=50, | |
label="Patch Overlap", | |
info="Overlap between patches to avoid edge artifacts") | |
process_btn = gr.Button("Process Cloud Detection") | |
with gr.Column(): | |
# Output components | |
rgb_output = gr.Image(label="Original RGB Image") | |
cloud_output = gr.Image(label="Cloud Detection Visualization") | |
stats_output = gr.Textbox(label="Statistics") | |
# CPU usage monitoring section | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("## System Monitoring") | |
cpu_button = gr.Button("Check CPU Usage") | |
cpu_output = gr.Textbox(label="CPU Usage") | |
# Set up event handlers | |
process_btn.click( | |
fn=process_satellite_images, | |
inputs=[red_input, green_input, blue_input, nir_input, batch_size, patch_size, patch_overlap], | |
outputs=[rgb_output, cloud_output, stats_output] | |
) | |
cpu_button.click( | |
fn=check_cpu_usage, | |
inputs=None, | |
outputs=cpu_output | |
) | |
# Add examples | |
gr.Examples( | |
examples=[["jp2s/B04.jp2", "jp2s/B03.jp2", "jp2s/B02.jp2", "jp2s/B8A.jp2", 1, 1000, 300]], | |
inputs=[red_input, green_input, blue_input, nir_input, batch_size, patch_size, patch_overlap] | |
) | |
# Launch the app | |
demo.queue(default_concurrency_limit=8).launch(debug=True) |