|
import os |
|
import glob |
|
import warnings |
|
import numpy as np |
|
from natsort import natsorted |
|
|
|
from datetime import datetime |
|
|
|
to_date = lambda string: datetime.strptime(string, "%Y-%m-%d") |
|
S1_LAUNCH = to_date("2014-04-03") |
|
|
|
|
|
from s2cloudless import S2PixelCloudDetector |
|
|
|
import rasterio |
|
from rasterio.merge import merge |
|
from scipy.ndimage import gaussian_filter |
|
from torch.utils.data import Dataset |
|
|
|
|
|
|
|
from util.detect_cloudshadow import get_cloud_mask, get_shadow_mask |
|
|
|
|
|
|
|
def read_tif(path_IMG): |
|
tif = rasterio.open(path_IMG) |
|
return tif |
|
|
|
|
|
def read_img(tif): |
|
return tif.read().astype(np.float32) |
|
|
|
|
|
def rescale(img, oldMin, oldMax): |
|
oldRange = oldMax - oldMin |
|
img = (img - oldMin) / oldRange |
|
return img |
|
|
|
|
|
def process_MS(img, method): |
|
if method == "default": |
|
intensity_min, intensity_max = ( |
|
0, |
|
10000, |
|
) |
|
img = np.clip( |
|
img, intensity_min, intensity_max |
|
) |
|
img = rescale( |
|
img, intensity_min, intensity_max |
|
) |
|
if method == "resnet": |
|
intensity_min, intensity_max = ( |
|
0, |
|
10000, |
|
) |
|
img = np.clip( |
|
img, intensity_min, intensity_max |
|
) |
|
img /= 2000 |
|
img = np.nan_to_num(img) |
|
return img |
|
|
|
|
|
def process_SAR(img, method): |
|
if method == "default": |
|
dB_min, dB_max = -25, 0 |
|
img = np.clip( |
|
img, dB_min, dB_max |
|
) |
|
img = rescale( |
|
img, dB_min, dB_max |
|
) |
|
if method == "resnet": |
|
|
|
dB_min, dB_max = [-25.0, -32.5], [0, 0] |
|
img = np.concatenate( |
|
[ |
|
( |
|
2 |
|
* (np.clip(img[0], dB_min[0], dB_max[0]) - dB_min[0]) |
|
/ (dB_max[0] - dB_min[0]) |
|
)[None, ...], |
|
( |
|
2 |
|
* (np.clip(img[1], dB_min[1], dB_max[1]) - dB_min[1]) |
|
/ (dB_max[1] - dB_min[1]) |
|
)[None, ...], |
|
], |
|
axis=0, |
|
) |
|
img = np.nan_to_num(img) |
|
return img |
|
|
|
|
|
def get_cloud_cloudshadow_mask(img, cloud_threshold=0.2): |
|
cloud_mask = get_cloud_mask(img, cloud_threshold, binarize=True) |
|
shadow_mask = get_shadow_mask(img) |
|
|
|
|
|
cloud_cloudshadow_mask = np.zeros_like(cloud_mask) |
|
cloud_cloudshadow_mask[shadow_mask < 0] = -1 |
|
cloud_cloudshadow_mask[cloud_mask > 0] = 1 |
|
|
|
|
|
cloud_cloudshadow_mask[cloud_cloudshadow_mask != 0] = 1 |
|
return cloud_cloudshadow_mask |
|
|
|
|
|
|
|
def iterdict(dictionary, fct): |
|
for k, v in dictionary.items(): |
|
if isinstance(v, dict): |
|
dictionary[k] = iterdict(v, fct) |
|
else: |
|
dictionary[k] = fct(v) |
|
return dictionary |
|
|
|
|
|
def get_cloud_map(img, detector, instance=None): |
|
|
|
img = np.clip(img, 0, 10000) |
|
mask = np.ones((img.shape[-1], img.shape[-1])) |
|
|
|
|
|
if not (img.mean() < 1e-5 and img.std() < 1e-5): |
|
if detector == "cloud_cloudshadow_mask": |
|
threshold = 0.2 |
|
mask = get_cloud_cloudshadow_mask(img, threshold) |
|
elif detector == "s2cloudless_map": |
|
threshold = 0.5 |
|
mask = instance.get_cloud_probability_maps( |
|
np.moveaxis(img / 10000, 0, -1)[None, ...] |
|
)[0, ...] |
|
mask[mask < threshold] = 0 |
|
mask = gaussian_filter(mask, sigma=2) |
|
elif detector == "s2cloudless_mask": |
|
mask = instance.get_cloud_masks(np.moveaxis(img / 10000, 0, -1)[None, ...])[ |
|
0, ... |
|
] |
|
else: |
|
mask = np.ones((img.shape[-1], img.shape[-1])) |
|
warnings.warn(f"Method {detector} not yet implemented!") |
|
else: |
|
warnings.warn(f"Encountered a blank sample, defaulting to cloudy mask.") |
|
return mask.astype(np.float32) |
|
|
|
|
|
|
|
def get_pairedS1(patch_list, root_dir, mod=None, time=None): |
|
paired_list = [] |
|
for patch in patch_list: |
|
seed, roi, modality, time_number, fname = patch.split("/") |
|
time = time_number if time is None else time |
|
mod = ( |
|
modality if mod is None else mod |
|
) |
|
n_patch = fname.split("patch_")[-1].split(".tif")[0] |
|
paired_dir = os.path.join(seed, roi, mod.upper(), str(time)) |
|
candidates = os.path.join( |
|
root_dir, |
|
paired_dir, |
|
f"{mod}_{seed}_{roi}_ImgNo_{time}_*_patch_{n_patch}.tif", |
|
) |
|
paired_list.append( |
|
os.path.join(paired_dir, os.path.basename(glob.glob(candidates)[0])) |
|
) |
|
return paired_list |
|
|
|
|
|
|
|
|
|
|
|
|
|
""" SEN12MSCR data loader class, inherits from torch.utils.data.Dataset |
|
|
|
IN: |
|
root: str, path to your copy of the SEN12MS-CR-TS data set |
|
split: str, in [all | train | val | test] |
|
region: str, [all | africa | america | asiaEast | asiaWest | europa] |
|
cloud_masks: str, type of cloud mask detector to run on optical data, in [] |
|
sample_type: str, [generic | cloudy_cloudfree] |
|
n_input_samples: int, number of input samples in time series |
|
rescale_method: str, [default | resnet] |
|
|
|
OUT: |
|
data_loader: SEN12MSCRTS instance, implements an iterator that can be traversed via __getitem__(pdx), |
|
which returns the pdx-th dictionary of patch-samples (whose structure depends on sample_type) |
|
""" |
|
|
|
|
|
class SEN12MSCR(Dataset): |
|
def __init__( |
|
self, |
|
root, |
|
split="all", |
|
region="all", |
|
cloud_masks="s2cloudless_mask", |
|
sample_type="pretrain", |
|
rescale_method="default", |
|
): |
|
self.root_dir = root |
|
self.region = region |
|
if self.region != "all": |
|
raise NotImplementedError |
|
self.ROI = { |
|
"ROIs1158": ["106"], |
|
"ROIs1868": [ |
|
"17", |
|
"36", |
|
"56", |
|
"73", |
|
"85", |
|
"100", |
|
"114", |
|
"119", |
|
"121", |
|
"126", |
|
"127", |
|
"139", |
|
"142", |
|
"143", |
|
], |
|
"ROIs1970": [ |
|
"20", |
|
"21", |
|
"35", |
|
"40", |
|
"57", |
|
"65", |
|
"71", |
|
"82", |
|
"83", |
|
"91", |
|
"112", |
|
"116", |
|
"119", |
|
"128", |
|
"132", |
|
"133", |
|
"135", |
|
"139", |
|
"142", |
|
"144", |
|
"149", |
|
], |
|
"ROIs2017": [ |
|
"8", |
|
"22", |
|
"25", |
|
"32", |
|
"49", |
|
"61", |
|
"63", |
|
"69", |
|
"75", |
|
"103", |
|
"108", |
|
"115", |
|
"116", |
|
"117", |
|
"130", |
|
"140", |
|
"146", |
|
], |
|
} |
|
|
|
|
|
self.splits = {} |
|
self.splits["train"] = [ |
|
"ROIs1970_fall_s1/s1_3", |
|
"ROIs1970_fall_s1/s1_22", |
|
"ROIs1970_fall_s1/s1_148", |
|
"ROIs1970_fall_s1/s1_107", |
|
"ROIs1970_fall_s1/s1_1", |
|
"ROIs1970_fall_s1/s1_114", |
|
"ROIs1970_fall_s1/s1_135", |
|
"ROIs1970_fall_s1/s1_40", |
|
"ROIs1970_fall_s1/s1_42", |
|
"ROIs1970_fall_s1/s1_31", |
|
"ROIs1970_fall_s1/s1_149", |
|
"ROIs1970_fall_s1/s1_64", |
|
"ROIs1970_fall_s1/s1_28", |
|
"ROIs1970_fall_s1/s1_144", |
|
"ROIs1970_fall_s1/s1_57", |
|
"ROIs1970_fall_s1/s1_35", |
|
"ROIs1970_fall_s1/s1_133", |
|
"ROIs1970_fall_s1/s1_30", |
|
"ROIs1970_fall_s1/s1_134", |
|
"ROIs1970_fall_s1/s1_141", |
|
"ROIs1970_fall_s1/s1_112", |
|
"ROIs1970_fall_s1/s1_116", |
|
"ROIs1970_fall_s1/s1_37", |
|
"ROIs1970_fall_s1/s1_26", |
|
"ROIs1970_fall_s1/s1_77", |
|
"ROIs1970_fall_s1/s1_100", |
|
"ROIs1970_fall_s1/s1_83", |
|
"ROIs1970_fall_s1/s1_71", |
|
"ROIs1970_fall_s1/s1_93", |
|
"ROIs1970_fall_s1/s1_119", |
|
"ROIs1970_fall_s1/s1_104", |
|
"ROIs1970_fall_s1/s1_136", |
|
"ROIs1970_fall_s1/s1_6", |
|
"ROIs1970_fall_s1/s1_41", |
|
"ROIs1970_fall_s1/s1_125", |
|
"ROIs1970_fall_s1/s1_91", |
|
"ROIs1970_fall_s1/s1_131", |
|
"ROIs1970_fall_s1/s1_120", |
|
"ROIs1970_fall_s1/s1_110", |
|
"ROIs1970_fall_s1/s1_19", |
|
"ROIs1970_fall_s1/s1_14", |
|
"ROIs1970_fall_s1/s1_81", |
|
"ROIs1970_fall_s1/s1_39", |
|
"ROIs1970_fall_s1/s1_109", |
|
"ROIs1970_fall_s1/s1_33", |
|
"ROIs1970_fall_s1/s1_88", |
|
"ROIs1970_fall_s1/s1_11", |
|
"ROIs1970_fall_s1/s1_128", |
|
"ROIs1970_fall_s1/s1_142", |
|
"ROIs1970_fall_s1/s1_122", |
|
"ROIs1970_fall_s1/s1_4", |
|
"ROIs1970_fall_s1/s1_27", |
|
"ROIs1970_fall_s1/s1_147", |
|
"ROIs1970_fall_s1/s1_85", |
|
"ROIs1970_fall_s1/s1_82", |
|
"ROIs1970_fall_s1/s1_105", |
|
"ROIs1158_spring_s1/s1_9", |
|
"ROIs1158_spring_s1/s1_1", |
|
"ROIs1158_spring_s1/s1_124", |
|
"ROIs1158_spring_s1/s1_40", |
|
"ROIs1158_spring_s1/s1_101", |
|
"ROIs1158_spring_s1/s1_21", |
|
"ROIs1158_spring_s1/s1_134", |
|
"ROIs1158_spring_s1/s1_145", |
|
"ROIs1158_spring_s1/s1_141", |
|
"ROIs1158_spring_s1/s1_66", |
|
"ROIs1158_spring_s1/s1_8", |
|
"ROIs1158_spring_s1/s1_26", |
|
"ROIs1158_spring_s1/s1_77", |
|
"ROIs1158_spring_s1/s1_113", |
|
"ROIs1158_spring_s1/s1_100", |
|
"ROIs1158_spring_s1/s1_117", |
|
"ROIs1158_spring_s1/s1_119", |
|
"ROIs1158_spring_s1/s1_6", |
|
"ROIs1158_spring_s1/s1_58", |
|
"ROIs1158_spring_s1/s1_120", |
|
"ROIs1158_spring_s1/s1_110", |
|
"ROIs1158_spring_s1/s1_126", |
|
"ROIs1158_spring_s1/s1_115", |
|
"ROIs1158_spring_s1/s1_121", |
|
"ROIs1158_spring_s1/s1_39", |
|
"ROIs1158_spring_s1/s1_109", |
|
"ROIs1158_spring_s1/s1_63", |
|
"ROIs1158_spring_s1/s1_75", |
|
"ROIs1158_spring_s1/s1_132", |
|
"ROIs1158_spring_s1/s1_128", |
|
"ROIs1158_spring_s1/s1_142", |
|
"ROIs1158_spring_s1/s1_15", |
|
"ROIs1158_spring_s1/s1_45", |
|
"ROIs1158_spring_s1/s1_97", |
|
"ROIs1158_spring_s1/s1_147", |
|
"ROIs1868_summer_s1/s1_90", |
|
"ROIs1868_summer_s1/s1_87", |
|
"ROIs1868_summer_s1/s1_25", |
|
"ROIs1868_summer_s1/s1_124", |
|
"ROIs1868_summer_s1/s1_114", |
|
"ROIs1868_summer_s1/s1_135", |
|
"ROIs1868_summer_s1/s1_40", |
|
"ROIs1868_summer_s1/s1_101", |
|
"ROIs1868_summer_s1/s1_42", |
|
"ROIs1868_summer_s1/s1_31", |
|
"ROIs1868_summer_s1/s1_36", |
|
"ROIs1868_summer_s1/s1_139", |
|
"ROIs1868_summer_s1/s1_56", |
|
"ROIs1868_summer_s1/s1_133", |
|
"ROIs1868_summer_s1/s1_55", |
|
"ROIs1868_summer_s1/s1_43", |
|
"ROIs1868_summer_s1/s1_113", |
|
"ROIs1868_summer_s1/s1_76", |
|
"ROIs1868_summer_s1/s1_123", |
|
"ROIs1868_summer_s1/s1_143", |
|
"ROIs1868_summer_s1/s1_93", |
|
"ROIs1868_summer_s1/s1_125", |
|
"ROIs1868_summer_s1/s1_89", |
|
"ROIs1868_summer_s1/s1_120", |
|
"ROIs1868_summer_s1/s1_126", |
|
"ROIs1868_summer_s1/s1_72", |
|
"ROIs1868_summer_s1/s1_115", |
|
"ROIs1868_summer_s1/s1_121", |
|
"ROIs1868_summer_s1/s1_146", |
|
"ROIs1868_summer_s1/s1_140", |
|
"ROIs1868_summer_s1/s1_95", |
|
"ROIs1868_summer_s1/s1_102", |
|
"ROIs1868_summer_s1/s1_7", |
|
"ROIs1868_summer_s1/s1_11", |
|
"ROIs1868_summer_s1/s1_132", |
|
"ROIs1868_summer_s1/s1_15", |
|
"ROIs1868_summer_s1/s1_137", |
|
"ROIs1868_summer_s1/s1_4", |
|
"ROIs1868_summer_s1/s1_27", |
|
"ROIs1868_summer_s1/s1_147", |
|
"ROIs1868_summer_s1/s1_86", |
|
"ROIs1868_summer_s1/s1_47", |
|
"ROIs2017_winter_s1/s1_68", |
|
"ROIs2017_winter_s1/s1_25", |
|
"ROIs2017_winter_s1/s1_62", |
|
"ROIs2017_winter_s1/s1_135", |
|
"ROIs2017_winter_s1/s1_42", |
|
"ROIs2017_winter_s1/s1_64", |
|
"ROIs2017_winter_s1/s1_21", |
|
"ROIs2017_winter_s1/s1_55", |
|
"ROIs2017_winter_s1/s1_112", |
|
"ROIs2017_winter_s1/s1_116", |
|
"ROIs2017_winter_s1/s1_8", |
|
"ROIs2017_winter_s1/s1_59", |
|
"ROIs2017_winter_s1/s1_49", |
|
"ROIs2017_winter_s1/s1_104", |
|
"ROIs2017_winter_s1/s1_81", |
|
"ROIs2017_winter_s1/s1_146", |
|
"ROIs2017_winter_s1/s1_75", |
|
"ROIs2017_winter_s1/s1_94", |
|
"ROIs2017_winter_s1/s1_102", |
|
"ROIs2017_winter_s1/s1_61", |
|
"ROIs2017_winter_s1/s1_47", |
|
"ROIs1868_summer_s1/s1_100", |
|
] |
|
self.splits["val"] = [ |
|
"ROIs2017_winter_s1/s1_22", |
|
"ROIs1868_summer_s1/s1_19", |
|
"ROIs1970_fall_s1/s1_65", |
|
"ROIs1158_spring_s1/s1_17", |
|
"ROIs2017_winter_s1/s1_107", |
|
"ROIs1868_summer_s1/s1_80", |
|
"ROIs1868_summer_s1/s1_127", |
|
"ROIs2017_winter_s1/s1_130", |
|
"ROIs1868_summer_s1/s1_17", |
|
"ROIs2017_winter_s1/s1_84", |
|
] |
|
self.splits["test"] = [ |
|
"ROIs1158_spring_s1/s1_106", |
|
"ROIs1158_spring_s1/s1_123", |
|
"ROIs1158_spring_s1/s1_140", |
|
"ROIs1158_spring_s1/s1_31", |
|
"ROIs1158_spring_s1/s1_44", |
|
"ROIs1868_summer_s1/s1_119", |
|
"ROIs1868_summer_s1/s1_73", |
|
"ROIs1970_fall_s1/s1_139", |
|
"ROIs2017_winter_s1/s1_108", |
|
"ROIs2017_winter_s1/s1_63", |
|
] |
|
|
|
self.splits["all"] = ( |
|
self.splits["train"] + self.splits["test"] + self.splits["val"] |
|
) |
|
self.split = split |
|
|
|
assert split in [ |
|
"all", |
|
"train", |
|
"val", |
|
"test", |
|
], "Input dataset must be either assigned as all, train, test, or val!" |
|
assert sample_type in ["pretrain"], "Input data must be pretrain!" |
|
assert cloud_masks in [ |
|
None, |
|
"cloud_cloudshadow_mask", |
|
"s2cloudless_map", |
|
"s2cloudless_mask", |
|
], "Unknown cloud mask type!" |
|
|
|
self.modalities = ["S1", "S2"] |
|
self.cloud_masks = cloud_masks |
|
self.sample_type = sample_type |
|
|
|
self.time_points = range(1) |
|
self.n_input_t = 1 |
|
|
|
if self.cloud_masks in ["s2cloudless_map", "s2cloudless_mask"]: |
|
self.cloud_detector = S2PixelCloudDetector( |
|
threshold=0.4, all_bands=True, average_over=4, dilation_size=2 |
|
) |
|
else: |
|
self.cloud_detector = None |
|
|
|
self.paths = self.get_paths() |
|
self.n_samples = len(self.paths) |
|
|
|
|
|
if not self.n_samples: |
|
self.throw_warn() |
|
|
|
self.method = rescale_method |
|
|
|
|
|
def get_paths( |
|
self, |
|
): |
|
print(f"\nProcessing paths for {self.split} split of region {self.region}") |
|
|
|
paths = [] |
|
seeds_S1 = natsorted( |
|
[s1dir for s1dir in os.listdir(self.root_dir) if "_s1" in s1dir] |
|
) |
|
for seed in seeds_S1: |
|
rois_S1 = natsorted(os.listdir(os.path.join(self.root_dir, seed))) |
|
for roi in rois_S1: |
|
roi_dir = os.path.join(self.root_dir, seed, roi) |
|
paths_S1 = natsorted( |
|
[os.path.join(roi_dir, s1patch) for s1patch in os.listdir(roi_dir)] |
|
) |
|
paths_S2 = [ |
|
patch.replace("/s1", "/s2").replace("_s1", "_s2") |
|
for patch in paths_S1 |
|
] |
|
paths_S2_cloudy = [ |
|
patch.replace("/s1", "/s2_cloudy").replace("_s1", "_s2_cloudy") |
|
for patch in paths_S1 |
|
] |
|
|
|
for pdx, _ in enumerate(paths_S1): |
|
|
|
if not all( |
|
[ |
|
os.path.isfile(paths_S1[pdx]), |
|
os.path.isfile(paths_S2[pdx]), |
|
os.path.isfile(paths_S2_cloudy[pdx]), |
|
] |
|
): |
|
continue |
|
|
|
if not any( |
|
[ |
|
split_roi in paths_S1[pdx] |
|
for split_roi in self.splits[self.split] |
|
] |
|
): |
|
continue |
|
sample = { |
|
"S1": paths_S1[pdx], |
|
"S2": paths_S2[pdx], |
|
"S2_cloudy": paths_S2_cloudy[pdx], |
|
} |
|
paths.append(sample) |
|
return paths |
|
|
|
def __getitem__(self, pdx): |
|
s1_tif = read_tif(self.paths[pdx]["S1"]) |
|
s2_tif = read_tif(self.paths[pdx]["S2"]) |
|
s2_cloudy_tif = read_tif(self.paths[pdx]["S2_cloudy"]) |
|
coord = list(s2_tif.bounds) |
|
s1 = process_SAR(read_img(s1_tif), self.method) |
|
s2 = read_img(s2_tif) |
|
s2_cloudy = read_img( |
|
s2_cloudy_tif |
|
) |
|
mask = ( |
|
None |
|
if not self.cloud_masks |
|
else get_cloud_map(s2_cloudy, self.cloud_masks, self.cloud_detector) |
|
) |
|
|
|
sample = { |
|
"input": { |
|
"S1": s1, |
|
"S2": process_MS(s2_cloudy, self.method), |
|
"masks": mask, |
|
"coverage": np.mean(mask), |
|
"S1 path": os.path.join(self.root_dir, self.paths[pdx]["S1"]), |
|
"S2 path": os.path.join(self.root_dir, self.paths[pdx]["S2_cloudy"]), |
|
"coord": coord, |
|
}, |
|
"target": { |
|
"S2": process_MS(s2, self.method), |
|
"S2 path": os.path.join(self.root_dir, self.paths[pdx]["S2"]), |
|
"coord": coord, |
|
}, |
|
} |
|
return sample |
|
|
|
def throw_warn(self): |
|
warnings.warn( |
|
"""No data samples found! Please use the following directory structure: |
|
|
|
path/to/your/SEN12MSCR/directory: |
|
├───ROIs1158_spring_s1 |
|
| ├─s1_1 |
|
| | |... |
|
| | ├─ROIs1158_spring_s1_1_p407.tif |
|
| | |... |
|
| ... |
|
├───ROIs1158_spring_s2 |
|
| ├─s2_1 |
|
| | |... |
|
| | ├─ROIs1158_spring_s2_1_p407.tif |
|
| | |... |
|
| ... |
|
├───ROIs1158_spring_s2_cloudy |
|
| ├─s2_cloudy_1 |
|
| | |... |
|
| | ├─ROIs1158_spring_s2_cloudy_1_p407.tif |
|
| | |... |
|
| ... |
|
... |
|
|
|
Note: Please arrange the dataset in a format as e.g. provided by the script dl_data.sh. |
|
""" |
|
) |
|
|
|
def __len__(self): |
|
|
|
return self.n_samples |
|
|
|
|
|
if __name__ == "__main__": |
|
dataset = SEN12MSCR( |
|
root="data2/SEN12MSCR", |
|
split="all", |
|
region="all", |
|
cloud_masks="s2cloudless_mask", |
|
sample_type="pretrain", |
|
rescale_method="default", |
|
) |
|
for each in dataset: |
|
print(f"{each['input']['S1'].shape}") |
|
print(f"{each['input']['S2'].shape}") |
|
print(f"{each['input']['masks'].shape}") |
|
print(f"{each['target']['S2'].shape}") |
|
|
|
|
|
|
|
|
|
break |
|
|