import torch from rstor.data.synthetic_dataloader import DeadLeavesDataset def test_dead_leaves_dataset(): # Test case 1: Default parameters dataset = DeadLeavesDataset(noise_stddev=(0, 0), ds_factor=1) assert len(dataset) == 1000 assert dataset.size == (128, 128) assert dataset.frozen_seed is None assert dataset.config_dead_leaves == {} # Test case 2: Custom parameters dataset = DeadLeavesDataset(size=(256, 256), length=500, frozen_seed=42, number_of_circles=5, background_color=(0.2, 0.4, 0.6), colored=True, radius_min=1, radius_alpha=3, noise_stddev=(0, 0), ds_factor=1) assert len(dataset) == 500 assert dataset.size == (256, 256) assert dataset.frozen_seed == 42 assert dataset.config_dead_leaves == { 'number_of_circles': 5, 'background_color': (0.2, 0.4, 0.6), 'colored': True, 'radius_min': 1, 'radius_alpha': 3 } # Test case 3: Check item retrieval item, item_tgt = dataset[0] assert isinstance(item, torch.Tensor) assert item.shape == (3, 256, 256) # Test case 4: Repeatable results with frozen seed dataset1 = DeadLeavesDataset(frozen_seed=42, noise_stddev=(0, 0), number_of_circles=256) dataset2 = DeadLeavesDataset(frozen_seed=42, noise_stddev=(0, 0), number_of_circles=256) item1, item_tgt1 = dataset1[0] item2, item_tgt2 = dataset2[0] assert torch.all(torch.eq(item1, item2)) # Test case 5: Visualize # dataset = DeadLeavesDataset(size=(256, 256), length=500, frozen_seed=43, # background_color=(0.2, 0.4, 0.6), colored=True, radius_min=1, radius_alpha=3, # noise_stddev=(0, 0), ds_factor=1) # item, item_tgt = dataset[0] # import matplotlib.pyplot as plt # plt.figure() # plt.imshow(item.permute(1, 2, 0).detach().cpu()) # plt.show() # print("done")