File size: 2,619 Bytes
cec5823
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch
from rstor.data.stored_images_dataloader import RestorationDataset
from numba import cuda
from rstor.properties import DATASET_PATH, AUGMENTATION_FLIP, AUGMENTATION_ROTATE


def test_dataloader_stored():
    if not cuda.is_available():
        print("cuda unavailable, exiting")
        return

    # Test case 1: Default parameters
    dataset = RestorationDataset(noise_stddev=(0, 0),
                                 images_path=DATASET_PATH/"sample")
    assert len(dataset) == 2
    assert dataset.frozen_seed is None

    # Test case 2: Custom parameters
    dataset = RestorationDataset(images_path=DATASET_PATH/"sample",
                                 size=(64, 64),
                                 frozen_seed=42,
                                 noise_stddev=(0, 0))
    assert len(dataset) == 2
    assert dataset.frozen_seed == 42

    # Test case 3: Check item retrieval
    item, item_tgt = dataset[0]
    assert isinstance(item, torch.Tensor)
    assert item.shape == item_tgt.shape
    assert item.shape == (3, 64, 64)

    # Test case 4: Repeatable results with frozen seed
    dataset1 = RestorationDataset(images_path=DATASET_PATH/"sample", 
                                  frozen_seed=42, noise_stddev=(0, 0))
    dataset2 = RestorationDataset(images_path=DATASET_PATH/"sample",
                                  frozen_seed=42, noise_stddev=(0, 0))
    item1, item_tgt1 = dataset1[0]
    item2, item_tgt2 = dataset2[0]
    
    assert torch.all(torch.eq(item1, item2))
    
    # Test case 4: Repeatable results with frozen seed and augmentation
    augmentation_list = [AUGMENTATION_FLIP, AUGMENTATION_ROTATE]
    dataset1 = RestorationDataset(images_path=DATASET_PATH/"sample", 
                                  frozen_seed=42, noise_stddev=(0, 0),
                                  augmentation_list=augmentation_list)
    dataset2 = RestorationDataset(images_path=DATASET_PATH/"sample",
                                  frozen_seed=42, noise_stddev=(0, 0),
                                  augmentation_list=augmentation_list)
    item1, item_tgt1 = dataset1[0]
    item2, item_tgt2 = dataset2[0]
    assert torch.all(torch.eq(item1, item2))



    # Test case 5: Visualize
    # dataset = RestorationDataset(images_path=DATASET_PATH/"sample",
    #                              noise_stddev=(0, 0),
    #                              augmentation_list=augmentation_list)
    # item, item_tgt = dataset[0]
    # import matplotlib.pyplot as plt
    # plt.figure()
    # plt.imshow(item.permute(1, 2, 0).detach().cpu())
    # plt.show()
    # breakpoint()
    print("done")