File size: 8,105 Bytes
3c8ff2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
"""
 Python script to pre-compute cloud coverage statistics on the data of SEN12MS-CR-TS.
    The data loader performs online sampling of input and target patches depending on its flags
    (e.g.: split, region, n_input_samples, min_cov, max_cov, ) and the patches' calculated cloud coverage.
    If using sampler='random', patches can also vary across epochs to act as data augmentation mechanism.

    However, online computing of cloud masks can slow down data loading. A solution is to pre-compute
    cloud coverage an relief the dataloader from re-computing each sample, which is what this script offers. 
    Currently, pre-calculated statistics are exported in an *.npy file, a collection of which is readily
    available for download via https://syncandshare.lrz.de/getlink/fiHhwCqr7ch3X39XoGYaUGM8/splits
    
    Pre-computed statistics can be imported via the dataloader's "import_data_path" argument.
"""

import os
import sys
import time
import random
import numpy as np
from tqdm import tqdm

import resource
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
# see: https://docs.python.org/3/library/resource.html#resource.RLIM_INFINITY
resource.setrlimit(resource.RLIMIT_NOFILE, (int(1024*1e3), rlimit[1]))

import torch
dirname = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.dirname(dirname))
from data.dataLoader import SEN12MSCRTS

# fix all RNG seeds
seed = 1
np.random.seed(seed)
torch.manual_seed(seed)

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

g = torch.Generator()
g.manual_seed(seed)


pathify = lambda path_list: [os.path.join(*path[0].split('/')[-6:]) for path in path_list]

if __name__ == '__main__':
    # main parameters for instantiating SEN12MS-CR-TS
    root                = '/home/data/SEN12MSCRTS'                              # path to your copy of SEN12MS-CR-TS
    split               = 'test'                                                # ROI to sample from, belonging to splits [all | train | val | test]
    input_t             = 3                                                     # number of input time points to sample (irrelevant if choosing sample_type='generic')
    region              = 'all'                                                 # choose the region of data input. [all | africa | america | asiaEast | asiaWest | europa]
    sample_type         = 'generic'                                             # type of samples returned [cloudy_cloudfree | generic]
    import_data_path    = None                                                  # path to importing the suppl. file specifying what time points to load for input and output, e.g. os.path.join(os.getcwd(), 'util', '3_test_s2cloudless_mask.npy')
    export_data_path    = os.path.join(dirname, 'precomputed')                  # e.g. ...'/3_all_train_vary_s2cloudless_mask.npy'
    vary                = 'random' if split!='test' else 'fixed'                # whether to vary samples across epoch or not
    n_epochs            = 1 if vary=='fixed' or sample_type=='generic' else 30  # if not varying dates across epochs, then a single epoch is sufficient
    max_samples         = int(1e9)

    shuffle             = False
    if export_data_path is not None:                # if exporting data indices to file then need to disable DataLoader shuffling, else pdx are not sorted (they may still be shuffled when importing)
        shuffle = False                             # ---for importing, shuffling may change the order from that of the exported file (which may or may not be desired)

    sen12mscrts         = SEN12MSCRTS(root, split=split, sample_type=sample_type, n_input_samples=input_t, region=region, sampler=vary, import_data_path=import_data_path)
    # instantiate dataloader, note: worker_init_fn is needed to get reproducible random samples across runs if vary_samples=True
    # note: if using 'export_data_path' then keep batch_size at 1 (unless moving data writing out of dataloader)
    #                                   and shuffle=False (processes patches in order, but later imports can still shuffle this)
    dataloader          = torch.utils.data.DataLoader(sen12mscrts, batch_size=1, shuffle=shuffle, worker_init_fn=seed_worker, generator=g, num_workers=0)
    
    if export_data_path is not None: 
        data_pairs  = {}  # collect pre-computed dates in a dict to be exported
        epoch_count = 0   # count, for loading time points that vary across epochs
    collect_var = []      # collect variance across S2 intensities

    # iterate over data to pre-compute indices for e.g. training or testing
    start_timer = time.time()
    for epoch in range(1, n_epochs + 1):
        print(f'\nCurating indices for {epoch}. epoch.')
        for pdx, patch in enumerate(tqdm(dataloader)):
            # stop sampling when sample count is exceeded
            if pdx>=max_samples: break
            
            if sample_type == 'generic':
                # collect variances in all samples' S2 intensities, finally compute grand average variance
                collect_var.append(torch.stack(patch['S2']).var())

                if export_data_path is not None:
                    if sample_type == 'cloudy_cloudfree':
                        # compute epoch-sensitive index, such that exported dates can differ across epochs 
                        adj_pdx = epoch_count*dataloader.dataset.__len__() + pdx
                        # performs repeated writing to file, only use this for processes dedicated for exporting
                        # and if so, only use a single thread of workers (--num_threads 1), this ain't thread-safe
                        data_pairs[adj_pdx] = {'input':     patch['input']['idx'], 'target': patch['target']['idx'],
                                               'coverage':  {'input': patch['input']['coverage'],
                                                             'output': patch['output']['coverage']},
                                               'paths':     {'input':  {'S1': pathify(patch['input']['S1 path']),
                                                                        'S2': pathify(patch['input']['S2 path'])},
                                                             'output': {'S1': pathify(patch['target']['S1 path']),
                                                                        'S2': pathify(patch['target']['S2 path'])}}}
                    elif sample_type == 'generic':
                        # performs repeated writing to file, only use this for processes dedicated for exporting
                        # and if so, only use a single thread of workers (--num_threads 1), this ain't thread-safe
                        data_pairs[pdx] = {'coverage':  patch['coverage'],
                                           'paths':     {'S1': pathify(patch['S1 path']),
                                                         'S2': pathify(patch['S2 path'])}}
        if sample_type == 'generic':    
            # export collected dates
            # eiter do this here after each epoch or after all epochs
            if export_data_path is not None:
                ds = dataloader.dataset
                if os.path.isdir(export_data_path):
                    export_here = os.path.join(export_data_path, f'{sample_type}_{input_t}_{split}_{region}_{ds.cloud_masks}.npy')
                else:
                    export_here = export_data_path
                np.save(export_here, data_pairs)
                print(f'\nEpoch {epoch_count+1}/{n_epochs}: Exported pre-computed dates to {export_here}')

                # bookkeeping at the end of epoch
                epoch_count += 1

    print(f'The grand average variance of S2 samples in the {split} split is: {torch.mean(torch.tensor(collect_var))}')

    if export_data_path is not None: print('Completed exporting data.')

    # benchmark speed of dataloader when (not) using 'import_data_path' flag
    elapsed = time.time() - start_timer
    print(f'Elapsed time is {elapsed}')