Spaces:
Running
Running
File size: 1,981 Bytes
cec5823 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
import torch
from rstor.data.synthetic_dataloader import DeadLeavesDataset
def test_dead_leaves_dataset():
# Test case 1: Default parameters
dataset = DeadLeavesDataset(noise_stddev=(0, 0), ds_factor=1)
assert len(dataset) == 1000
assert dataset.size == (128, 128)
assert dataset.frozen_seed is None
assert dataset.config_dead_leaves == {}
# Test case 2: Custom parameters
dataset = DeadLeavesDataset(size=(256, 256), length=500, frozen_seed=42, number_of_circles=5,
background_color=(0.2, 0.4, 0.6), colored=True, radius_min=1, radius_alpha=3,
noise_stddev=(0, 0), ds_factor=1)
assert len(dataset) == 500
assert dataset.size == (256, 256)
assert dataset.frozen_seed == 42
assert dataset.config_dead_leaves == {
'number_of_circles': 5,
'background_color': (0.2, 0.4, 0.6),
'colored': True,
'radius_min': 1,
'radius_alpha': 3
}
# Test case 3: Check item retrieval
item, item_tgt = dataset[0]
assert isinstance(item, torch.Tensor)
assert item.shape == (3, 256, 256)
# Test case 4: Repeatable results with frozen seed
dataset1 = DeadLeavesDataset(frozen_seed=42, noise_stddev=(0, 0), number_of_circles=256)
dataset2 = DeadLeavesDataset(frozen_seed=42, noise_stddev=(0, 0), number_of_circles=256)
item1, item_tgt1 = dataset1[0]
item2, item_tgt2 = dataset2[0]
assert torch.all(torch.eq(item1, item2))
# Test case 5: Visualize
# dataset = DeadLeavesDataset(size=(256, 256), length=500, frozen_seed=43,
# background_color=(0.2, 0.4, 0.6), colored=True, radius_min=1, radius_alpha=3,
# noise_stddev=(0, 0), ds_factor=1)
# item, item_tgt = dataset[0]
# import matplotlib.pyplot as plt
# plt.figure()
# plt.imshow(item.permute(1, 2, 0).detach().cpu())
# plt.show()
# print("done")
|