cloud-detection / app.py
truthdotphd's picture
Update app.py
02ded84 verified
# Gradio App Code (based on paste.txt) with Triton Integration and Fallback
import psutil
import gradio as gr
import numpy as np
import os
import cv2
import matplotlib.pyplot as plt
from huggingface_hub import snapshot_download
import rasterio
from rasterio.enums import Resampling
from rasterio.plot import reshape_as_image
import sys
import time # For potential timeouts/delays
# --- Triton Client Imports ---
try:
import tritonclient.http as httpclient
import tritonclient.utils as triton_utils # For InferenceServerException
TRITON_CLIENT_AVAILABLE = True
except ImportError:
print("WARNING: tritonclient is not installed. Triton inference will not be available.")
print("Install using: pip install tritonclient[all]")
TRITON_CLIENT_AVAILABLE = False
httpclient = None # Define dummy to avoid NameErrors later
triton_utils = None
# --- Configuration ---
# Download the entire repository for local fallback and utils
repo_id = "truthdotphd/cloud-detection"
repo_subdir = "."
print(f"Downloading/Checking Hugging Face repo '{repo_id}'...")
repo_dir = snapshot_download(repo_id=repo_id, local_dir=repo_subdir, local_dir_use_symlinks=False) # Use False for symlinks in Gradio/Docker usually
print(f"Repo downloaded/cached at: {repo_dir}")
# Add the repository directory to the Python path for local modules
sys.path.append(repo_dir)
# Import the necessary functions from the downloaded modules for LOCAL fallback
try:
# Adjust path if omnicloudmask is inside a subfolder
omnicloudmask_path = os.path.join(repo_dir, "omnicloudmask")
if os.path.isdir(omnicloudmask_path):
sys.path.append(omnicloudmask_path) # Add subfolder if exists
from omnicloudmask import predict_from_array
LOCAL_MODEL_AVAILABLE = True
print("Local omnicloudmask module loaded successfully.")
except ImportError as e:
print(f"ERROR: Could not import local 'predict_from_array' from omnicloudmask: {e}")
print("Local fallback will not be available.")
LOCAL_MODEL_AVAILABLE = False
predict_from_array = None # Define dummy
# --- Triton Server Configuration ---
TRITON_IP = "206.123.129.87" # Use the public IP provided
HTTP_TRITON_URL = f"{TRITON_IP}:8000"
# GRPC_TRITON_URL = f"{TRITON_IP}:8001" # Keep for potential future use
TRITON_MODEL_NAME = "cloud-detection" # Ensure this matches your deployed model name
TRITON_INPUT_NAME = "input_jp2_bytes" # Ensure this matches your model's config.pbtxt
TRITON_OUTPUT_NAME = "output_mask" # Ensure this matches your model's config.pbtxt
TRITON_TIMEOUT_SECONDS = 300 # 5 minutes timeout for connection/network
# --- Utility Functions (mostly from paste.txt) ---
def visualize_rgb(red_file, green_file, blue_file):
"""
Create and display an RGB visualization immediately after images are uploaded.
(Modified slightly: doesn't need nir_file)
"""
if not all([red_file, green_file, blue_file]):
return None
try:
# Load bands (using load_band utility)
# Get target shape from red band
with rasterio.open(red_file) as src:
target_height = src.height
target_width = src.width
blue_data = load_band(blue_file)
green_data = load_band(green_file)
red_data = load_band(red_file)
# Compute max values for scaling (simple approach)
red_max = np.percentile(red_data[red_data>0], 98) if np.any(red_data>0) else 1.0
green_max = np.percentile(green_data[green_data>0], 98) if np.any(green_data>0) else 1.0
blue_max = np.percentile(blue_data[blue_data>0], 98) if np.any(blue_data>0) else 1.0
# Create RGB image for visualization with dynamic normalization
rgb_image = np.zeros((red_data.shape[0], red_data.shape[1], 3), dtype=np.float32)
epsilon = 1e-10
rgb_image[:, :, 0] = np.clip(red_data / (red_max + epsilon), 0, 1)
rgb_image[:, :, 1] = np.clip(green_data / (green_max + epsilon), 0, 1)
rgb_image[:, :, 2] = np.clip(blue_data / (blue_max + epsilon), 0, 1)
# Simple brightness/contrast adjustment (gamma correction)
gamma = 1.8
rgb_image_enhanced = np.power(rgb_image, 1/gamma)
# Convert to uint8 for display
rgb_display = (rgb_image_enhanced * 255).astype(np.uint8)
return rgb_display
except Exception as e:
print(f"Error generating RGB preview: {e}")
import traceback
traceback.print_exc()
return None
def visualize_jp2(file_path):
"""
Visualize a single JP2 file. (Unchanged from paste.txt)
"""
try:
with rasterio.open(file_path) as src:
data = src.read(1)
# Check if data is all zero or invalid
if np.all(data == 0) or np.ptp(data) == 0:
print(f"Warning: Data in {file_path} is constant or zero. Cannot normalize.")
# Return a black image or handle as appropriate
return np.zeros((src.height, src.width, 3), dtype=np.uint8)
# Normalize the data for visualization
data_norm = (data - np.min(data)) / (np.max(data) - np.min(data))
# Apply a colormap for better visualization
cmap = plt.get_cmap('viridis')
colored_image = cmap(data_norm)
# Convert to 8-bit for display
return (colored_image[:, :, :3] * 255).astype(np.uint8)
except Exception as e:
print(f"Error visualizing JP2 file {file_path}: {e}")
return None
def load_band(file_path, resample=False, target_height=None, target_width=None):
"""
Load a single band from a raster file with optional resampling. (Unchanged from paste.txt)
"""
try:
with rasterio.open(file_path) as src:
if resample and target_height is not None and target_width is not None:
# Ensure output shape matches target channels (1 for single band)
out_shape = (1, target_height, target_width)
band_data = src.read(
out_shape=out_shape,
resampling=Resampling.bilinear
)[0].astype(np.float32) # Read only the first band after resampling
else:
band_data = src.read(1).astype(np.float32) # Read only the first band
return band_data
except Exception as e:
print(f"Error loading band {file_path}: {e}")
raise # Re-raise error to be caught by calling function
def prepare_input_array(red_file, green_file, blue_file, nir_file):
"""
Prepare a stacked array (R, G, NIR) for the LOCAL model and an RGB image for visualization.
(Slightly modified from paste.txt to handle potential loading errors)
Returns:
prediction_array (np.ndarray): Stacked array (R,G,NIR) for local model, or None on error.
rgb_image_enhanced (np.ndarray): RGB image (0-1 float) for visualization, or None on error.
"""
try:
# Get dimensions from red band to use for resampling
with rasterio.open(red_file) as src:
target_height = src.height
target_width = src.width
# Load bands (resample NIR band to match 10m resolution)
blue_data = load_band(blue_file) # Needed for RGB viz
green_data = load_band(green_file)
red_data = load_band(red_file)
nir_data = load_band(
nir_file,
resample=True,
target_height=target_height,
target_width=target_width
)
# --- Prepare RGB Image for Visualization (similar to visualize_rgb but returns float array) ---
red_max = np.percentile(red_data[red_data>0], 98) if np.any(red_data>0) else 1.0
green_max = np.percentile(green_data[green_data>0], 98) if np.any(green_data>0) else 1.0
blue_max = np.percentile(blue_data[blue_data>0], 98) if np.any(blue_data>0) else 1.0
epsilon = 1e-10
rgb_image = np.zeros((target_height, target_width, 3), dtype=np.float32)
rgb_image[:, :, 0] = np.clip(red_data / (red_max + epsilon), 0, 1)
rgb_image[:, :, 1] = np.clip(green_data / (green_max + epsilon), 0, 1)
rgb_image[:, :, 2] = np.clip(blue_data / (blue_max + epsilon), 0, 1)
# Apply gamma correction for enhancement
gamma = 1.8
rgb_image_enhanced = np.power(rgb_image, 1/gamma)
# --- End RGB Image Preparation ---
# Stack bands in CHW format for LOCAL cloud mask prediction (red, green, nir)
# Ensure all bands have the same shape before stacking
if not (red_data.shape == green_data.shape == nir_data.shape):
print("ERROR: Band shapes mismatch after loading/resampling!")
print(f"Shapes - Red: {red_data.shape}, Green: {green_data.shape}, NIR: {nir_data.shape}")
return None, None # Indicate error
prediction_array = np.stack([red_data, green_data, nir_data], axis=0) # CHW format
print(f"Local prediction array shape: {prediction_array.shape}")
print(f"RGB visualization image shape: {rgb_image_enhanced.shape}")
return prediction_array, rgb_image_enhanced
except Exception as e:
print(f"Error during input preparation: {e}")
import traceback
traceback.print_exc()
return None, None # Indicate error
def visualize_cloud_mask(rgb_image, pred_mask):
"""
Create a visualization of the cloud mask overlaid on the RGB image.
(Unchanged from paste.txt, but added error checks)
"""
if rgb_image is None or pred_mask is None:
print("Cannot visualize cloud mask: Missing RGB image or prediction mask.")
return None
try:
# Ensure pred_mask has the right dimensions (H, W)
if pred_mask.ndim == 3 and pred_mask.shape[0] == 1: # Squeeze channel dim if present
pred_mask = np.squeeze(pred_mask, axis=0)
elif pred_mask.ndim != 2:
print(f"ERROR: Unexpected prediction mask dimension: {pred_mask.ndim}, shape: {pred_mask.shape}")
# Attempt to squeeze if possible, otherwise fail
try:
pred_mask = np.squeeze(pred_mask)
if pred_mask.ndim != 2: raise ValueError("Still not 2D after squeeze")
except Exception as sq_err:
print(f"Could not convert mask to 2D: {sq_err}")
return None # Cannot visualize
print(f"Visualization - RGB image shape: {rgb_image.shape}, Pred mask shape: {pred_mask.shape}")
# Ensure mask has the same spatial dimensions as the image
if pred_mask.shape != rgb_image.shape[:2]:
print(f"Warning: Resizing prediction mask from {pred_mask.shape} to {rgb_image.shape[:2]} for visualization.")
# Ensure mask is integer type for nearest neighbor interpolation
if not np.issubdtype(pred_mask.dtype, np.integer):
print("Warning: Prediction mask is not integer type, casting to uint8 for resize.")
pred_mask = pred_mask.astype(np.uint8)
pred_mask_resized = cv2.resize(
pred_mask,
(rgb_image.shape[1], rgb_image.shape[0]), # Target shape (width, height) for cv2.resize
interpolation=cv2.INTER_NEAREST # Use nearest to preserve class labels
)
pred_mask = pred_mask_resized
print(f"Resized mask shape: {pred_mask.shape}")
# Define colors for each class
colors = {
0: [0, 255, 0], # Clear - Green
1: [255, 0, 0], # Thick Cloud - Red (Changed from White for better contrast)
2: [255, 255, 0], # Thin Cloud - Yellow (Changed from Gray)
3: [0, 0, 255] # Cloud Shadow - Blue (Changed from Gray)
}
# Create a color-coded mask visualization
mask_vis = np.zeros((pred_mask.shape[0], pred_mask.shape[1], 3), dtype=np.uint8)
for class_idx, color in colors.items():
# Handle potential out-of-bounds class indices in mask
mask_vis[pred_mask == class_idx] = color
# Create a blended visualization
alpha = 0.4 # Transparency of the mask overlay
# Ensure rgb_image is uint8 for blending
rgb_uint8 = (np.clip(rgb_image, 0, 1) * 255).astype(np.uint8)
blended = cv2.addWeighted(rgb_uint8, 1-alpha, mask_vis, alpha, 0)
# --- Create Legend ---
legend_height = 100
legend_width = blended.shape[1] # Match image width
legend = np.ones((legend_height, legend_width, 3), dtype=np.uint8) * 255 # White background
legend_text = ["Clear", "Thick Cloud", "Thin Cloud", "Cloud Shadow"]
legend_colors = [colors.get(i, [0,0,0]) for i in range(4)] # Use .get for safety
box_size = 15
text_offset_x = 40
start_y = 15
padding_y = 20
for i, (text, color) in enumerate(zip(legend_text, legend_colors)):
# Draw color box
cv2.rectangle(legend,
(10, start_y + i*padding_y - box_size // 2),
(10 + box_size, start_y + i*padding_y + box_size // 2),
color, -1)
# Draw text
cv2.putText(legend, text,
(text_offset_x, start_y + i*padding_y + box_size // 4), # Adjust vertical alignment
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)
# --- End Legend ---
# Combine image and legend
final_output = np.vstack([blended, legend])
return final_output
except Exception as e:
print(f"Error during visualization: {e}")
import traceback
traceback.print_exc()
return None # Return None if visualization fails
# --- Triton Client Functions (Adapted from paste-2.txt) ---
def is_triton_server_healthy(url=HTTP_TRITON_URL):
"""Checks if the Triton Inference Server is live."""
if not TRITON_CLIENT_AVAILABLE:
return False
try:
triton_client = httpclient.InferenceServerClient(url=url, connection_timeout=10.0) # Short timeout for health check
server_live = triton_client.is_server_live()
if server_live:
print(f"Triton server at {url} is live.")
# Optionally check readiness:
# server_ready = triton_client.is_server_ready()
# print(f"Triton server at {url} is ready: {server_ready}")
# return server_ready
else:
print(f"Triton server at {url} is not live.")
return server_live
except Exception as e:
print(f"Could not connect to Triton server at {url}: {e}")
return False
def get_jp2_bytes_for_triton(red_file_path, green_file_path, nir_file_path):
"""
Reads the raw bytes of Red, Green, and NIR JP2 files for Triton.
Order: Red, Green, NIR (must match Triton model input expectation)
"""
byte_list = []
files_to_read = [red_file_path, green_file_path, nir_file_path]
band_names = ['Red', 'Green', 'NIR']
for file_path, band_name in zip(files_to_read, band_names):
try:
with open(file_path, "rb") as f:
file_bytes = f.read()
byte_list.append(file_bytes)
print(f"Read {len(file_bytes)} bytes for {band_name} band from {os.path.basename(file_path)}")
except FileNotFoundError:
print(f"ERROR: File not found: {file_path}")
raise # Propagate error
except Exception as e:
print(f"ERROR: Could not read file {file_path}: {e}")
raise # Propagate error
# Create NumPy array of object type to hold bytes
input_byte_array = np.array(byte_list, dtype=object)
# Expected shape is (3,) -> a 1D array containing 3 byte objects
print(f"Prepared Triton input byte array with shape: {input_byte_array.shape} and dtype: {input_byte_array.dtype}")
return input_byte_array
def run_inference_triton_http(input_byte_array):
"""
Run inference using Triton HTTP client with raw JP2 bytes.
"""
if not TRITON_CLIENT_AVAILABLE:
raise RuntimeError("Triton client library not available.")
print("Attempting inference using Triton HTTP client...")
try:
client = httpclient.InferenceServerClient(
url=HTTP_TRITON_URL,
verbose=False,
connection_timeout=TRITON_TIMEOUT_SECONDS,
network_timeout=TRITON_TIMEOUT_SECONDS
)
except Exception as e:
print(f"ERROR: Couldn't create Triton HTTP client: {e}")
raise # Propagate error
# Prepare input tensor (BYTES type)
# Shape [3] matches the 1D numpy array holding 3 byte strings
inputs = [httpclient.InferInput(TRITON_INPUT_NAME, input_byte_array.shape, "BYTES")]
inputs[0].set_data_from_numpy(input_byte_array, binary_data=True) # binary_data=True is important for BYTES
# Prepare output tensor request
outputs = [httpclient.InferRequestedOutput(TRITON_OUTPUT_NAME, binary_data=True)]
# Send inference request
try:
print(f"Sending inference request to Triton model '{TRITON_MODEL_NAME}' at {HTTP_TRITON_URL}...")
response = client.infer(
model_name=TRITON_MODEL_NAME,
inputs=inputs,
outputs=outputs,
request_id=str(os.getpid()), # Optional request ID
timeout=TRITON_TIMEOUT_SECONDS
)
print("Triton inference request successful.")
mask = response.as_numpy(TRITON_OUTPUT_NAME)
print(f"Received output mask from Triton with shape: {mask.shape}, dtype: {mask.dtype}")
return mask
except triton_utils.InferenceServerException as e:
print(f"ERROR: Triton server failed inference: Status code {e.status()}, message: {e.message()}")
print(f"Debug details: {e.debug_details()}")
raise # Propagate error to trigger fallback
except Exception as e:
print(f"ERROR: An unexpected error occurred during Triton HTTP inference: {e}")
import traceback
traceback.print_exc()
raise # Propagate error to trigger fallback
# --- Main Processing Function with Fallback Logic ---
def process_satellite_images(red_file, green_file, blue_file, nir_file, batch_size, patch_size, patch_overlap):
"""
Process satellite images: Try Triton first, fallback to local model.
"""
if not all([red_file, green_file, blue_file, nir_file]):
return None, None, "ERROR: Please upload all four channel files (Red, Green, Blue, NIR)"
# Store file paths from Gradio Image components
red_file_path = red_file if isinstance(red_file, str) else red_file.name
green_file_path = green_file if isinstance(green_file, str) else green_file.name
blue_file_path = blue_file if isinstance(blue_file, str) else blue_file.name
nir_file_path = nir_file if isinstance(nir_file, str) else nir_file.name
print("\n--- Starting Cloud Detection Process ---")
print(f"Input files: R={os.path.basename(red_file_path)}, G={os.path.basename(green_file_path)}, B={os.path.basename(blue_file_path)}, N={os.path.basename(nir_file_path)}")
pred_mask = None
status_message = ""
rgb_display_image = None # For the raw RGB output panel
rgb_float_image = None # For overlay visualization
# 1. Prepare Visualization Image (always needed) & Local Input Array (needed for fallback)
print("Preparing visualization image and local model input array...")
local_input_array, rgb_float_image = prepare_input_array(red_file_path, green_file_path, blue_file_path, nir_file_path)
if rgb_float_image is not None:
# Convert float image (0-1) to uint8 (0-255) for the RGB output panel
rgb_display_image = (np.clip(rgb_float_image, 0, 1) * 255).astype(np.uint8)
else:
print("ERROR: Failed to create RGB visualization image.")
# Return early if visualization prep failed, as likely indicates file loading issues
return None, None, "ERROR: Failed to load or process input band files."
# 2. Check Triton Server Health
use_triton = False
if TRITON_CLIENT_AVAILABLE:
print(f"Checking Triton server health at {HTTP_TRITON_URL}...")
if is_triton_server_healthy(HTTP_TRITON_URL):
use_triton = True
else:
print("Triton server is not healthy or unavailable.")
status_message += "Triton server unavailable. "
else:
print("Triton client library not installed. Skipping Triton check.")
status_message += "Triton client not installed. "
# 3. Attempt Triton Inference if Healthy
if use_triton:
try:
print("Preparing JP2 bytes for Triton...")
# Use Red, Green, NIR file paths
triton_byte_input = get_jp2_bytes_for_triton(red_file_path, green_file_path, nir_file_path)
pred_mask = run_inference_triton_http(triton_byte_input)
status_message += "Inference performed using Triton Server. "
print("Triton inference successful.")
except Exception as e:
print(f"Triton inference failed: {e}. Falling back to local model.")
status_message += f"Triton inference failed ({type(e).__name__}). "
pred_mask = None # Ensure mask is None to trigger fallback
use_triton = False # Explicitly mark Triton as not used
# 4. Fallback to Local Model if Triton failed or wasn't available/healthy
if pred_mask is None: # Check if mask wasn't obtained from Triton
status_message += "Falling back to local inference. "
if LOCAL_MODEL_AVAILABLE and local_input_array is not None:
print("Running local inference using omnicloudmask...")
try:
# Predict cloud mask using local omnicloudmask
pred_mask = predict_from_array(
local_input_array,
batch_size=batch_size,
patch_size=patch_size,
patch_overlap=patch_overlap
)
print(f"Local prediction successful. Output mask shape: {pred_mask.shape}, dtype: {pred_mask.dtype}")
status_message += "Local inference successful."
except Exception as e:
print(f"ERROR: Local inference failed: {e}")
import traceback
traceback.print_exc()
status_message += f"Local inference FAILED: {e}"
# Keep pred_mask as None
elif not LOCAL_MODEL_AVAILABLE:
status_message += "Local model not available. Cannot perform inference."
print("ERROR: Local model could not be loaded.")
elif local_input_array is None:
status_message += "Local input data preparation failed. Cannot perform local inference."
print("ERROR: Failed to prepare input array for local model.")
else:
status_message += "Unknown state, cannot perform inference." # Should not happen
# 5. Process Results (Stats and Visualization) if mask was generated
if pred_mask is not None:
# Ensure mask is squeezed to 2D if necessary (local model might return extra dim)
if pred_mask.ndim == 3 and pred_mask.shape[0] == 1:
flat_mask = np.squeeze(pred_mask, axis=0)
elif pred_mask.ndim == 2:
flat_mask = pred_mask
else:
print(f"ERROR: Unexpected mask shape after inference: {pred_mask.shape}")
status_message += " ERROR: Invalid mask shape received."
flat_mask = None # Invalidate mask
if flat_mask is not None:
# Calculate class distribution
clear_pixels = np.sum(flat_mask == 0)
thick_cloud_pixels = np.sum(flat_mask == 1)
thin_cloud_pixels = np.sum(flat_mask == 2)
cloud_shadow_pixels = np.sum(flat_mask == 3)
total_pixels = flat_mask.size
stats = f"""
Cloud Mask Statistics ({'Triton' if use_triton else 'Local'}):
- Clear: {clear_pixels} pixels ({clear_pixels/total_pixels*100:.2f}%)
- Thick Cloud: {thick_cloud_pixels} pixels ({thick_cloud_pixels/total_pixels*100:.2f}%)
- Thin Cloud: {thin_cloud_pixels} pixels ({thin_cloud_pixels/total_pixels*100:.2f}%)
- Cloud Shadow: {cloud_shadow_pixels} pixels ({cloud_shadow_pixels/total_pixels*100:.2f}%)
- Total Cloud Cover (Thick+Thin): {(thick_cloud_pixels + thin_cloud_pixels)/total_pixels*100:.2f}%
"""
status_message += f"\nMask stats calculated. Total pixels: {total_pixels}."
# Visualize the cloud mask on the original image
print("Generating final visualization...")
visualization = visualize_cloud_mask(rgb_float_image, flat_mask) # Use float image for viz function
if visualization is None:
status_message += " ERROR: Failed to generate visualization."
print("--- Cloud Detection Process Finished ---")
return rgb_display_image, visualization, status_message + "\n" + stats
else:
# Mask had wrong shape
return rgb_display_image, None, status_message + "\nERROR: Could not process prediction mask."
else:
# Inference failed both ways or initial loading failed
print("--- Cloud Detection Process Failed ---")
return rgb_display_image, None, status_message + "\nERROR: Could not generate cloud mask."
# --- Gradio Interface (from paste.txt) ---
def check_cpu_usage():
"""Check and return the current CPU usage."""
return f"CPU Usage: {psutil.cpu_percent()}%"
# --- Build Gradio App ---
print("Building Gradio interface...")
with gr.Blocks(title="Satellite Cloud Detection (Triton/Local)") as demo:
gr.Markdown("""
# Satellite Cloud Detection (with Triton Fallback)
Upload separate JP2 files for Red (e.g., B04), Green (e.g., B03), Blue (e.g., B02), and NIR (e.g., B8A) channels.
The application will **first attempt** to use a remote Triton Inference Server. If the server is unavailable or inference fails,
it will **fall back** to using the local OmniCloudMask model.
**Pixel Classification:**
- Clear (Green)
- Thick Cloud (Red)
- Thin Cloud (Yellow)
- Cloud Shadow (Blue)
The model works best with imagery at 10-50m resolution.
""")
# Main cloud detection interface
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Input Bands (JP2)")
# Use filepaths which are needed for both local reading and byte reading
red_input = gr.File(label="Red Channel (e.g., B04)", type="filepath")
green_input = gr.File(label="Green Channel (e.g., B03)", type="filepath")
blue_input = gr.File(label="Blue Channel (e.g., B02)", type="filepath")
nir_input = gr.File(label="NIR Channel (e.g., B8A)", type="filepath")
gr.Markdown("### Local Model Parameters (Used for Fallback)")
batch_size = gr.Slider(minimum=1, maximum=32, value=4, step=1,
label="Batch Size",
info="Memory usage/speed for local model")
patch_size = gr.Slider(minimum=256, maximum=2048, value=1024, step=128,
label="Patch Size",
info="Patch size for local model processing")
patch_overlap = gr.Slider(minimum=64, maximum=512, value=256, step=64,
label="Patch Overlap",
info="Overlap for local model processing")
process_btn = gr.Button("Process Cloud Detection", variant="primary")
with gr.Column(scale=2):
gr.Markdown("### Results")
# Output components
rgb_output = gr.Image(label="Original RGB Image (Approx. True Color)", type="numpy")
cloud_output = gr.Image(label="Cloud Detection Visualization (Mask Overlay)", type="numpy")
stats_output = gr.Textbox(label="Processing Status & Statistics", lines=10)
# CPU usage monitoring section (Optional)
with gr.Accordion("System Monitoring", open=False):
cpu_button = gr.Button("Check CPU Usage")
cpu_output = gr.Textbox(label="Current CPU Usage")
cpu_button.click(fn=check_cpu_usage, inputs=None, outputs=cpu_output)
# Examples section
# Ensure example paths are relative to where the script is run,
# or absolute if needed. Assumes 'jp2s' folder is present.
example_base = os.path.join(repo_dir, "jp2s") # Use downloaded repo path
example_files = [
os.path.join(example_base, "B04.jp2"), # Red
os.path.join(example_base, "B03.jp2"), # Green
os.path.join(example_base, "B02.jp2"), # Blue
os.path.join(example_base, "B8A.jp2") # NIR
]
# Check if example files actually exist before adding example
if all(os.path.exists(f) for f in example_files):
print("Adding examples...")
gr.Examples(
examples=[example_files + [4, 1024, 256]], # Corresponds to inputs below
inputs=[red_input, green_input, blue_input, nir_input, batch_size, patch_size, patch_overlap],
outputs=[rgb_output, cloud_output, stats_output], # Define outputs for examples too
fn=process_satellite_images, # Function to run for examples
cache_examples=False # Maybe disable caching if files change or for debugging
)
else:
print(f"WARN: Example JP2 files not found in '{example_base}'. Skipping examples.")
# Setup main button click handler
process_btn.click(
fn=process_satellite_images,
inputs=[red_input, green_input, blue_input, nir_input, batch_size, patch_size, patch_overlap],
outputs=[rgb_output, cloud_output, stats_output]
)
# --- Launch the App ---
print("Launching Gradio app...")
# Allow queueing and potentially increase workers if needed
demo.queue(default_concurrency_limit=4).launch(debug=True, share=False) # share=True for public link if needed