diffcr-models / UnCRtainTS /util /pre_compute_data_samples.py
XavierJiezou's picture
Upload folder using huggingface_hub
3c8ff2e verified
"""
Python script to pre-compute cloud coverage statistics on the data of SEN12MS-CR-TS.
The data loader performs online sampling of input and target patches depending on its flags
(e.g.: split, region, n_input_samples, min_cov, max_cov, ) and the patches' calculated cloud coverage.
If using sampler='random', patches can also vary across epochs to act as data augmentation mechanism.
However, online computing of cloud masks can slow down data loading. A solution is to pre-compute
cloud coverage an relief the dataloader from re-computing each sample, which is what this script offers.
Currently, pre-calculated statistics are exported in an *.npy file, a collection of which is readily
available for download via https://syncandshare.lrz.de/getlink/fiHhwCqr7ch3X39XoGYaUGM8/splits
Pre-computed statistics can be imported via the dataloader's "import_data_path" argument.
"""
import os
import sys
import time
import random
import numpy as np
from tqdm import tqdm
import resource
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
# see: https://docs.python.org/3/library/resource.html#resource.RLIM_INFINITY
resource.setrlimit(resource.RLIMIT_NOFILE, (int(1024*1e3), rlimit[1]))
import torch
dirname = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.dirname(dirname))
from data.dataLoader import SEN12MSCRTS
# fix all RNG seeds
seed = 1
np.random.seed(seed)
torch.manual_seed(seed)
def seed_worker(worker_id):
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)
g = torch.Generator()
g.manual_seed(seed)
pathify = lambda path_list: [os.path.join(*path[0].split('/')[-6:]) for path in path_list]
if __name__ == '__main__':
# main parameters for instantiating SEN12MS-CR-TS
root = '/home/data/SEN12MSCRTS' # path to your copy of SEN12MS-CR-TS
split = 'test' # ROI to sample from, belonging to splits [all | train | val | test]
input_t = 3 # number of input time points to sample (irrelevant if choosing sample_type='generic')
region = 'all' # choose the region of data input. [all | africa | america | asiaEast | asiaWest | europa]
sample_type = 'generic' # type of samples returned [cloudy_cloudfree | generic]
import_data_path = None # path to importing the suppl. file specifying what time points to load for input and output, e.g. os.path.join(os.getcwd(), 'util', '3_test_s2cloudless_mask.npy')
export_data_path = os.path.join(dirname, 'precomputed') # e.g. ...'/3_all_train_vary_s2cloudless_mask.npy'
vary = 'random' if split!='test' else 'fixed' # whether to vary samples across epoch or not
n_epochs = 1 if vary=='fixed' or sample_type=='generic' else 30 # if not varying dates across epochs, then a single epoch is sufficient
max_samples = int(1e9)
shuffle = False
if export_data_path is not None: # if exporting data indices to file then need to disable DataLoader shuffling, else pdx are not sorted (they may still be shuffled when importing)
shuffle = False # ---for importing, shuffling may change the order from that of the exported file (which may or may not be desired)
sen12mscrts = SEN12MSCRTS(root, split=split, sample_type=sample_type, n_input_samples=input_t, region=region, sampler=vary, import_data_path=import_data_path)
# instantiate dataloader, note: worker_init_fn is needed to get reproducible random samples across runs if vary_samples=True
# note: if using 'export_data_path' then keep batch_size at 1 (unless moving data writing out of dataloader)
# and shuffle=False (processes patches in order, but later imports can still shuffle this)
dataloader = torch.utils.data.DataLoader(sen12mscrts, batch_size=1, shuffle=shuffle, worker_init_fn=seed_worker, generator=g, num_workers=0)
if export_data_path is not None:
data_pairs = {} # collect pre-computed dates in a dict to be exported
epoch_count = 0 # count, for loading time points that vary across epochs
collect_var = [] # collect variance across S2 intensities
# iterate over data to pre-compute indices for e.g. training or testing
start_timer = time.time()
for epoch in range(1, n_epochs + 1):
print(f'\nCurating indices for {epoch}. epoch.')
for pdx, patch in enumerate(tqdm(dataloader)):
# stop sampling when sample count is exceeded
if pdx>=max_samples: break
if sample_type == 'generic':
# collect variances in all samples' S2 intensities, finally compute grand average variance
collect_var.append(torch.stack(patch['S2']).var())
if export_data_path is not None:
if sample_type == 'cloudy_cloudfree':
# compute epoch-sensitive index, such that exported dates can differ across epochs
adj_pdx = epoch_count*dataloader.dataset.__len__() + pdx
# performs repeated writing to file, only use this for processes dedicated for exporting
# and if so, only use a single thread of workers (--num_threads 1), this ain't thread-safe
data_pairs[adj_pdx] = {'input': patch['input']['idx'], 'target': patch['target']['idx'],
'coverage': {'input': patch['input']['coverage'],
'output': patch['output']['coverage']},
'paths': {'input': {'S1': pathify(patch['input']['S1 path']),
'S2': pathify(patch['input']['S2 path'])},
'output': {'S1': pathify(patch['target']['S1 path']),
'S2': pathify(patch['target']['S2 path'])}}}
elif sample_type == 'generic':
# performs repeated writing to file, only use this for processes dedicated for exporting
# and if so, only use a single thread of workers (--num_threads 1), this ain't thread-safe
data_pairs[pdx] = {'coverage': patch['coverage'],
'paths': {'S1': pathify(patch['S1 path']),
'S2': pathify(patch['S2 path'])}}
if sample_type == 'generic':
# export collected dates
# eiter do this here after each epoch or after all epochs
if export_data_path is not None:
ds = dataloader.dataset
if os.path.isdir(export_data_path):
export_here = os.path.join(export_data_path, f'{sample_type}_{input_t}_{split}_{region}_{ds.cloud_masks}.npy')
else:
export_here = export_data_path
np.save(export_here, data_pairs)
print(f'\nEpoch {epoch_count+1}/{n_epochs}: Exported pre-computed dates to {export_here}')
# bookkeeping at the end of epoch
epoch_count += 1
print(f'The grand average variance of S2 samples in the {split} split is: {torch.mean(torch.tensor(collect_var))}')
if export_data_path is not None: print('Completed exporting data.')
# benchmark speed of dataloader when (not) using 'import_data_path' flag
elapsed = time.time() - start_timer
print(f'Elapsed time is {elapsed}')