|
|
|
|
|
import os |
|
import torch |
|
from data.dataLoader import SEN12MSCR, SEN12MSCRTS |
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
dataset = 'SEN12MS-CR-TS' |
|
root = '/home/data/' |
|
split = 'all' |
|
input_t = 3 |
|
import_path = None |
|
sample_type = 'cloudy_cloudfree' |
|
|
|
assert dataset in ['SEN12MS-CR', 'SEN12MS-CR-TS'] |
|
if dataset =='SEN12MS-CR': loader = SEN12MSCR(os.path.join(root, 'SEN12MSCR'), split=split) |
|
else: loader = SEN12MSCRTS(os.path.join(root, 'SEN12MSCRTS'), split=split, sample_type=sample_type, n_input_samples=input_t, import_data_path=import_path) |
|
dataloader = torch.utils.data.DataLoader(loader, batch_size=1, shuffle=False, num_workers=10) |
|
|
|
|
|
for pdx, patch in enumerate(dataloader): |
|
print(f'Fetching {pdx}. batch of data.') |
|
|
|
input_s1 = patch['input']['S1'] |
|
input_s2 = patch['input']['S2'] |
|
input_c = sum(patch['input']['coverage'])/len(patch['input']['coverage']) |
|
output_s2 = patch['target']['S2'] |
|
if dataset=='SEN12MS-CR-TS': |
|
dates_s1 = patch['input']['S1 TD'] |
|
dates_s2 = patch['input']['S2 TD'] |
|
|