balthou's picture
initiate demo
cec5823
from torch.utils.data import DataLoader
from rstor.data.synthetic_dataloader import DeadLeavesDataset, DeadLeavesDatasetGPU
from rstor.data.stored_images_dataloader import RestorationDataset
from rstor.properties import (
DATALOADER, BATCH_SIZE, TRAIN, VALIDATION, LENGTH, CONFIG_DEAD_LEAVES, SIZE, NAME, CONFIG_DEGRADATION,
DATASET_SYNTH_LIST, DATASET_DIV2K,
DATASET_PATH
)
from typing import Optional
from random import seed, shuffle
def get_data_loader_synthetic(config, frozen_seed=42):
# print(config[DATALOADER].get(CONFIG_DEAD_LEAVES, {}))
if config[DATALOADER].get("gpu_gen", False):
print("Using GPU dead leaves generator")
ds = DeadLeavesDatasetGPU
else:
ds = DeadLeavesDataset
dl_train = ds(config[DATALOADER][SIZE], config[DATALOADER][LENGTH][TRAIN],
frozen_seed=None, **config[DATALOADER].get(CONFIG_DEAD_LEAVES, {}))
dl_valid = ds(config[DATALOADER][SIZE], config[DATALOADER][LENGTH][VALIDATION],
frozen_seed=frozen_seed, **config[DATALOADER].get(CONFIG_DEAD_LEAVES, {}))
dl_dict = create_dataloaders(config, dl_train, dl_valid)
return dl_dict
def create_dataloaders(config, dl_train, dl_valid) -> dict:
dl_dict = {
TRAIN: DataLoader(
dl_train,
shuffle=True,
batch_size=config[DATALOADER][BATCH_SIZE][TRAIN],
),
VALIDATION: DataLoader(
dl_valid,
shuffle=False,
batch_size=config[DATALOADER][BATCH_SIZE][VALIDATION]
),
# TEST: DataLoader(dl_test, shuffle=False, batch_size=config[DATALOADER][BATCH_SIZE][TEST])
}
return dl_dict
def get_data_loader_from_disk(config, frozen_seed: Optional[int] = 42) -> dict:
ds = RestorationDataset
dataset_name = config[DATALOADER][NAME] # NAME shall be here!
if dataset_name == DATASET_DIV2K:
dataset_root = DATASET_PATH/DATASET_DIV2K
train_root = dataset_root/"DIV2K_train_HR"/"DIV2K_train_HR"
valid_root = dataset_root/"DIV2K_valid_HR"/"DIV2K_valid_HR"
train_files = sorted(list(train_root.glob("*.png")))
train_files = 5*train_files # Just to get 4000 elements...
valid_files = sorted(list(valid_root.glob("*.png")))
elif dataset_name in DATASET_SYNTH_LIST:
dataset_root = DATASET_PATH/dataset_name
all_files = sorted(list(dataset_root.glob("*.png")))
seed(frozen_seed)
shuffle(all_files) # Easy way to perform cross validation if neeeded
cut_index = int(0.9*len(all_files))
train_files = all_files[:cut_index]
valid_files = all_files[cut_index:]
dl_train = ds(
train_files,
size=config[DATALOADER][SIZE],
frozen_seed=None,
**config[DATALOADER].get(CONFIG_DEGRADATION, {})
)
dl_valid = ds(
valid_files,
size=config[DATALOADER][SIZE],
frozen_seed=frozen_seed,
**config[DATALOADER].get(CONFIG_DEGRADATION, {})
)
dl_dict = create_dataloaders(config, dl_train, dl_valid)
return dl_dict
def get_data_loader(config, frozen_seed=42):
dataset_name = config[DATALOADER].get(NAME, False)
if dataset_name:
return get_data_loader_from_disk(config, frozen_seed)
else:
return get_data_loader_synthetic(config, frozen_seed)
if __name__ == "__main__":
# Example of usage synthetic dataset
for dataset_name in [DATASET_DIV2K, None, DATASET_DL_DIV2K_512, DATASET_DL_DIV2K_1024]:
if dataset_name is None:
dead_leaves_dataset = DeadLeavesDatasetGPU(colored=True)
dl = DataLoader(dead_leaves_dataset, batch_size=4, shuffle=True)
else:
# Example of usage stored images dataset
config = {
DATALOADER: {
NAME: dataset_name,
SIZE: (128, 128),
BATCH_SIZE: {
TRAIN: 4,
VALIDATION: 4
},
}
}
dl_dict = get_data_loader(config)
dl = dl_dict[TRAIN]
# dl = dl_dict[VALIDATION]
for i, (batch_inp, batch_target) in enumerate(dl):
print(batch_inp.shape, batch_target.shape) # Should print [batch_size, size[0], size[1], 3] for each batch
if i == 1: # Just to break the loop after two batches for demonstration
import matplotlib.pyplot as plt
plt.subplot(1, 2, 1)
plt.imshow(batch_inp.permute(0, 2, 3, 1).reshape(-1, batch_inp.shape[-1], 3).cpu().numpy())
plt.title("Degraded")
plt.subplot(1, 2, 2)
plt.imshow(batch_target.permute(0, 2, 3, 1).reshape(-1, batch_inp.shape[-1], 3).cpu().numpy())
plt.title("Target")
plt.show()
# print(batch_target)
break