|
""" |
|
Python script to pre-compute cloud coverage statistics on the data of SEN12MS-CR-TS. |
|
The data loader performs online sampling of input and target patches depending on its flags |
|
(e.g.: split, region, n_input_samples, min_cov, max_cov, ) and the patches' calculated cloud coverage. |
|
If using sampler='random', patches can also vary across epochs to act as data augmentation mechanism. |
|
|
|
However, online computing of cloud masks can slow down data loading. A solution is to pre-compute |
|
cloud coverage an relief the dataloader from re-computing each sample, which is what this script offers. |
|
Currently, pre-calculated statistics are exported in an *.npy file, a collection of which is readily |
|
available for download via https://syncandshare.lrz.de/getlink/fiHhwCqr7ch3X39XoGYaUGM8/splits |
|
|
|
Pre-computed statistics can be imported via the dataloader's "import_data_path" argument. |
|
""" |
|
|
|
import os |
|
import sys |
|
import time |
|
import random |
|
import numpy as np |
|
from tqdm import tqdm |
|
|
|
import resource |
|
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) |
|
|
|
resource.setrlimit(resource.RLIMIT_NOFILE, (int(1024*1e3), rlimit[1])) |
|
|
|
import torch |
|
dirname = os.path.dirname(os.path.abspath(__file__)) |
|
sys.path.append(os.path.dirname(dirname)) |
|
from data.dataLoader import SEN12MSCRTS |
|
|
|
|
|
seed = 1 |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
|
|
def seed_worker(worker_id): |
|
worker_seed = torch.initial_seed() % 2**32 |
|
np.random.seed(worker_seed) |
|
random.seed(worker_seed) |
|
|
|
g = torch.Generator() |
|
g.manual_seed(seed) |
|
|
|
|
|
pathify = lambda path_list: [os.path.join(*path[0].split('/')[-6:]) for path in path_list] |
|
|
|
if __name__ == '__main__': |
|
|
|
root = '/home/data/SEN12MSCRTS' |
|
split = 'test' |
|
input_t = 3 |
|
region = 'all' |
|
sample_type = 'generic' |
|
import_data_path = None |
|
export_data_path = os.path.join(dirname, 'precomputed') |
|
vary = 'random' if split!='test' else 'fixed' |
|
n_epochs = 1 if vary=='fixed' or sample_type=='generic' else 30 |
|
max_samples = int(1e9) |
|
|
|
shuffle = False |
|
if export_data_path is not None: |
|
shuffle = False |
|
|
|
sen12mscrts = SEN12MSCRTS(root, split=split, sample_type=sample_type, n_input_samples=input_t, region=region, sampler=vary, import_data_path=import_data_path) |
|
|
|
|
|
|
|
dataloader = torch.utils.data.DataLoader(sen12mscrts, batch_size=1, shuffle=shuffle, worker_init_fn=seed_worker, generator=g, num_workers=0) |
|
|
|
if export_data_path is not None: |
|
data_pairs = {} |
|
epoch_count = 0 |
|
collect_var = [] |
|
|
|
|
|
start_timer = time.time() |
|
for epoch in range(1, n_epochs + 1): |
|
print(f'\nCurating indices for {epoch}. epoch.') |
|
for pdx, patch in enumerate(tqdm(dataloader)): |
|
|
|
if pdx>=max_samples: break |
|
|
|
if sample_type == 'generic': |
|
|
|
collect_var.append(torch.stack(patch['S2']).var()) |
|
|
|
if export_data_path is not None: |
|
if sample_type == 'cloudy_cloudfree': |
|
|
|
adj_pdx = epoch_count*dataloader.dataset.__len__() + pdx |
|
|
|
|
|
data_pairs[adj_pdx] = {'input': patch['input']['idx'], 'target': patch['target']['idx'], |
|
'coverage': {'input': patch['input']['coverage'], |
|
'output': patch['output']['coverage']}, |
|
'paths': {'input': {'S1': pathify(patch['input']['S1 path']), |
|
'S2': pathify(patch['input']['S2 path'])}, |
|
'output': {'S1': pathify(patch['target']['S1 path']), |
|
'S2': pathify(patch['target']['S2 path'])}}} |
|
elif sample_type == 'generic': |
|
|
|
|
|
data_pairs[pdx] = {'coverage': patch['coverage'], |
|
'paths': {'S1': pathify(patch['S1 path']), |
|
'S2': pathify(patch['S2 path'])}} |
|
if sample_type == 'generic': |
|
|
|
|
|
if export_data_path is not None: |
|
ds = dataloader.dataset |
|
if os.path.isdir(export_data_path): |
|
export_here = os.path.join(export_data_path, f'{sample_type}_{input_t}_{split}_{region}_{ds.cloud_masks}.npy') |
|
else: |
|
export_here = export_data_path |
|
np.save(export_here, data_pairs) |
|
print(f'\nEpoch {epoch_count+1}/{n_epochs}: Exported pre-computed dates to {export_here}') |
|
|
|
|
|
epoch_count += 1 |
|
|
|
print(f'The grand average variance of S2 samples in the {split} split is: {torch.mean(torch.tensor(collect_var))}') |
|
|
|
if export_data_path is not None: print('Completed exporting data.') |
|
|
|
|
|
elapsed = time.time() - start_timer |
|
print(f'Elapsed time is {elapsed}') |