File size: 6,080 Bytes
cec5823
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import torch
from torch.utils.data import DataLoader, Dataset
from rstor.data.augmentation import augment_flip
from rstor.data.degradation import DegradationBlurMat, DegradationBlurGauss, DegradationNoise
from rstor.properties import DEVICE, AUGMENTATION_FLIP, AUGMENTATION_ROTATE, DEGRADATION_BLUR_NONE, DEGRADATION_BLUR_MAT, DEGRADATION_BLUR_GAUSS
from rstor.properties import DATALOADER, BATCH_SIZE, TRAIN, VALIDATION, LENGTH, CONFIG_DEAD_LEAVES, SIZE
from typing import Tuple, Optional, Union
from torchvision import transforms
# from torchvision.transforms import RandomCrop
from pathlib import Path
from tqdm import tqdm
from time import time
from torchvision.io import read_image
IMAGES_FOLDER = "images"


def load_image(path):
    return read_image(str(path))


class RestorationDataset(Dataset):
    def __init__(
        self,
        images_path: Path,
        size: Tuple[int, int] = (128, 128),
        device: str = DEVICE,
        preloaded: bool = False,
        augmentation_list: Optional[list] = [],
        frozen_seed: int = None,  # useful for validation set!
        blur_kernel_half_size: int = [0, 2],
        noise_stddev: float = [0., 50.],
        degradation_blur=DEGRADATION_BLUR_NONE,
        blur_index=None,
        **_extra_kwargs
    ):
        self.preloaded = preloaded
        self.augmentation_list = augmentation_list
        self.device = device
        self.frozen_seed = frozen_seed
        if not isinstance(images_path, list):
            self.path_list = sorted(list(images_path.glob("*.png")))
        else:
            self.path_list = images_path

        self.length = len(self.path_list)

        self.n_samples = len(self.path_list)
        # If we can preload everything in memory, we can do it
        if preloaded:
            self.data_list = [load_image(pth) for pth in tqdm(self.path_list)]
        else:
            self.data_list = self.path_list

        # if AUGMENTATION_FLIP in self.augmentation_list:
        #     img_data = augment_flip(img_data)
        # img_data = self.cropper(img_data)
        self.transforms = []

        if self.frozen_seed is None:
            if AUGMENTATION_FLIP in self.augmentation_list:
                self.transforms.append(transforms.RandomHorizontalFlip(p=0.5))
                self.transforms.append(transforms.RandomVerticalFlip(p=0.5))
            if AUGMENTATION_ROTATE in self.augmentation_list:
                self.transforms.append(transforms.RandomRotation(degrees=180))

        crop = transforms.RandomCrop(size) if frozen_seed is None else transforms.CenterCrop(size)
        self.transforms.append(crop)
        self.transforms = transforms.Compose(self.transforms)

        # self.cropper = RandomCrop(size=size)

        self.degradation_blur_type = degradation_blur
        if degradation_blur == DEGRADATION_BLUR_GAUSS:
            self.degradation_blur = DegradationBlurGauss(self.length,
                                                         blur_kernel_half_size,
                                                         frozen_seed)
            self.blur_deg_str = "blur_kernel_half_size"
        elif degradation_blur == DEGRADATION_BLUR_MAT:
            self.degradation_blur = DegradationBlurMat(self.length,
                                                       frozen_seed,
                                                       blur_index)
            self.blur_deg_str = "blur_kernel_id"
        elif degradation_blur == DEGRADATION_BLUR_NONE:
            pass
        else:
            raise ValueError(f"Unknown degradation blur {degradation_blur}")

        self.degradation_noise = DegradationNoise(self.length,
                                                  noise_stddev,
                                                  frozen_seed)
        self.current_degradation = {}

    def __getitem__(self, index: int) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]:
        """Access a specific image from dataset and augment

        Args:
            index (int): access index

        Returns:
            torch.Tensor: image tensor [C, H, W]
        """
        if self.preloaded:
            img_data = self.data_list[index]
        else:
            img_data = load_image(self.data_list[index])
        img_data = img_data.to(self.device)

        # if AUGMENTATION_FLIP in self.augmentation_list:
        #     img_data = augment_flip(img_data)
        # img_data = self.cropper(img_data)

        img_data = self.transforms(img_data)
        img_data = img_data.float()/255.
        degraded_img = img_data.clone().unsqueeze(0)

        self.current_degradation[index] = {}
        if self.degradation_blur_type != DEGRADATION_BLUR_NONE:
            degraded_img = self.degradation_blur(degraded_img, index)
            self.current_degradation[index][self.blur_deg_str] = self.degradation_blur.current_degradation[index][self.blur_deg_str]

        degraded_img = self.degradation_noise(degraded_img, index)
        self.current_degradation[index]["noise_stddev"] = self.degradation_noise.current_degradation[index]["noise_stddev"]

        degraded_img = degraded_img.squeeze(0)
        self.current_degradation[index] = {
            "noise_stddev": self.degradation_noise.current_degradation[index]["noise_stddev"]
        }
        try:
            self.current_degradation[index][self.blur_deg_str] = self.degradation_blur.current_degradation[index][self.blur_deg_str]
        except KeyError:
            pass

        return degraded_img, img_data

    def __len__(self):
        return self.n_samples


if __name__ == "__main__":
    dataset_restoration = RestorationDataset(
        Path("__dataset/div2k/DIV2K_train_HR/DIV2K_train_HR/"),
        preloaded=True,
    )
    dataloader = DataLoader(
        dataset_restoration,
        batch_size=16,
        shuffle=True
    )
    start = time()
    total = 0
    for batch in tqdm(dataloader):
        # print(batch.shape)
        torch.cuda.synchronize()
        total += batch.shape[0]
    end = time()
    print(f"Time elapsed: {(end-start)/total*1000.:.2f}ms/image")