from torch.utils.data import DataLoader from rstor.data.synthetic_dataloader import DeadLeavesDataset, DeadLeavesDatasetGPU from rstor.data.stored_images_dataloader import RestorationDataset from rstor.properties import ( DATALOADER, BATCH_SIZE, TRAIN, VALIDATION, LENGTH, CONFIG_DEAD_LEAVES, SIZE, NAME, CONFIG_DEGRADATION, DATASET_SYNTH_LIST, DATASET_DIV2K, DATASET_PATH ) from typing import Optional from random import seed, shuffle def get_data_loader_synthetic(config, frozen_seed=42): # print(config[DATALOADER].get(CONFIG_DEAD_LEAVES, {})) if config[DATALOADER].get("gpu_gen", False): print("Using GPU dead leaves generator") ds = DeadLeavesDatasetGPU else: ds = DeadLeavesDataset dl_train = ds(config[DATALOADER][SIZE], config[DATALOADER][LENGTH][TRAIN], frozen_seed=None, **config[DATALOADER].get(CONFIG_DEAD_LEAVES, {})) dl_valid = ds(config[DATALOADER][SIZE], config[DATALOADER][LENGTH][VALIDATION], frozen_seed=frozen_seed, **config[DATALOADER].get(CONFIG_DEAD_LEAVES, {})) dl_dict = create_dataloaders(config, dl_train, dl_valid) return dl_dict def create_dataloaders(config, dl_train, dl_valid) -> dict: dl_dict = { TRAIN: DataLoader( dl_train, shuffle=True, batch_size=config[DATALOADER][BATCH_SIZE][TRAIN], ), VALIDATION: DataLoader( dl_valid, shuffle=False, batch_size=config[DATALOADER][BATCH_SIZE][VALIDATION] ), # TEST: DataLoader(dl_test, shuffle=False, batch_size=config[DATALOADER][BATCH_SIZE][TEST]) } return dl_dict def get_data_loader_from_disk(config, frozen_seed: Optional[int] = 42) -> dict: ds = RestorationDataset dataset_name = config[DATALOADER][NAME] # NAME shall be here! if dataset_name == DATASET_DIV2K: dataset_root = DATASET_PATH/DATASET_DIV2K train_root = dataset_root/"DIV2K_train_HR"/"DIV2K_train_HR" valid_root = dataset_root/"DIV2K_valid_HR"/"DIV2K_valid_HR" train_files = sorted(list(train_root.glob("*.png"))) train_files = 5*train_files # Just to get 4000 elements... valid_files = sorted(list(valid_root.glob("*.png"))) elif dataset_name in DATASET_SYNTH_LIST: dataset_root = DATASET_PATH/dataset_name all_files = sorted(list(dataset_root.glob("*.png"))) seed(frozen_seed) shuffle(all_files) # Easy way to perform cross validation if neeeded cut_index = int(0.9*len(all_files)) train_files = all_files[:cut_index] valid_files = all_files[cut_index:] dl_train = ds( train_files, size=config[DATALOADER][SIZE], frozen_seed=None, **config[DATALOADER].get(CONFIG_DEGRADATION, {}) ) dl_valid = ds( valid_files, size=config[DATALOADER][SIZE], frozen_seed=frozen_seed, **config[DATALOADER].get(CONFIG_DEGRADATION, {}) ) dl_dict = create_dataloaders(config, dl_train, dl_valid) return dl_dict def get_data_loader(config, frozen_seed=42): dataset_name = config[DATALOADER].get(NAME, False) if dataset_name: return get_data_loader_from_disk(config, frozen_seed) else: return get_data_loader_synthetic(config, frozen_seed) if __name__ == "__main__": # Example of usage synthetic dataset for dataset_name in [DATASET_DIV2K, None, DATASET_DL_DIV2K_512, DATASET_DL_DIV2K_1024]: if dataset_name is None: dead_leaves_dataset = DeadLeavesDatasetGPU(colored=True) dl = DataLoader(dead_leaves_dataset, batch_size=4, shuffle=True) else: # Example of usage stored images dataset config = { DATALOADER: { NAME: dataset_name, SIZE: (128, 128), BATCH_SIZE: { TRAIN: 4, VALIDATION: 4 }, } } dl_dict = get_data_loader(config) dl = dl_dict[TRAIN] # dl = dl_dict[VALIDATION] for i, (batch_inp, batch_target) in enumerate(dl): print(batch_inp.shape, batch_target.shape) # Should print [batch_size, size[0], size[1], 3] for each batch if i == 1: # Just to break the loop after two batches for demonstration import matplotlib.pyplot as plt plt.subplot(1, 2, 1) plt.imshow(batch_inp.permute(0, 2, 3, 1).reshape(-1, batch_inp.shape[-1], 3).cpu().numpy()) plt.title("Degraded") plt.subplot(1, 2, 2) plt.imshow(batch_target.permute(0, 2, 3, 1).reshape(-1, batch_inp.shape[-1], 3).cpu().numpy()) plt.title("Target") plt.show() # print(batch_target) break