Spaces:
Running
Running
import torch | |
from rstor.data.stored_images_dataloader import RestorationDataset | |
from numba import cuda | |
from rstor.properties import DATASET_PATH, AUGMENTATION_FLIP, AUGMENTATION_ROTATE | |
def test_dataloader_stored(): | |
if not cuda.is_available(): | |
print("cuda unavailable, exiting") | |
return | |
# Test case 1: Default parameters | |
dataset = RestorationDataset(noise_stddev=(0, 0), | |
images_path=DATASET_PATH/"sample") | |
assert len(dataset) == 2 | |
assert dataset.frozen_seed is None | |
# Test case 2: Custom parameters | |
dataset = RestorationDataset(images_path=DATASET_PATH/"sample", | |
size=(64, 64), | |
frozen_seed=42, | |
noise_stddev=(0, 0)) | |
assert len(dataset) == 2 | |
assert dataset.frozen_seed == 42 | |
# Test case 3: Check item retrieval | |
item, item_tgt = dataset[0] | |
assert isinstance(item, torch.Tensor) | |
assert item.shape == item_tgt.shape | |
assert item.shape == (3, 64, 64) | |
# Test case 4: Repeatable results with frozen seed | |
dataset1 = RestorationDataset(images_path=DATASET_PATH/"sample", | |
frozen_seed=42, noise_stddev=(0, 0)) | |
dataset2 = RestorationDataset(images_path=DATASET_PATH/"sample", | |
frozen_seed=42, noise_stddev=(0, 0)) | |
item1, item_tgt1 = dataset1[0] | |
item2, item_tgt2 = dataset2[0] | |
assert torch.all(torch.eq(item1, item2)) | |
# Test case 4: Repeatable results with frozen seed and augmentation | |
augmentation_list = [AUGMENTATION_FLIP, AUGMENTATION_ROTATE] | |
dataset1 = RestorationDataset(images_path=DATASET_PATH/"sample", | |
frozen_seed=42, noise_stddev=(0, 0), | |
augmentation_list=augmentation_list) | |
dataset2 = RestorationDataset(images_path=DATASET_PATH/"sample", | |
frozen_seed=42, noise_stddev=(0, 0), | |
augmentation_list=augmentation_list) | |
item1, item_tgt1 = dataset1[0] | |
item2, item_tgt2 = dataset2[0] | |
assert torch.all(torch.eq(item1, item2)) | |
# Test case 5: Visualize | |
# dataset = RestorationDataset(images_path=DATASET_PATH/"sample", | |
# noise_stddev=(0, 0), | |
# augmentation_list=augmentation_list) | |
# item, item_tgt = dataset[0] | |
# import matplotlib.pyplot as plt | |
# plt.figure() | |
# plt.imshow(item.permute(1, 2, 0).detach().cpu()) | |
# plt.show() | |
# breakpoint() | |
print("done") | |