diffcr-models / UnCRtainTS /standalone_dataloader.py
XavierJiezou's picture
Upload folder using huggingface_hub
3c8ff2e verified
# minimal Python script to demonstrate utilizing the pyTorch data loader for SEN12MS-CR and SEN12MS-CR-TS
import os
import torch
from data.dataLoader import SEN12MSCR, SEN12MSCRTS
if __name__ == '__main__':
# main parameters for instantiating SEN12MS-CR-TS
dataset = 'SEN12MS-CR-TS' # choose either 'SEN12MS-CR' or 'SEN12MS-CR-TS'
root = '/home/data/' # path to your copy of SEN12MS-CR or SEN12MS-CR-TS
split = 'all' # ROI to sample from, belonging to splits [all | train | val | test]
input_t = 3 # number of input time points to sample, only relevant for SEN12MS-CR-TS
import_path = None # path to importing the suppl. file specifying what time points to load for input and output
sample_type = 'cloudy_cloudfree' # type of samples returned [cloudy_cloudfree | generic]
assert dataset in ['SEN12MS-CR', 'SEN12MS-CR-TS']
if dataset =='SEN12MS-CR': loader = SEN12MSCR(os.path.join(root, 'SEN12MSCR'), split=split)
else: loader = SEN12MSCRTS(os.path.join(root, 'SEN12MSCRTS'), split=split, sample_type=sample_type, n_input_samples=input_t, import_data_path=import_path)
dataloader = torch.utils.data.DataLoader(loader, batch_size=1, shuffle=False, num_workers=10)
# iterate over split and do some data accessing for demonstration
for pdx, patch in enumerate(dataloader):
print(f'Fetching {pdx}. batch of data.')
input_s1 = patch['input']['S1']
input_s2 = patch['input']['S2']
input_c = sum(patch['input']['coverage'])/len(patch['input']['coverage'])
output_s2 = patch['target']['S2']
if dataset=='SEN12MS-CR-TS':
dates_s1 = patch['input']['S1 TD']
dates_s2 = patch['input']['S2 TD']