import os import glob import warnings import numpy as np from natsort import natsorted from datetime import datetime to_date = lambda string: datetime.strptime(string, "%Y-%m-%d") S1_LAUNCH = to_date("2014-04-03") # s2cloudless: see https://github.com/sentinel-hub/sentinel2-cloud-detector from s2cloudless import S2PixelCloudDetector import rasterio from rasterio.merge import merge from scipy.ndimage import gaussian_filter from torch.utils.data import Dataset # import sys # sys.path.append(".") from util.detect_cloudshadow import get_cloud_mask, get_shadow_mask # utility functions used in the dataloaders of SEN12MS-CR and SEN12MS-CR-TS def read_tif(path_IMG): tif = rasterio.open(path_IMG) return tif def read_img(tif): return tif.read().astype(np.float32) def rescale(img, oldMin, oldMax): oldRange = oldMax - oldMin img = (img - oldMin) / oldRange return img def process_MS(img, method): if method == "default": intensity_min, intensity_max = ( 0, 10000, ) # define a reasonable range of MS intensities img = np.clip( img, intensity_min, intensity_max ) # intensity clipping to a global unified MS intensity range img = rescale( img, intensity_min, intensity_max ) # project to [0,1], preserve global intensities (across patches), gets mapped to [-1,+1] in wrapper if method == "resnet": intensity_min, intensity_max = ( 0, 10000, ) # define a reasonable range of MS intensities img = np.clip( img, intensity_min, intensity_max ) # intensity clipping to a global unified MS intensity range img /= 2000 # project to [0,5], preserve global intensities (across patches) img = np.nan_to_num(img) return img def process_SAR(img, method): if method == "default": dB_min, dB_max = -25, 0 # define a reasonable range of SAR dB img = np.clip( img, dB_min, dB_max ) # intensity clipping to a global unified SAR dB range img = rescale( img, dB_min, dB_max ) # project to [0,1], preserve global intensities (across patches), gets mapped to [-1,+1] in wrapper if method == "resnet": # project SAR to [0, 2] range dB_min, dB_max = [-25.0, -32.5], [0, 0] img = np.concatenate( [ ( 2 * (np.clip(img[0], dB_min[0], dB_max[0]) - dB_min[0]) / (dB_max[0] - dB_min[0]) )[None, ...], ( 2 * (np.clip(img[1], dB_min[1], dB_max[1]) - dB_min[1]) / (dB_max[1] - dB_min[1]) )[None, ...], ], axis=0, ) img = np.nan_to_num(img) return img def get_cloud_cloudshadow_mask(img, cloud_threshold=0.2): cloud_mask = get_cloud_mask(img, cloud_threshold, binarize=True) shadow_mask = get_shadow_mask(img) # encode clouds and shadows as segmentation masks cloud_cloudshadow_mask = np.zeros_like(cloud_mask) cloud_cloudshadow_mask[shadow_mask < 0] = -1 cloud_cloudshadow_mask[cloud_mask > 0] = 1 # label clouds and shadows cloud_cloudshadow_mask[cloud_cloudshadow_mask != 0] = 1 return cloud_cloudshadow_mask # recursively apply function to nested dictionary def iterdict(dictionary, fct): for k, v in dictionary.items(): if isinstance(v, dict): dictionary[k] = iterdict(v, fct) else: dictionary[k] = fct(v) return dictionary def get_cloud_map(img, detector, instance=None): # get cloud masks img = np.clip(img, 0, 10000) mask = np.ones((img.shape[-1], img.shape[-1])) # note: if your model may suffer from dark pixel artifacts, # you may consider adjusting these filtering parameters if not (img.mean() < 1e-5 and img.std() < 1e-5): if detector == "cloud_cloudshadow_mask": threshold = 0.2 # set to e.g. 0.2 or 0.4 mask = get_cloud_cloudshadow_mask(img, threshold) elif detector == "s2cloudless_map": threshold = 0.5 mask = instance.get_cloud_probability_maps( np.moveaxis(img / 10000, 0, -1)[None, ...] )[0, ...] mask[mask < threshold] = 0 mask = gaussian_filter(mask, sigma=2) elif detector == "s2cloudless_mask": mask = instance.get_cloud_masks(np.moveaxis(img / 10000, 0, -1)[None, ...])[ 0, ... ] else: mask = np.ones((img.shape[-1], img.shape[-1])) warnings.warn(f"Method {detector} not yet implemented!") else: warnings.warn(f"Encountered a blank sample, defaulting to cloudy mask.") return mask.astype(np.float32) # function to fetch paired data, which may differ in modalities or dates def get_pairedS1(patch_list, root_dir, mod=None, time=None): paired_list = [] for patch in patch_list: seed, roi, modality, time_number, fname = patch.split("/") time = time_number if time is None else time # unless overwriting, ... mod = ( modality if mod is None else mod ) # keep the patch list's original time and modality n_patch = fname.split("patch_")[-1].split(".tif")[0] paired_dir = os.path.join(seed, roi, mod.upper(), str(time)) candidates = os.path.join( root_dir, paired_dir, f"{mod}_{seed}_{roi}_ImgNo_{time}_*_patch_{n_patch}.tif", ) paired_list.append( os.path.join(paired_dir, os.path.basename(glob.glob(candidates)[0])) ) return paired_list """ SEN12MSCR data loader class, inherits from torch.utils.data.Dataset IN: root: str, path to your copy of the SEN12MS-CR-TS data set split: str, in [all | train | val | test] region: str, [all | africa | america | asiaEast | asiaWest | europa] cloud_masks: str, type of cloud mask detector to run on optical data, in [] sample_type: str, [generic | cloudy_cloudfree] n_input_samples: int, number of input samples in time series rescale_method: str, [default | resnet] OUT: data_loader: SEN12MSCRTS instance, implements an iterator that can be traversed via __getitem__(pdx), which returns the pdx-th dictionary of patch-samples (whose structure depends on sample_type) """ class SEN12MSCR(Dataset): def __init__( self, root, split="all", region="all", cloud_masks="s2cloudless_mask", sample_type="pretrain", rescale_method="default", ): self.root_dir = root # set root directory which contains all ROI self.region = region # region according to which the ROI are selected if self.region != "all": raise NotImplementedError # TODO: currently only supporting 'all' self.ROI = { "ROIs1158": ["106"], "ROIs1868": [ "17", "36", "56", "73", "85", "100", "114", "119", "121", "126", "127", "139", "142", "143", ], "ROIs1970": [ "20", "21", "35", "40", "57", "65", "71", "82", "83", "91", "112", "116", "119", "128", "132", "133", "135", "139", "142", "144", "149", ], "ROIs2017": [ "8", "22", "25", "32", "49", "61", "63", "69", "75", "103", "108", "115", "116", "117", "130", "140", "146", ], } # define splits conform with SEN12MS-CR-TS self.splits = {} self.splits["train"] = [ "ROIs1970_fall_s1/s1_3", "ROIs1970_fall_s1/s1_22", "ROIs1970_fall_s1/s1_148", "ROIs1970_fall_s1/s1_107", "ROIs1970_fall_s1/s1_1", "ROIs1970_fall_s1/s1_114", "ROIs1970_fall_s1/s1_135", "ROIs1970_fall_s1/s1_40", "ROIs1970_fall_s1/s1_42", "ROIs1970_fall_s1/s1_31", "ROIs1970_fall_s1/s1_149", "ROIs1970_fall_s1/s1_64", "ROIs1970_fall_s1/s1_28", "ROIs1970_fall_s1/s1_144", "ROIs1970_fall_s1/s1_57", "ROIs1970_fall_s1/s1_35", "ROIs1970_fall_s1/s1_133", "ROIs1970_fall_s1/s1_30", "ROIs1970_fall_s1/s1_134", "ROIs1970_fall_s1/s1_141", "ROIs1970_fall_s1/s1_112", "ROIs1970_fall_s1/s1_116", "ROIs1970_fall_s1/s1_37", "ROIs1970_fall_s1/s1_26", "ROIs1970_fall_s1/s1_77", "ROIs1970_fall_s1/s1_100", "ROIs1970_fall_s1/s1_83", "ROIs1970_fall_s1/s1_71", "ROIs1970_fall_s1/s1_93", "ROIs1970_fall_s1/s1_119", "ROIs1970_fall_s1/s1_104", "ROIs1970_fall_s1/s1_136", "ROIs1970_fall_s1/s1_6", "ROIs1970_fall_s1/s1_41", "ROIs1970_fall_s1/s1_125", "ROIs1970_fall_s1/s1_91", "ROIs1970_fall_s1/s1_131", "ROIs1970_fall_s1/s1_120", "ROIs1970_fall_s1/s1_110", "ROIs1970_fall_s1/s1_19", "ROIs1970_fall_s1/s1_14", "ROIs1970_fall_s1/s1_81", "ROIs1970_fall_s1/s1_39", "ROIs1970_fall_s1/s1_109", "ROIs1970_fall_s1/s1_33", "ROIs1970_fall_s1/s1_88", "ROIs1970_fall_s1/s1_11", "ROIs1970_fall_s1/s1_128", "ROIs1970_fall_s1/s1_142", "ROIs1970_fall_s1/s1_122", "ROIs1970_fall_s1/s1_4", "ROIs1970_fall_s1/s1_27", "ROIs1970_fall_s1/s1_147", "ROIs1970_fall_s1/s1_85", "ROIs1970_fall_s1/s1_82", "ROIs1970_fall_s1/s1_105", "ROIs1158_spring_s1/s1_9", "ROIs1158_spring_s1/s1_1", "ROIs1158_spring_s1/s1_124", "ROIs1158_spring_s1/s1_40", "ROIs1158_spring_s1/s1_101", "ROIs1158_spring_s1/s1_21", "ROIs1158_spring_s1/s1_134", "ROIs1158_spring_s1/s1_145", "ROIs1158_spring_s1/s1_141", "ROIs1158_spring_s1/s1_66", "ROIs1158_spring_s1/s1_8", "ROIs1158_spring_s1/s1_26", "ROIs1158_spring_s1/s1_77", "ROIs1158_spring_s1/s1_113", "ROIs1158_spring_s1/s1_100", "ROIs1158_spring_s1/s1_117", "ROIs1158_spring_s1/s1_119", "ROIs1158_spring_s1/s1_6", "ROIs1158_spring_s1/s1_58", "ROIs1158_spring_s1/s1_120", "ROIs1158_spring_s1/s1_110", "ROIs1158_spring_s1/s1_126", "ROIs1158_spring_s1/s1_115", "ROIs1158_spring_s1/s1_121", "ROIs1158_spring_s1/s1_39", "ROIs1158_spring_s1/s1_109", "ROIs1158_spring_s1/s1_63", "ROIs1158_spring_s1/s1_75", "ROIs1158_spring_s1/s1_132", "ROIs1158_spring_s1/s1_128", "ROIs1158_spring_s1/s1_142", "ROIs1158_spring_s1/s1_15", "ROIs1158_spring_s1/s1_45", "ROIs1158_spring_s1/s1_97", "ROIs1158_spring_s1/s1_147", "ROIs1868_summer_s1/s1_90", "ROIs1868_summer_s1/s1_87", "ROIs1868_summer_s1/s1_25", "ROIs1868_summer_s1/s1_124", "ROIs1868_summer_s1/s1_114", "ROIs1868_summer_s1/s1_135", "ROIs1868_summer_s1/s1_40", "ROIs1868_summer_s1/s1_101", "ROIs1868_summer_s1/s1_42", "ROIs1868_summer_s1/s1_31", "ROIs1868_summer_s1/s1_36", "ROIs1868_summer_s1/s1_139", "ROIs1868_summer_s1/s1_56", "ROIs1868_summer_s1/s1_133", "ROIs1868_summer_s1/s1_55", "ROIs1868_summer_s1/s1_43", "ROIs1868_summer_s1/s1_113", "ROIs1868_summer_s1/s1_76", "ROIs1868_summer_s1/s1_123", "ROIs1868_summer_s1/s1_143", "ROIs1868_summer_s1/s1_93", "ROIs1868_summer_s1/s1_125", "ROIs1868_summer_s1/s1_89", "ROIs1868_summer_s1/s1_120", "ROIs1868_summer_s1/s1_126", "ROIs1868_summer_s1/s1_72", "ROIs1868_summer_s1/s1_115", "ROIs1868_summer_s1/s1_121", "ROIs1868_summer_s1/s1_146", "ROIs1868_summer_s1/s1_140", "ROIs1868_summer_s1/s1_95", "ROIs1868_summer_s1/s1_102", "ROIs1868_summer_s1/s1_7", "ROIs1868_summer_s1/s1_11", "ROIs1868_summer_s1/s1_132", "ROIs1868_summer_s1/s1_15", "ROIs1868_summer_s1/s1_137", "ROIs1868_summer_s1/s1_4", "ROIs1868_summer_s1/s1_27", "ROIs1868_summer_s1/s1_147", "ROIs1868_summer_s1/s1_86", "ROIs1868_summer_s1/s1_47", "ROIs2017_winter_s1/s1_68", "ROIs2017_winter_s1/s1_25", "ROIs2017_winter_s1/s1_62", "ROIs2017_winter_s1/s1_135", "ROIs2017_winter_s1/s1_42", "ROIs2017_winter_s1/s1_64", "ROIs2017_winter_s1/s1_21", "ROIs2017_winter_s1/s1_55", "ROIs2017_winter_s1/s1_112", "ROIs2017_winter_s1/s1_116", "ROIs2017_winter_s1/s1_8", "ROIs2017_winter_s1/s1_59", "ROIs2017_winter_s1/s1_49", "ROIs2017_winter_s1/s1_104", "ROIs2017_winter_s1/s1_81", "ROIs2017_winter_s1/s1_146", "ROIs2017_winter_s1/s1_75", "ROIs2017_winter_s1/s1_94", "ROIs2017_winter_s1/s1_102", "ROIs2017_winter_s1/s1_61", "ROIs2017_winter_s1/s1_47", "ROIs1868_summer_s1/s1_100", # note: this ROI is also used for testing in SEN12MS-CR-TS. If you wish to combine both datasets, please comment out this line ] self.splits["val"] = [ "ROIs2017_winter_s1/s1_22", "ROIs1868_summer_s1/s1_19", "ROIs1970_fall_s1/s1_65", "ROIs1158_spring_s1/s1_17", "ROIs2017_winter_s1/s1_107", "ROIs1868_summer_s1/s1_80", "ROIs1868_summer_s1/s1_127", "ROIs2017_winter_s1/s1_130", "ROIs1868_summer_s1/s1_17", "ROIs2017_winter_s1/s1_84", ] self.splits["test"] = [ "ROIs1158_spring_s1/s1_106", "ROIs1158_spring_s1/s1_123", "ROIs1158_spring_s1/s1_140", "ROIs1158_spring_s1/s1_31", "ROIs1158_spring_s1/s1_44", "ROIs1868_summer_s1/s1_119", "ROIs1868_summer_s1/s1_73", "ROIs1970_fall_s1/s1_139", "ROIs2017_winter_s1/s1_108", "ROIs2017_winter_s1/s1_63", ] self.splits["all"] = ( self.splits["train"] + self.splits["test"] + self.splits["val"] ) self.split = split assert split in [ "all", "train", "val", "test", ], "Input dataset must be either assigned as all, train, test, or val!" assert sample_type in ["pretrain"], "Input data must be pretrain!" assert cloud_masks in [ None, "cloud_cloudshadow_mask", "s2cloudless_map", "s2cloudless_mask", ], "Unknown cloud mask type!" self.modalities = ["S1", "S2"] self.cloud_masks = cloud_masks # e.g. 'cloud_cloudshadow_mask', 's2cloudless_map', 's2cloudless_mask' self.sample_type = sample_type # e.g. 'pretrain' self.time_points = range(1) self.n_input_t = 1 # specifies the number of samples, if only part of the time series is used as an input if self.cloud_masks in ["s2cloudless_map", "s2cloudless_mask"]: self.cloud_detector = S2PixelCloudDetector( threshold=0.4, all_bands=True, average_over=4, dilation_size=2 ) else: self.cloud_detector = None self.paths = self.get_paths() self.n_samples = len(self.paths) # raise a warning if no data has been found if not self.n_samples: self.throw_warn() self.method = rescale_method # indexes all patches contained in the current data split def get_paths( self, ): # assuming for the same ROI+num, the patch numbers are the same print(f"\nProcessing paths for {self.split} split of region {self.region}") paths = [] seeds_S1 = natsorted( [s1dir for s1dir in os.listdir(self.root_dir) if "_s1" in s1dir] ) for seed in seeds_S1: rois_S1 = natsorted(os.listdir(os.path.join(self.root_dir, seed))) for roi in rois_S1: roi_dir = os.path.join(self.root_dir, seed, roi) paths_S1 = natsorted( [os.path.join(roi_dir, s1patch) for s1patch in os.listdir(roi_dir)] ) paths_S2 = [ patch.replace("/s1", "/s2").replace("_s1", "_s2") for patch in paths_S1 ] paths_S2_cloudy = [ patch.replace("/s1", "/s2_cloudy").replace("_s1", "_s2_cloudy") for patch in paths_S1 ] for pdx, _ in enumerate(paths_S1): # omit patches that are potentially unpaired if not all( [ os.path.isfile(paths_S1[pdx]), os.path.isfile(paths_S2[pdx]), os.path.isfile(paths_S2_cloudy[pdx]), ] ): continue # don't add patch if not belonging to the selected split if not any( [ split_roi in paths_S1[pdx] for split_roi in self.splits[self.split] ] ): continue sample = { "S1": paths_S1[pdx], "S2": paths_S2[pdx], "S2_cloudy": paths_S2_cloudy[pdx], } paths.append(sample) return paths def __getitem__(self, pdx): # get the triplet of patch with ID pdx s1_tif = read_tif(self.paths[pdx]["S1"]) s2_tif = read_tif(self.paths[pdx]["S2"]) s2_cloudy_tif = read_tif(self.paths[pdx]["S2_cloudy"]) coord = list(s2_tif.bounds) s1 = process_SAR(read_img(s1_tif), self.method) s2 = read_img(s2_tif) # note: pre-processing happens after cloud detection s2_cloudy = read_img( s2_cloudy_tif ) # note: pre-processing happens after cloud detection mask = ( None if not self.cloud_masks else get_cloud_map(s2_cloudy, self.cloud_masks, self.cloud_detector) ) sample = { "input": { "S1": s1, "S2": process_MS(s2_cloudy, self.method), "masks": mask, "coverage": np.mean(mask), "S1 path": os.path.join(self.root_dir, self.paths[pdx]["S1"]), "S2 path": os.path.join(self.root_dir, self.paths[pdx]["S2_cloudy"]), "coord": coord, }, "target": { "S2": process_MS(s2, self.method), "S2 path": os.path.join(self.root_dir, self.paths[pdx]["S2"]), "coord": coord, }, } return sample def throw_warn(self): warnings.warn( """No data samples found! Please use the following directory structure: path/to/your/SEN12MSCR/directory: ├───ROIs1158_spring_s1 | ├─s1_1 | | |... | | ├─ROIs1158_spring_s1_1_p407.tif | | |... | ... ├───ROIs1158_spring_s2 | ├─s2_1 | | |... | | ├─ROIs1158_spring_s2_1_p407.tif | | |... | ... ├───ROIs1158_spring_s2_cloudy | ├─s2_cloudy_1 | | |... | | ├─ROIs1158_spring_s2_cloudy_1_p407.tif | | |... | ... ... Note: Please arrange the dataset in a format as e.g. provided by the script dl_data.sh. """ ) def __len__(self): # length of generated list return self.n_samples if __name__ == "__main__": dataset = SEN12MSCR( root="data2/SEN12MSCR", split="all", region="all", cloud_masks="s2cloudless_mask", sample_type="pretrain", rescale_method="default", ) for each in dataset: print(f"{each['input']['S1'].shape}") print(f"{each['input']['S2'].shape}") print(f"{each['input']['masks'].shape}") print(f"{each['target']['S2'].shape}") # (2, 256, 256) # (13, 256, 256) # (256, 256) # (13, 256, 256) break