Spaces:
Sleeping
Sleeping
# Gradio App Code (based on paste.txt) with Triton Integration and Fallback | |
import psutil | |
import gradio as gr | |
import numpy as np | |
import os | |
import cv2 | |
import matplotlib.pyplot as plt | |
from huggingface_hub import snapshot_download | |
import rasterio | |
from rasterio.enums import Resampling | |
from rasterio.plot import reshape_as_image | |
import sys | |
import time # For potential timeouts/delays | |
# --- Triton Client Imports --- | |
try: | |
import tritonclient.http as httpclient | |
import tritonclient.utils as triton_utils # For InferenceServerException | |
TRITON_CLIENT_AVAILABLE = True | |
except ImportError: | |
print("WARNING: tritonclient is not installed. Triton inference will not be available.") | |
print("Install using: pip install tritonclient[all]") | |
TRITON_CLIENT_AVAILABLE = False | |
httpclient = None # Define dummy to avoid NameErrors later | |
triton_utils = None | |
# --- Configuration --- | |
# Download the entire repository for local fallback and utils | |
repo_id = "truthdotphd/cloud-detection" | |
repo_subdir = "." | |
print(f"Downloading/Checking Hugging Face repo '{repo_id}'...") | |
repo_dir = snapshot_download(repo_id=repo_id, local_dir=repo_subdir, local_dir_use_symlinks=False) # Use False for symlinks in Gradio/Docker usually | |
print(f"Repo downloaded/cached at: {repo_dir}") | |
# Add the repository directory to the Python path for local modules | |
sys.path.append(repo_dir) | |
# Import the necessary functions from the downloaded modules for LOCAL fallback | |
try: | |
# Adjust path if omnicloudmask is inside a subfolder | |
omnicloudmask_path = os.path.join(repo_dir, "omnicloudmask") | |
if os.path.isdir(omnicloudmask_path): | |
sys.path.append(omnicloudmask_path) # Add subfolder if exists | |
from omnicloudmask import predict_from_array | |
LOCAL_MODEL_AVAILABLE = True | |
print("Local omnicloudmask module loaded successfully.") | |
except ImportError as e: | |
print(f"ERROR: Could not import local 'predict_from_array' from omnicloudmask: {e}") | |
print("Local fallback will not be available.") | |
LOCAL_MODEL_AVAILABLE = False | |
predict_from_array = None # Define dummy | |
# --- Triton Server Configuration --- | |
TRITON_IP = "206.123.129.87" # Use the public IP provided | |
HTTP_TRITON_URL = f"{TRITON_IP}:8000" | |
# GRPC_TRITON_URL = f"{TRITON_IP}:8001" # Keep for potential future use | |
TRITON_MODEL_NAME = "cloud-detection" # Ensure this matches your deployed model name | |
TRITON_INPUT_NAME = "input_jp2_bytes" # Ensure this matches your model's config.pbtxt | |
TRITON_OUTPUT_NAME = "output_mask" # Ensure this matches your model's config.pbtxt | |
TRITON_TIMEOUT_SECONDS = 300 # 5 minutes timeout for connection/network | |
# --- Utility Functions (mostly from paste.txt) --- | |
def visualize_rgb(red_file, green_file, blue_file): | |
""" | |
Create and display an RGB visualization immediately after images are uploaded. | |
(Modified slightly: doesn't need nir_file) | |
""" | |
if not all([red_file, green_file, blue_file]): | |
return None | |
try: | |
# Load bands (using load_band utility) | |
# Get target shape from red band | |
with rasterio.open(red_file) as src: | |
target_height = src.height | |
target_width = src.width | |
blue_data = load_band(blue_file) | |
green_data = load_band(green_file) | |
red_data = load_band(red_file) | |
# Compute max values for scaling (simple approach) | |
red_max = np.percentile(red_data[red_data>0], 98) if np.any(red_data>0) else 1.0 | |
green_max = np.percentile(green_data[green_data>0], 98) if np.any(green_data>0) else 1.0 | |
blue_max = np.percentile(blue_data[blue_data>0], 98) if np.any(blue_data>0) else 1.0 | |
# Create RGB image for visualization with dynamic normalization | |
rgb_image = np.zeros((red_data.shape[0], red_data.shape[1], 3), dtype=np.float32) | |
epsilon = 1e-10 | |
rgb_image[:, :, 0] = np.clip(red_data / (red_max + epsilon), 0, 1) | |
rgb_image[:, :, 1] = np.clip(green_data / (green_max + epsilon), 0, 1) | |
rgb_image[:, :, 2] = np.clip(blue_data / (blue_max + epsilon), 0, 1) | |
# Simple brightness/contrast adjustment (gamma correction) | |
gamma = 1.8 | |
rgb_image_enhanced = np.power(rgb_image, 1/gamma) | |
# Convert to uint8 for display | |
rgb_display = (rgb_image_enhanced * 255).astype(np.uint8) | |
return rgb_display | |
except Exception as e: | |
print(f"Error generating RGB preview: {e}") | |
import traceback | |
traceback.print_exc() | |
return None | |
def visualize_jp2(file_path): | |
""" | |
Visualize a single JP2 file. (Unchanged from paste.txt) | |
""" | |
try: | |
with rasterio.open(file_path) as src: | |
data = src.read(1) | |
# Check if data is all zero or invalid | |
if np.all(data == 0) or np.ptp(data) == 0: | |
print(f"Warning: Data in {file_path} is constant or zero. Cannot normalize.") | |
# Return a black image or handle as appropriate | |
return np.zeros((src.height, src.width, 3), dtype=np.uint8) | |
# Normalize the data for visualization | |
data_norm = (data - np.min(data)) / (np.max(data) - np.min(data)) | |
# Apply a colormap for better visualization | |
cmap = plt.get_cmap('viridis') | |
colored_image = cmap(data_norm) | |
# Convert to 8-bit for display | |
return (colored_image[:, :, :3] * 255).astype(np.uint8) | |
except Exception as e: | |
print(f"Error visualizing JP2 file {file_path}: {e}") | |
return None | |
def load_band(file_path, resample=False, target_height=None, target_width=None): | |
""" | |
Load a single band from a raster file with optional resampling. (Unchanged from paste.txt) | |
""" | |
try: | |
with rasterio.open(file_path) as src: | |
if resample and target_height is not None and target_width is not None: | |
# Ensure output shape matches target channels (1 for single band) | |
out_shape = (1, target_height, target_width) | |
band_data = src.read( | |
out_shape=out_shape, | |
resampling=Resampling.bilinear | |
)[0].astype(np.float32) # Read only the first band after resampling | |
else: | |
band_data = src.read(1).astype(np.float32) # Read only the first band | |
return band_data | |
except Exception as e: | |
print(f"Error loading band {file_path}: {e}") | |
raise # Re-raise error to be caught by calling function | |
def prepare_input_array(red_file, green_file, blue_file, nir_file): | |
""" | |
Prepare a stacked array (R, G, NIR) for the LOCAL model and an RGB image for visualization. | |
(Slightly modified from paste.txt to handle potential loading errors) | |
Returns: | |
prediction_array (np.ndarray): Stacked array (R,G,NIR) for local model, or None on error. | |
rgb_image_enhanced (np.ndarray): RGB image (0-1 float) for visualization, or None on error. | |
""" | |
try: | |
# Get dimensions from red band to use for resampling | |
with rasterio.open(red_file) as src: | |
target_height = src.height | |
target_width = src.width | |
# Load bands (resample NIR band to match 10m resolution) | |
blue_data = load_band(blue_file) # Needed for RGB viz | |
green_data = load_band(green_file) | |
red_data = load_band(red_file) | |
nir_data = load_band( | |
nir_file, | |
resample=True, | |
target_height=target_height, | |
target_width=target_width | |
) | |
# --- Prepare RGB Image for Visualization (similar to visualize_rgb but returns float array) --- | |
red_max = np.percentile(red_data[red_data>0], 98) if np.any(red_data>0) else 1.0 | |
green_max = np.percentile(green_data[green_data>0], 98) if np.any(green_data>0) else 1.0 | |
blue_max = np.percentile(blue_data[blue_data>0], 98) if np.any(blue_data>0) else 1.0 | |
epsilon = 1e-10 | |
rgb_image = np.zeros((target_height, target_width, 3), dtype=np.float32) | |
rgb_image[:, :, 0] = np.clip(red_data / (red_max + epsilon), 0, 1) | |
rgb_image[:, :, 1] = np.clip(green_data / (green_max + epsilon), 0, 1) | |
rgb_image[:, :, 2] = np.clip(blue_data / (blue_max + epsilon), 0, 1) | |
# Apply gamma correction for enhancement | |
gamma = 1.8 | |
rgb_image_enhanced = np.power(rgb_image, 1/gamma) | |
# --- End RGB Image Preparation --- | |
# Stack bands in CHW format for LOCAL cloud mask prediction (red, green, nir) | |
# Ensure all bands have the same shape before stacking | |
if not (red_data.shape == green_data.shape == nir_data.shape): | |
print("ERROR: Band shapes mismatch after loading/resampling!") | |
print(f"Shapes - Red: {red_data.shape}, Green: {green_data.shape}, NIR: {nir_data.shape}") | |
return None, None # Indicate error | |
prediction_array = np.stack([red_data, green_data, nir_data], axis=0) # CHW format | |
print(f"Local prediction array shape: {prediction_array.shape}") | |
print(f"RGB visualization image shape: {rgb_image_enhanced.shape}") | |
return prediction_array, rgb_image_enhanced | |
except Exception as e: | |
print(f"Error during input preparation: {e}") | |
import traceback | |
traceback.print_exc() | |
return None, None # Indicate error | |
def visualize_cloud_mask(rgb_image, pred_mask): | |
""" | |
Create a visualization of the cloud mask overlaid on the RGB image. | |
(Unchanged from paste.txt, but added error checks) | |
""" | |
if rgb_image is None or pred_mask is None: | |
print("Cannot visualize cloud mask: Missing RGB image or prediction mask.") | |
return None | |
try: | |
# Ensure pred_mask has the right dimensions (H, W) | |
if pred_mask.ndim == 3 and pred_mask.shape[0] == 1: # Squeeze channel dim if present | |
pred_mask = np.squeeze(pred_mask, axis=0) | |
elif pred_mask.ndim != 2: | |
print(f"ERROR: Unexpected prediction mask dimension: {pred_mask.ndim}, shape: {pred_mask.shape}") | |
# Attempt to squeeze if possible, otherwise fail | |
try: | |
pred_mask = np.squeeze(pred_mask) | |
if pred_mask.ndim != 2: raise ValueError("Still not 2D after squeeze") | |
except Exception as sq_err: | |
print(f"Could not convert mask to 2D: {sq_err}") | |
return None # Cannot visualize | |
print(f"Visualization - RGB image shape: {rgb_image.shape}, Pred mask shape: {pred_mask.shape}") | |
# Ensure mask has the same spatial dimensions as the image | |
if pred_mask.shape != rgb_image.shape[:2]: | |
print(f"Warning: Resizing prediction mask from {pred_mask.shape} to {rgb_image.shape[:2]} for visualization.") | |
# Ensure mask is integer type for nearest neighbor interpolation | |
if not np.issubdtype(pred_mask.dtype, np.integer): | |
print("Warning: Prediction mask is not integer type, casting to uint8 for resize.") | |
pred_mask = pred_mask.astype(np.uint8) | |
pred_mask_resized = cv2.resize( | |
pred_mask, | |
(rgb_image.shape[1], rgb_image.shape[0]), # Target shape (width, height) for cv2.resize | |
interpolation=cv2.INTER_NEAREST # Use nearest to preserve class labels | |
) | |
pred_mask = pred_mask_resized | |
print(f"Resized mask shape: {pred_mask.shape}") | |
# Define colors for each class | |
colors = { | |
0: [0, 255, 0], # Clear - Green | |
1: [255, 0, 0], # Thick Cloud - Red (Changed from White for better contrast) | |
2: [255, 255, 0], # Thin Cloud - Yellow (Changed from Gray) | |
3: [0, 0, 255] # Cloud Shadow - Blue (Changed from Gray) | |
} | |
# Create a color-coded mask visualization | |
mask_vis = np.zeros((pred_mask.shape[0], pred_mask.shape[1], 3), dtype=np.uint8) | |
for class_idx, color in colors.items(): | |
# Handle potential out-of-bounds class indices in mask | |
mask_vis[pred_mask == class_idx] = color | |
# Create a blended visualization | |
alpha = 0.4 # Transparency of the mask overlay | |
# Ensure rgb_image is uint8 for blending | |
rgb_uint8 = (np.clip(rgb_image, 0, 1) * 255).astype(np.uint8) | |
blended = cv2.addWeighted(rgb_uint8, 1-alpha, mask_vis, alpha, 0) | |
# --- Create Legend --- | |
legend_height = 100 | |
legend_width = blended.shape[1] # Match image width | |
legend = np.ones((legend_height, legend_width, 3), dtype=np.uint8) * 255 # White background | |
legend_text = ["Clear", "Thick Cloud", "Thin Cloud", "Cloud Shadow"] | |
legend_colors = [colors.get(i, [0,0,0]) for i in range(4)] # Use .get for safety | |
box_size = 15 | |
text_offset_x = 40 | |
start_y = 15 | |
padding_y = 20 | |
for i, (text, color) in enumerate(zip(legend_text, legend_colors)): | |
# Draw color box | |
cv2.rectangle(legend, | |
(10, start_y + i*padding_y - box_size // 2), | |
(10 + box_size, start_y + i*padding_y + box_size // 2), | |
color, -1) | |
# Draw text | |
cv2.putText(legend, text, | |
(text_offset_x, start_y + i*padding_y + box_size // 4), # Adjust vertical alignment | |
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA) | |
# --- End Legend --- | |
# Combine image and legend | |
final_output = np.vstack([blended, legend]) | |
return final_output | |
except Exception as e: | |
print(f"Error during visualization: {e}") | |
import traceback | |
traceback.print_exc() | |
return None # Return None if visualization fails | |
# --- Triton Client Functions (Adapted from paste-2.txt) --- | |
def is_triton_server_healthy(url=HTTP_TRITON_URL): | |
"""Checks if the Triton Inference Server is live.""" | |
if not TRITON_CLIENT_AVAILABLE: | |
return False | |
try: | |
triton_client = httpclient.InferenceServerClient(url=url, connection_timeout=10.0) # Short timeout for health check | |
server_live = triton_client.is_server_live() | |
if server_live: | |
print(f"Triton server at {url} is live.") | |
# Optionally check readiness: | |
# server_ready = triton_client.is_server_ready() | |
# print(f"Triton server at {url} is ready: {server_ready}") | |
# return server_ready | |
else: | |
print(f"Triton server at {url} is not live.") | |
return server_live | |
except Exception as e: | |
print(f"Could not connect to Triton server at {url}: {e}") | |
return False | |
def get_jp2_bytes_for_triton(red_file_path, green_file_path, nir_file_path): | |
""" | |
Reads the raw bytes of Red, Green, and NIR JP2 files for Triton. | |
Order: Red, Green, NIR (must match Triton model input expectation) | |
""" | |
byte_list = [] | |
files_to_read = [red_file_path, green_file_path, nir_file_path] | |
band_names = ['Red', 'Green', 'NIR'] | |
for file_path, band_name in zip(files_to_read, band_names): | |
try: | |
with open(file_path, "rb") as f: | |
file_bytes = f.read() | |
byte_list.append(file_bytes) | |
print(f"Read {len(file_bytes)} bytes for {band_name} band from {os.path.basename(file_path)}") | |
except FileNotFoundError: | |
print(f"ERROR: File not found: {file_path}") | |
raise # Propagate error | |
except Exception as e: | |
print(f"ERROR: Could not read file {file_path}: {e}") | |
raise # Propagate error | |
# Create NumPy array of object type to hold bytes | |
input_byte_array = np.array(byte_list, dtype=object) | |
# Expected shape is (3,) -> a 1D array containing 3 byte objects | |
print(f"Prepared Triton input byte array with shape: {input_byte_array.shape} and dtype: {input_byte_array.dtype}") | |
return input_byte_array | |
def run_inference_triton_http(input_byte_array): | |
""" | |
Run inference using Triton HTTP client with raw JP2 bytes. | |
""" | |
if not TRITON_CLIENT_AVAILABLE: | |
raise RuntimeError("Triton client library not available.") | |
print("Attempting inference using Triton HTTP client...") | |
try: | |
client = httpclient.InferenceServerClient( | |
url=HTTP_TRITON_URL, | |
verbose=False, | |
connection_timeout=TRITON_TIMEOUT_SECONDS, | |
network_timeout=TRITON_TIMEOUT_SECONDS | |
) | |
except Exception as e: | |
print(f"ERROR: Couldn't create Triton HTTP client: {e}") | |
raise # Propagate error | |
# Prepare input tensor (BYTES type) | |
# Shape [3] matches the 1D numpy array holding 3 byte strings | |
inputs = [httpclient.InferInput(TRITON_INPUT_NAME, input_byte_array.shape, "BYTES")] | |
inputs[0].set_data_from_numpy(input_byte_array, binary_data=True) # binary_data=True is important for BYTES | |
# Prepare output tensor request | |
outputs = [httpclient.InferRequestedOutput(TRITON_OUTPUT_NAME, binary_data=True)] | |
# Send inference request | |
try: | |
print(f"Sending inference request to Triton model '{TRITON_MODEL_NAME}' at {HTTP_TRITON_URL}...") | |
response = client.infer( | |
model_name=TRITON_MODEL_NAME, | |
inputs=inputs, | |
outputs=outputs, | |
request_id=str(os.getpid()), # Optional request ID | |
timeout=TRITON_TIMEOUT_SECONDS | |
) | |
print("Triton inference request successful.") | |
mask = response.as_numpy(TRITON_OUTPUT_NAME) | |
print(f"Received output mask from Triton with shape: {mask.shape}, dtype: {mask.dtype}") | |
return mask | |
except triton_utils.InferenceServerException as e: | |
print(f"ERROR: Triton server failed inference: Status code {e.status()}, message: {e.message()}") | |
print(f"Debug details: {e.debug_details()}") | |
raise # Propagate error to trigger fallback | |
except Exception as e: | |
print(f"ERROR: An unexpected error occurred during Triton HTTP inference: {e}") | |
import traceback | |
traceback.print_exc() | |
raise # Propagate error to trigger fallback | |
# --- Main Processing Function with Fallback Logic --- | |
def process_satellite_images(red_file, green_file, blue_file, nir_file, batch_size, patch_size, patch_overlap): | |
""" | |
Process satellite images: Try Triton first, fallback to local model. | |
""" | |
if not all([red_file, green_file, blue_file, nir_file]): | |
return None, None, "ERROR: Please upload all four channel files (Red, Green, Blue, NIR)" | |
# Store file paths from Gradio Image components | |
red_file_path = red_file if isinstance(red_file, str) else red_file.name | |
green_file_path = green_file if isinstance(green_file, str) else green_file.name | |
blue_file_path = blue_file if isinstance(blue_file, str) else blue_file.name | |
nir_file_path = nir_file if isinstance(nir_file, str) else nir_file.name | |
print("\n--- Starting Cloud Detection Process ---") | |
print(f"Input files: R={os.path.basename(red_file_path)}, G={os.path.basename(green_file_path)}, B={os.path.basename(blue_file_path)}, N={os.path.basename(nir_file_path)}") | |
pred_mask = None | |
status_message = "" | |
rgb_display_image = None # For the raw RGB output panel | |
rgb_float_image = None # For overlay visualization | |
# 1. Prepare Visualization Image (always needed) & Local Input Array (needed for fallback) | |
print("Preparing visualization image and local model input array...") | |
local_input_array, rgb_float_image = prepare_input_array(red_file_path, green_file_path, blue_file_path, nir_file_path) | |
if rgb_float_image is not None: | |
# Convert float image (0-1) to uint8 (0-255) for the RGB output panel | |
rgb_display_image = (np.clip(rgb_float_image, 0, 1) * 255).astype(np.uint8) | |
else: | |
print("ERROR: Failed to create RGB visualization image.") | |
# Return early if visualization prep failed, as likely indicates file loading issues | |
return None, None, "ERROR: Failed to load or process input band files." | |
# 2. Check Triton Server Health | |
use_triton = False | |
if TRITON_CLIENT_AVAILABLE: | |
print(f"Checking Triton server health at {HTTP_TRITON_URL}...") | |
if is_triton_server_healthy(HTTP_TRITON_URL): | |
use_triton = True | |
else: | |
print("Triton server is not healthy or unavailable.") | |
status_message += "Triton server unavailable. " | |
else: | |
print("Triton client library not installed. Skipping Triton check.") | |
status_message += "Triton client not installed. " | |
# 3. Attempt Triton Inference if Healthy | |
if use_triton: | |
try: | |
print("Preparing JP2 bytes for Triton...") | |
# Use Red, Green, NIR file paths | |
triton_byte_input = get_jp2_bytes_for_triton(red_file_path, green_file_path, nir_file_path) | |
pred_mask = run_inference_triton_http(triton_byte_input) | |
status_message += "Inference performed using Triton Server. " | |
print("Triton inference successful.") | |
except Exception as e: | |
print(f"Triton inference failed: {e}. Falling back to local model.") | |
status_message += f"Triton inference failed ({type(e).__name__}). " | |
pred_mask = None # Ensure mask is None to trigger fallback | |
use_triton = False # Explicitly mark Triton as not used | |
# 4. Fallback to Local Model if Triton failed or wasn't available/healthy | |
if pred_mask is None: # Check if mask wasn't obtained from Triton | |
status_message += "Falling back to local inference. " | |
if LOCAL_MODEL_AVAILABLE and local_input_array is not None: | |
print("Running local inference using omnicloudmask...") | |
try: | |
# Predict cloud mask using local omnicloudmask | |
pred_mask = predict_from_array( | |
local_input_array, | |
batch_size=batch_size, | |
patch_size=patch_size, | |
patch_overlap=patch_overlap | |
) | |
print(f"Local prediction successful. Output mask shape: {pred_mask.shape}, dtype: {pred_mask.dtype}") | |
status_message += "Local inference successful." | |
except Exception as e: | |
print(f"ERROR: Local inference failed: {e}") | |
import traceback | |
traceback.print_exc() | |
status_message += f"Local inference FAILED: {e}" | |
# Keep pred_mask as None | |
elif not LOCAL_MODEL_AVAILABLE: | |
status_message += "Local model not available. Cannot perform inference." | |
print("ERROR: Local model could not be loaded.") | |
elif local_input_array is None: | |
status_message += "Local input data preparation failed. Cannot perform local inference." | |
print("ERROR: Failed to prepare input array for local model.") | |
else: | |
status_message += "Unknown state, cannot perform inference." # Should not happen | |
# 5. Process Results (Stats and Visualization) if mask was generated | |
if pred_mask is not None: | |
# Ensure mask is squeezed to 2D if necessary (local model might return extra dim) | |
if pred_mask.ndim == 3 and pred_mask.shape[0] == 1: | |
flat_mask = np.squeeze(pred_mask, axis=0) | |
elif pred_mask.ndim == 2: | |
flat_mask = pred_mask | |
else: | |
print(f"ERROR: Unexpected mask shape after inference: {pred_mask.shape}") | |
status_message += " ERROR: Invalid mask shape received." | |
flat_mask = None # Invalidate mask | |
if flat_mask is not None: | |
# Calculate class distribution | |
clear_pixels = np.sum(flat_mask == 0) | |
thick_cloud_pixels = np.sum(flat_mask == 1) | |
thin_cloud_pixels = np.sum(flat_mask == 2) | |
cloud_shadow_pixels = np.sum(flat_mask == 3) | |
total_pixels = flat_mask.size | |
stats = f""" | |
Cloud Mask Statistics ({'Triton' if use_triton else 'Local'}): | |
- Clear: {clear_pixels} pixels ({clear_pixels/total_pixels*100:.2f}%) | |
- Thick Cloud: {thick_cloud_pixels} pixels ({thick_cloud_pixels/total_pixels*100:.2f}%) | |
- Thin Cloud: {thin_cloud_pixels} pixels ({thin_cloud_pixels/total_pixels*100:.2f}%) | |
- Cloud Shadow: {cloud_shadow_pixels} pixels ({cloud_shadow_pixels/total_pixels*100:.2f}%) | |
- Total Cloud Cover (Thick+Thin): {(thick_cloud_pixels + thin_cloud_pixels)/total_pixels*100:.2f}% | |
""" | |
status_message += f"\nMask stats calculated. Total pixels: {total_pixels}." | |
# Visualize the cloud mask on the original image | |
print("Generating final visualization...") | |
visualization = visualize_cloud_mask(rgb_float_image, flat_mask) # Use float image for viz function | |
if visualization is None: | |
status_message += " ERROR: Failed to generate visualization." | |
print("--- Cloud Detection Process Finished ---") | |
return rgb_display_image, visualization, status_message + "\n" + stats | |
else: | |
# Mask had wrong shape | |
return rgb_display_image, None, status_message + "\nERROR: Could not process prediction mask." | |
else: | |
# Inference failed both ways or initial loading failed | |
print("--- Cloud Detection Process Failed ---") | |
return rgb_display_image, None, status_message + "\nERROR: Could not generate cloud mask." | |
# --- Gradio Interface (from paste.txt) --- | |
def check_cpu_usage(): | |
"""Check and return the current CPU usage.""" | |
return f"CPU Usage: {psutil.cpu_percent()}%" | |
# --- Build Gradio App --- | |
print("Building Gradio interface...") | |
with gr.Blocks(title="Satellite Cloud Detection (Triton/Local)") as demo: | |
gr.Markdown(""" | |
# Satellite Cloud Detection (with Triton Fallback) | |
Upload separate JP2 files for Red (e.g., B04), Green (e.g., B03), Blue (e.g., B02), and NIR (e.g., B8A) channels. | |
The application will **first attempt** to use a remote Triton Inference Server. If the server is unavailable or inference fails, | |
it will **fall back** to using the local OmniCloudMask model. | |
**Pixel Classification:** | |
- Clear (Green) | |
- Thick Cloud (Red) | |
- Thin Cloud (Yellow) | |
- Cloud Shadow (Blue) | |
The model works best with imagery at 10-50m resolution. | |
""") | |
# Main cloud detection interface | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("### Input Bands (JP2)") | |
# Use filepaths which are needed for both local reading and byte reading | |
red_input = gr.File(label="Red Channel (e.g., B04)", type="filepath") | |
green_input = gr.File(label="Green Channel (e.g., B03)", type="filepath") | |
blue_input = gr.File(label="Blue Channel (e.g., B02)", type="filepath") | |
nir_input = gr.File(label="NIR Channel (e.g., B8A)", type="filepath") | |
gr.Markdown("### Local Model Parameters (Used for Fallback)") | |
batch_size = gr.Slider(minimum=1, maximum=32, value=4, step=1, | |
label="Batch Size", | |
info="Memory usage/speed for local model") | |
patch_size = gr.Slider(minimum=256, maximum=2048, value=1024, step=128, | |
label="Patch Size", | |
info="Patch size for local model processing") | |
patch_overlap = gr.Slider(minimum=64, maximum=512, value=256, step=64, | |
label="Patch Overlap", | |
info="Overlap for local model processing") | |
process_btn = gr.Button("Process Cloud Detection", variant="primary") | |
with gr.Column(scale=2): | |
gr.Markdown("### Results") | |
# Output components | |
rgb_output = gr.Image(label="Original RGB Image (Approx. True Color)", type="numpy") | |
cloud_output = gr.Image(label="Cloud Detection Visualization (Mask Overlay)", type="numpy") | |
stats_output = gr.Textbox(label="Processing Status & Statistics", lines=10) | |
# CPU usage monitoring section (Optional) | |
with gr.Accordion("System Monitoring", open=False): | |
cpu_button = gr.Button("Check CPU Usage") | |
cpu_output = gr.Textbox(label="Current CPU Usage") | |
cpu_button.click(fn=check_cpu_usage, inputs=None, outputs=cpu_output) | |
# Examples section | |
# Ensure example paths are relative to where the script is run, | |
# or absolute if needed. Assumes 'jp2s' folder is present. | |
example_base = os.path.join(repo_dir, "jp2s") # Use downloaded repo path | |
example_files = [ | |
os.path.join(example_base, "B04.jp2"), # Red | |
os.path.join(example_base, "B03.jp2"), # Green | |
os.path.join(example_base, "B02.jp2"), # Blue | |
os.path.join(example_base, "B8A.jp2") # NIR | |
] | |
# Check if example files actually exist before adding example | |
if all(os.path.exists(f) for f in example_files): | |
print("Adding examples...") | |
gr.Examples( | |
examples=[example_files + [4, 1024, 256]], # Corresponds to inputs below | |
inputs=[red_input, green_input, blue_input, nir_input, batch_size, patch_size, patch_overlap], | |
outputs=[rgb_output, cloud_output, stats_output], # Define outputs for examples too | |
fn=process_satellite_images, # Function to run for examples | |
cache_examples=False # Maybe disable caching if files change or for debugging | |
) | |
else: | |
print(f"WARN: Example JP2 files not found in '{example_base}'. Skipping examples.") | |
# Setup main button click handler | |
process_btn.click( | |
fn=process_satellite_images, | |
inputs=[red_input, green_input, blue_input, nir_input, batch_size, patch_size, patch_overlap], | |
outputs=[rgb_output, cloud_output, stats_output] | |
) | |
# --- Launch the App --- | |
print("Launching Gradio app...") | |
# Allow queueing and potentially increase workers if needed | |
demo.queue(default_concurrency_limit=4).launch(debug=True, share=False) # share=True for public link if needed | |