File size: 1,865 Bytes
3c8ff2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# minimal Python script to demonstrate utilizing the pyTorch data loader for SEN12MS-CR and SEN12MS-CR-TS

import os
import torch
from data.dataLoader import SEN12MSCR, SEN12MSCRTS

if __name__ == '__main__':

    # main parameters for instantiating SEN12MS-CR-TS
    dataset     = 'SEN12MS-CR-TS'               # choose either 'SEN12MS-CR' or 'SEN12MS-CR-TS'
    root        = '/home/data/'                 # path to your copy of SEN12MS-CR or SEN12MS-CR-TS
    split       = 'all'                         # ROI to sample from, belonging to splits [all | train | val | test]
    input_t     = 3                             # number of input time points to sample, only relevant for SEN12MS-CR-TS
    import_path = None                          # path to importing the suppl. file specifying what time points to load for input and output
    sample_type = 'cloudy_cloudfree'            # type of samples returned [cloudy_cloudfree | generic]

    assert dataset in ['SEN12MS-CR', 'SEN12MS-CR-TS']
    if dataset  =='SEN12MS-CR':  loader = SEN12MSCR(os.path.join(root, 'SEN12MSCR'), split=split)
    else: loader = SEN12MSCRTS(os.path.join(root, 'SEN12MSCRTS'), split=split, sample_type=sample_type, n_input_samples=input_t, import_data_path=import_path)
    dataloader   = torch.utils.data.DataLoader(loader, batch_size=1, shuffle=False, num_workers=10)
    
    # iterate over split and do some data accessing for demonstration
    for pdx, patch in enumerate(dataloader):
        print(f'Fetching {pdx}. batch of data.')

        input_s1  = patch['input']['S1']
        input_s2  = patch['input']['S2']
        input_c   = sum(patch['input']['coverage'])/len(patch['input']['coverage'])
        output_s2 = patch['target']['S2']
        if dataset=='SEN12MS-CR-TS':
            dates_s1  = patch['input']['S1 TD']
            dates_s2  = patch['input']['S2 TD']