Spaces:
Running
Running
File size: 2,619 Bytes
cec5823 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
import torch
from rstor.data.stored_images_dataloader import RestorationDataset
from numba import cuda
from rstor.properties import DATASET_PATH, AUGMENTATION_FLIP, AUGMENTATION_ROTATE
def test_dataloader_stored():
if not cuda.is_available():
print("cuda unavailable, exiting")
return
# Test case 1: Default parameters
dataset = RestorationDataset(noise_stddev=(0, 0),
images_path=DATASET_PATH/"sample")
assert len(dataset) == 2
assert dataset.frozen_seed is None
# Test case 2: Custom parameters
dataset = RestorationDataset(images_path=DATASET_PATH/"sample",
size=(64, 64),
frozen_seed=42,
noise_stddev=(0, 0))
assert len(dataset) == 2
assert dataset.frozen_seed == 42
# Test case 3: Check item retrieval
item, item_tgt = dataset[0]
assert isinstance(item, torch.Tensor)
assert item.shape == item_tgt.shape
assert item.shape == (3, 64, 64)
# Test case 4: Repeatable results with frozen seed
dataset1 = RestorationDataset(images_path=DATASET_PATH/"sample",
frozen_seed=42, noise_stddev=(0, 0))
dataset2 = RestorationDataset(images_path=DATASET_PATH/"sample",
frozen_seed=42, noise_stddev=(0, 0))
item1, item_tgt1 = dataset1[0]
item2, item_tgt2 = dataset2[0]
assert torch.all(torch.eq(item1, item2))
# Test case 4: Repeatable results with frozen seed and augmentation
augmentation_list = [AUGMENTATION_FLIP, AUGMENTATION_ROTATE]
dataset1 = RestorationDataset(images_path=DATASET_PATH/"sample",
frozen_seed=42, noise_stddev=(0, 0),
augmentation_list=augmentation_list)
dataset2 = RestorationDataset(images_path=DATASET_PATH/"sample",
frozen_seed=42, noise_stddev=(0, 0),
augmentation_list=augmentation_list)
item1, item_tgt1 = dataset1[0]
item2, item_tgt2 = dataset2[0]
assert torch.all(torch.eq(item1, item2))
# Test case 5: Visualize
# dataset = RestorationDataset(images_path=DATASET_PATH/"sample",
# noise_stddev=(0, 0),
# augmentation_list=augmentation_list)
# item, item_tgt = dataset[0]
# import matplotlib.pyplot as plt
# plt.figure()
# plt.imshow(item.permute(1, 2, 0).detach().cpu())
# plt.show()
# breakpoint()
print("done")
|