image-deblurring / src /rstor /data /stored_images_dataloader.py
balthou's picture
initiate demo
cec5823
import torch
from torch.utils.data import DataLoader, Dataset
from rstor.data.augmentation import augment_flip
from rstor.data.degradation import DegradationBlurMat, DegradationBlurGauss, DegradationNoise
from rstor.properties import DEVICE, AUGMENTATION_FLIP, AUGMENTATION_ROTATE, DEGRADATION_BLUR_NONE, DEGRADATION_BLUR_MAT, DEGRADATION_BLUR_GAUSS
from rstor.properties import DATALOADER, BATCH_SIZE, TRAIN, VALIDATION, LENGTH, CONFIG_DEAD_LEAVES, SIZE
from typing import Tuple, Optional, Union
from torchvision import transforms
# from torchvision.transforms import RandomCrop
from pathlib import Path
from tqdm import tqdm
from time import time
from torchvision.io import read_image
IMAGES_FOLDER = "images"
def load_image(path):
return read_image(str(path))
class RestorationDataset(Dataset):
def __init__(
self,
images_path: Path,
size: Tuple[int, int] = (128, 128),
device: str = DEVICE,
preloaded: bool = False,
augmentation_list: Optional[list] = [],
frozen_seed: int = None, # useful for validation set!
blur_kernel_half_size: int = [0, 2],
noise_stddev: float = [0., 50.],
degradation_blur=DEGRADATION_BLUR_NONE,
blur_index=None,
**_extra_kwargs
):
self.preloaded = preloaded
self.augmentation_list = augmentation_list
self.device = device
self.frozen_seed = frozen_seed
if not isinstance(images_path, list):
self.path_list = sorted(list(images_path.glob("*.png")))
else:
self.path_list = images_path
self.length = len(self.path_list)
self.n_samples = len(self.path_list)
# If we can preload everything in memory, we can do it
if preloaded:
self.data_list = [load_image(pth) for pth in tqdm(self.path_list)]
else:
self.data_list = self.path_list
# if AUGMENTATION_FLIP in self.augmentation_list:
# img_data = augment_flip(img_data)
# img_data = self.cropper(img_data)
self.transforms = []
if self.frozen_seed is None:
if AUGMENTATION_FLIP in self.augmentation_list:
self.transforms.append(transforms.RandomHorizontalFlip(p=0.5))
self.transforms.append(transforms.RandomVerticalFlip(p=0.5))
if AUGMENTATION_ROTATE in self.augmentation_list:
self.transforms.append(transforms.RandomRotation(degrees=180))
crop = transforms.RandomCrop(size) if frozen_seed is None else transforms.CenterCrop(size)
self.transforms.append(crop)
self.transforms = transforms.Compose(self.transforms)
# self.cropper = RandomCrop(size=size)
self.degradation_blur_type = degradation_blur
if degradation_blur == DEGRADATION_BLUR_GAUSS:
self.degradation_blur = DegradationBlurGauss(self.length,
blur_kernel_half_size,
frozen_seed)
self.blur_deg_str = "blur_kernel_half_size"
elif degradation_blur == DEGRADATION_BLUR_MAT:
self.degradation_blur = DegradationBlurMat(self.length,
frozen_seed,
blur_index)
self.blur_deg_str = "blur_kernel_id"
elif degradation_blur == DEGRADATION_BLUR_NONE:
pass
else:
raise ValueError(f"Unknown degradation blur {degradation_blur}")
self.degradation_noise = DegradationNoise(self.length,
noise_stddev,
frozen_seed)
self.current_degradation = {}
def __getitem__(self, index: int) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]:
"""Access a specific image from dataset and augment
Args:
index (int): access index
Returns:
torch.Tensor: image tensor [C, H, W]
"""
if self.preloaded:
img_data = self.data_list[index]
else:
img_data = load_image(self.data_list[index])
img_data = img_data.to(self.device)
# if AUGMENTATION_FLIP in self.augmentation_list:
# img_data = augment_flip(img_data)
# img_data = self.cropper(img_data)
img_data = self.transforms(img_data)
img_data = img_data.float()/255.
degraded_img = img_data.clone().unsqueeze(0)
self.current_degradation[index] = {}
if self.degradation_blur_type != DEGRADATION_BLUR_NONE:
degraded_img = self.degradation_blur(degraded_img, index)
self.current_degradation[index][self.blur_deg_str] = self.degradation_blur.current_degradation[index][self.blur_deg_str]
degraded_img = self.degradation_noise(degraded_img, index)
self.current_degradation[index]["noise_stddev"] = self.degradation_noise.current_degradation[index]["noise_stddev"]
degraded_img = degraded_img.squeeze(0)
self.current_degradation[index] = {
"noise_stddev": self.degradation_noise.current_degradation[index]["noise_stddev"]
}
try:
self.current_degradation[index][self.blur_deg_str] = self.degradation_blur.current_degradation[index][self.blur_deg_str]
except KeyError:
pass
return degraded_img, img_data
def __len__(self):
return self.n_samples
if __name__ == "__main__":
dataset_restoration = RestorationDataset(
Path("__dataset/div2k/DIV2K_train_HR/DIV2K_train_HR/"),
preloaded=True,
)
dataloader = DataLoader(
dataset_restoration,
batch_size=16,
shuffle=True
)
start = time()
total = 0
for batch in tqdm(dataloader):
# print(batch.shape)
torch.cuda.synchronize()
total += batch.shape[0]
end = time()
print(f"Time elapsed: {(end-start)/total*1000.:.2f}ms/image")