Spaces:
Running
Running
import torch | |
import torch.nn.functional as F | |
from torch.utils.data import Dataset | |
from typing import Tuple | |
from rstor.data.degradation import DegradationBlurMat, DegradationBlurGauss, DegradationNoise | |
from rstor.properties import DEVICE, AUGMENTATION_FLIP, DEGRADATION_BLUR_NONE, DEGRADATION_BLUR_MAT, DEGRADATION_BLUR_GAUSS | |
from rstor.synthetic_data.dead_leaves_cpu import cpu_dead_leaves_chart | |
from rstor.synthetic_data.dead_leaves_gpu import gpu_dead_leaves_chart | |
import cv2 | |
from skimage.filters import gaussian | |
import random | |
import numpy as np | |
from rstor.utils import DEFAULT_TORCH_FLOAT_TYPE | |
class DeadLeavesDataset(Dataset): | |
def __init__( | |
self, | |
size: Tuple[int, int] = (128, 128), | |
length: int = 1000, | |
frozen_seed: int = None, # useful for validation set! | |
blur_kernel_half_size: int = [0, 2], | |
ds_factor: int = 5, | |
noise_stddev: float = [0., 50.], | |
degradation_blur=DEGRADATION_BLUR_NONE, | |
**config_dead_leaves | |
# number_of_circles: int = -1, | |
# background_color: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5), | |
# colored: Optional[bool] = False, | |
# radius_mean: Optional[int] = -1, | |
# radius_stddev: Optional[int] = -1, | |
): | |
self.frozen_seed = frozen_seed | |
self.ds_factor = ds_factor | |
self.size = (size[0]*ds_factor, size[1]*ds_factor) | |
self.length = length | |
self.config_dead_leaves = config_dead_leaves | |
self.blur_kernel_half_size = blur_kernel_half_size | |
self.noise_stddev = noise_stddev | |
self.degradation_blur_type = degradation_blur | |
if degradation_blur == DEGRADATION_BLUR_GAUSS: | |
self.degradation_blur = DegradationBlurGauss(self.length, | |
blur_kernel_half_size, | |
frozen_seed) | |
self.blur_deg_str = "blur_kernel_half_size" | |
elif degradation_blur == DEGRADATION_BLUR_MAT: | |
self.degradation_blur = DegradationBlurMat(self.length, | |
frozen_seed) | |
self.blur_deg_str = "blur_kernel_id" | |
elif degradation_blur == DEGRADATION_BLUR_NONE: | |
pass | |
else: | |
raise ValueError(f"Unknown degradation blur {degradation_blur}") | |
self.degradation_noise = DegradationNoise(self.length, | |
noise_stddev, | |
frozen_seed) | |
self.current_degradation = {} | |
def __len__(self): | |
return self.length | |
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: | |
# TODO there is a bug on this cpu version, the dead leaved dont appear ot be right | |
seed = self.frozen_seed + idx if self.frozen_seed is not None else None | |
chart = cpu_dead_leaves_chart(self.size, seed=seed, **self.config_dead_leaves) | |
if self.ds_factor > 1: | |
# print(f"Downsampling {chart.shape} with factor {self.ds_factor}...") | |
sigma = 3/5 | |
chart = gaussian( | |
chart, sigma=(sigma, sigma, 0), mode='nearest', | |
cval=0, preserve_range=True, truncate=4.0) | |
chart = chart[::self.ds_factor, ::self.ds_factor] | |
th_chart = torch.from_numpy(chart).permute(2, 0, 1).unsqueeze(0) | |
degraded_chart = th_chart | |
self.current_degradation[idx] = {} | |
if self.degradation_blur_type != DEGRADATION_BLUR_NONE: | |
degraded_chart = self.degradation_blur(degraded_chart, idx) | |
self.current_degradation[idx][self.blur_deg_str] = self.degradation_blur.current_degradation[idx][self.blur_deg_str] | |
degraded_chart = self.degradation_noise(degraded_chart, idx) | |
self.current_degradation[idx]["noise_stddev"] = self.degradation_noise.current_degradation[idx]["noise_stddev"] | |
degraded_chart = degraded_chart.squeeze(0) | |
th_chart = th_chart.squeeze(0) | |
return degraded_chart, th_chart | |
class DeadLeavesDatasetGPU(Dataset): | |
def __init__( | |
self, | |
size: Tuple[int, int] = (128, 128), | |
length: int = 1000, | |
frozen_seed: int = None, # useful for validation set! | |
blur_kernel_half_size: int = [0, 2], | |
ds_factor: int = 5, | |
noise_stddev: float = [0., 50.], | |
use_gaussian_kernel=True, | |
**config_dead_leaves | |
# number_of_circles: int = -1, | |
# background_color: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5), | |
# colored: Optional[bool] = False, | |
# radius_mean: Optional[int] = -1, | |
# radius_stddev: Optional[int] = -1, | |
): | |
self.frozen_seed = frozen_seed | |
self.ds_factor = ds_factor | |
self.size = (size[0]*ds_factor, size[1]*ds_factor) | |
self.length = length | |
self.config_dead_leaves = config_dead_leaves | |
# downsample kernel | |
sigma = 3/5 | |
k_size = 5 # This fits with sigma = 3/5, the cutoff value is 0.0038 (neglectable) | |
x = (torch.arange(k_size) - 2).to('cuda') | |
kernel = torch.stack(torch.meshgrid((x, x), indexing='ij')) | |
kernel.requires_grad = False | |
dist_sq = kernel[0]**2 + kernel[1]**2 | |
kernel = (-dist_sq.square()/(2*sigma**2)).exp() | |
kernel = kernel / kernel.sum() | |
self.downsample_kernel = kernel.repeat(3, 1, 1, 1) # shape [3, 1, k_size, k_size] | |
self.downsample_kernel.requires_grad = False | |
self.use_gaussian_kernel = use_gaussian_kernel | |
if use_gaussian_kernel: | |
self.degradation_blur = DegradationBlurGauss(length, | |
blur_kernel_half_size, | |
frozen_seed) | |
else: | |
self.degradation_blur = DegradationBlurMat(length, | |
frozen_seed) | |
self.degradation_noise = DegradationNoise(length, | |
noise_stddev, | |
frozen_seed) | |
self.current_degradation = {} | |
def __len__(self) -> int: | |
return self.length | |
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: | |
"""Get a single deadleave chart and its degraded version. | |
Args: | |
idx (int): index of the item to retrieve | |
Returns: | |
Tuple[torch.Tensor, torch.Tensor]: degraded chart, target chart | |
""" | |
seed = self.frozen_seed + idx if self.frozen_seed is not None else None | |
# Return numba device array | |
numba_chart = gpu_dead_leaves_chart(self.size, seed=seed, **self.config_dead_leaves) | |
th_chart = torch.as_tensor(numba_chart, dtype=DEFAULT_TORCH_FLOAT_TYPE, device="cuda")[ | |
None].permute(0, 3, 1, 2) # [1, c, h, w] | |
if self.ds_factor > 1: | |
# Downsample using strided gaussian conv (sigma=3/5) | |
th_chart = F.pad(th_chart, | |
pad=(2, 2, 0, 0), | |
mode="replicate") | |
th_chart = F.conv2d(th_chart, | |
self.downsample_kernel, | |
padding='valid', | |
groups=3, | |
stride=self.ds_factor) | |
degraded_chart = self.degradation_blur(th_chart, idx) | |
degraded_chart = self.degradation_noise(degraded_chart, idx) | |
blur_deg_str = "blur_kernel_half_size" if self.use_gaussian_kernel else "blur_kernel_id" | |
self.current_degradation[idx] = { | |
blur_deg_str: self.degradation_blur.current_degradation[idx][blur_deg_str], | |
"noise_stddev": self.degradation_noise.current_degradation[idx]["noise_stddev"] | |
} | |
degraded_chart = degraded_chart.squeeze(0) | |
th_chart = th_chart.squeeze(0) | |
return degraded_chart, th_chart | |