cloud-detection / omnicloudmask /data_loaders.py
Amir Erfan Eshratifar
model checkpoints, sample input, readme
241b6a2
import warnings
from pathlib import Path
from typing import Optional, Union
import numpy as np
import rasterio as rio
from rasterio.profiles import Profile
def load_s2(
input_path: Union[Path, str],
resolution: float = 10.0,
required_bands: list[str] = ["B04", "B03", "B8A"],
) -> tuple[np.ndarray, Profile]:
"""Load a Sentinel-2 (L1C or L2A) image from a SAFE folder containing the bands"""
if not 10 <= resolution <= 50:
raise ValueError("Resolution must be between 10 and 50")
input_path = Path(input_path)
processing_level = find_s2_processing_level(input_path)
return open_s2_bands(input_path, processing_level, resolution, required_bands)
def find_s2_processing_level(
input_path: Path,
) -> str:
"""Derive the processing level of a Sentinel-2 image from the folder name."""
folder_name = Path(input_path).name
processing_level = folder_name.split("_")[1][3:6]
if processing_level not in ["L1C", "L2A"]:
raise ValueError(
f"Processing level {processing_level} not recognized, expected L1C or L2A"
)
return processing_level
def open_s2_bands(
input_path: Path,
processing_level: str,
resolution: float,
required_bands: list[str],
) -> tuple[np.ndarray, Profile]:
bands = []
for band_name in required_bands:
if processing_level == "L1C":
try:
band = list(input_path.rglob(f"*IMG_DATA/*{band_name}.jp2"))[0]
except IndexError:
raise ValueError(f"Band {band_name} not found in {input_path}")
else:
band = None
for search_resolution in [10, 20, 60]:
band_paths = list(
input_path.rglob(f"*{band_name}_{search_resolution}m.jp2")
)
if band_paths:
band = band_paths[0]
break
if not band:
raise ValueError(f"Band {band_name} not found in {input_path}")
with rio.open(band) as src:
profile = src.profile
native_resolution = int(src.res[0])
scale_factor = native_resolution / resolution
if native_resolution == resolution:
bands.append(src.read(1))
else:
bands.append(
src.read(
1,
out_shape=(
int(src.height * scale_factor),
int(src.width * scale_factor),
),
)
)
profile["transform"] = rio.transform.from_origin( # type: ignore
profile["transform"][2],
profile["transform"][5],
resolution,
resolution,
)
data = np.array(bands)
profile["height"] = data.shape[1]
profile["width"] = data.shape[2]
return data, profile
def load_multiband(
input_path: Union[Path, str],
resample_res: Optional[float] = None,
band_order: Optional[list[int]] = None,
) -> tuple[np.ndarray, Profile]:
"""Load a multiband image and resample it to requested resolution."""
if band_order is None:
warnings.warn(
"No band order provided, using default [1, 2, 3] (RGN)", UserWarning
)
band_order = [1, 2, 3]
input_path = Path(input_path)
with rio.open(input_path) as src:
if resample_res:
current_res = src.res
desired_res = (resample_res, resample_res)
scale_factor = (
current_res[0] / desired_res[0],
current_res[1] / desired_res[1],
)
else:
scale_factor = (1, 1)
data = src.read(
band_order,
out_shape=(
len(band_order),
int(src.height * scale_factor[0]),
int(src.width * scale_factor[1]),
),
resampling=rio.enums.Resampling.nearest, # type: ignore
)
profile = src.profile
return data, profile
def load_ls8(
input_path: Union[Path, str],
resolution: int = 30,
required_bands=["B4", "B3", "B5"],
) -> tuple[np.ndarray, Profile]:
"""Load a Landsat 8 image from a folder containing the bands"""
if resolution != 30:
raise ValueError("Resolution must be 30")
input_path = Path(input_path)
band_files = {}
for band_name in required_bands:
try:
band = list(input_path.rglob(f"*{band_name}.TIF"))[0]
except IndexError:
raise ValueError(f"Band {band_name} not found in {input_path}")
band_files[band_name] = band
data = []
profile = Profile()
for band_name in required_bands:
with rio.open(band_files[band_name]) as src:
if not profile:
profile = src.profile
data.append(src.read(1))
data = np.array(data)
return data, profile