cloud-detection / omnicloudmask /raster_utils.py
Amir Erfan Eshratifar
model checkpoints, sample input, readme
241b6a2
from pathlib import Path
from typing import Optional
import numpy as np
import rasterio as rio
from rasterio.profiles import Profile
from .model_utils import channel_norm
def get_patch(
input_array: np.ndarray,
index: tuple,
no_data_value: Optional[int] = 0,
) -> tuple[Optional[np.ndarray], Optional[tuple[int, int, int, int]]]:
"""Extract a patch from a 3D array and normalize it. If the patch is entirely nodata, return None.
If the patch contains nodata, try to move patches to reduce nodata regions in patches.
"""
assert input_array.ndim == 3, "Input array must have 3 dimensions"
top, bottom, left, right = index
patch = input_array[:, top:bottom, left:right].astype(np.float32)
if patch.sum() == 0:
return None, None
if no_data_value is None:
if np.all(patch == no_data_value):
return None, None
if np.any(patch == 0):
max_bottom, max_right = input_array.shape[1:3]
if np.any(patch[:, 0, :]) or np.any(patch[:, -1, :]):
while not np.any(patch[:, 0, :]) and bottom < max_bottom: # check top row
patch = patch[:, 1:, :]
top += 1
bottom += 1
while not np.any(patch[:, -1, :]) and top > 0:
patch = patch[:, :-1, :]
bottom -= 1
top -= 1
# Both sides are not zero-filled
if np.any(patch[:, :, 0]) or np.any(patch[:, :, -1]):
while not np.any(patch[:, :, 0]) and right < max_right: # check left column
patch = patch[:, :, 1:]
left += 1
right += 1
while not np.any(patch[:, :, -1]) and left > 0: # check right column
patch = patch[:, :, :-1]
right -= 1
left -= 1
patch = input_array[:, top:bottom, left:right].astype(np.float32)
index = (top, bottom, left, right)
# trim index bottom and right to match patch shape
index = (top, top + patch.shape[1], left, left + patch.shape[2])
return channel_norm(patch, no_data_value), index
def mask_prediction(
scene: np.ndarray, pred_tracker_np: np.ndarray, no_data_value: int = 0
) -> np.ndarray:
"""Create a no data mask from a raster scene."""
assert scene.ndim == 3, "Scene must have 3 dimensions"
assert pred_tracker_np.ndim == 3, "Prediction tracker must have 3 dimensions"
assert (
scene.shape[1:] == pred_tracker_np.shape[1:]
), "Scene and prediction tracker must have the same shape"
mask = np.all(scene != no_data_value, axis=0).astype(np.uint8)
pred_tracker_np *= mask
return pred_tracker_np
def make_patch_indexes(
array_width: int,
array_height: int,
patch_size: int = 1000,
patch_overlap: int = 300,
) -> list[tuple[int, int, int, int]]:
"""Create a list of patch indexes for a given shape and patch size."""
assert patch_size > patch_overlap, "Patch size must be greater than patch overlap"
assert patch_overlap >= 0, "Patch overlap must be greater than or equal to 0"
assert patch_size > 0, "Patch size must be greater than 0"
assert (
patch_size <= array_width
), "Patch size must be less than or equal to array width"
assert (
patch_size <= array_height
), "Patch size must be less than or equal to array height"
stride = patch_size - patch_overlap
max_bottom = array_height - patch_size
max_right = array_width - patch_size
patch_indexes = []
for top in range(0, array_height, stride):
if top > max_bottom:
top = max_bottom
bottom = top + patch_size
for left in range(0, array_width, stride):
if left > max_right:
left = max_right
right = left + patch_size
patch_indexes.append((top, bottom, left, right))
return patch_indexes
def save_prediction(
output_path: Path, export_profile: Profile, pred_tracker_np: np.ndarray
) -> None:
with rio.open(output_path, "w", **export_profile) as dst:
dst.write(pred_tracker_np)