# -*- coding: utf-8 -*- """ Created on Sun Mar 24 01:21:46 2024 @author: jamyl """ import torch from rstor.properties import DATASET_BLUR_KERNEL_PATH import random from scipy.io import loadmat import cv2 class Degradation(): def __init__(self, length: int = 1000, frozen_seed: int = None): self.frozen_seed = frozen_seed self.current_degradation = {} class DegradationNoise(Degradation): def __init__(self, length: int = 1000, noise_stddev: float = [0., 50.], frozen_seed: int = None): super().__init__(length, frozen_seed) self.noise_stddev = noise_stddev if frozen_seed is not None: random.seed(frozen_seed) self.noise_stddev = [(self.noise_stddev[1] - self.noise_stddev[0]) * random.random() + self.noise_stddev[0] for _ in range(length)] def __call__(self, x: torch.Tensor, idx: int): # WARNING! INPLACE OPERATIONS!!!!! # expects x of shape [b, c, h, w] assert x.ndim == 4 assert x.shape[1] in [1, 3] if self.frozen_seed is not None: std_dev = self.noise_stddev[idx] else: std_dev = (self.noise_stddev[1] - self.noise_stddev[0]) * random.random() + self.noise_stddev[0] if std_dev > 0.: # x += (std_dev/255.)*np.random.randn(*x.shape) x += (std_dev/255.)*torch.randn(*x.shape, device=x.device) self.current_degradation[idx] = { "noise_stddev": std_dev } return x class DegradationBlurMat(Degradation): def __init__(self, length: int = 1000, frozen_seed: int = None, blur_index: int = None): super().__init__(length, frozen_seed) kernels = loadmat(DATASET_BLUR_KERNEL_PATH)["kernels"].squeeze() # conversion to torch (the shape of the kernel is not constant) self.kernels = tuple([ torch.from_numpy(kernel/kernel.sum(keepdims=True)).unsqueeze(0).unsqueeze(0) for kernel in kernels] + [torch.ones((1, 1)).unsqueeze(0).unsqueeze(0)]) self.n_kernels = len(self.kernels) if frozen_seed is not None: random.seed(frozen_seed) self.kernel_ids = [random.randint(0, self.n_kernels-1) for _ in range(length)] if blur_index is not None: self.frozen_seed = 42 self.kernel_ids = [blur_index for _ in range(length)] def __call__(self, x: torch.Tensor, idx: int): # expects x of shape [b, c, h, w] assert x.ndim == 4 assert x.shape[1] in [1, 3] device = x.device if self.frozen_seed is not None: kernel_id = self.kernel_ids[idx] else: kernel_id = random.randint(0, self.n_kernels-1) kernel = self.kernels[kernel_id].to(device).repeat(3, 1, 1, 1).float() # repeat for grouped conv _, _, kh, kw = kernel.shape # We use padding = same to make # sure that the output size does not depend on the kernel. # define nn.Conf layer to define both padding mode and padding value... conv_layer = torch.nn.Conv2d(in_channels=x.shape[1], out_channels=x.shape[1], kernel_size=(kh, kw), padding="same", padding_mode='replicate', groups=3, bias=False) # Set the predefined kernel as weights and freeze the parameters with torch.no_grad(): conv_layer.weight = torch.nn.Parameter(kernel) conv_layer.weight.requires_grad = False # breakpoint() x = conv_layer(x) # Alternative Functional version with 0 padding : # x = F.conv2d(x, kernel, padding="same", groups=3) self.current_degradation[idx] = { "blur_kernel_id": kernel_id } return x class DegradationBlurGauss(Degradation): def __init__(self, length: int = 1000, blur_kernel_half_size: int = [0, 2], frozen_seed: int = None): super().__init__(length, frozen_seed) self.blur_kernel_half_size = blur_kernel_half_size # conversion to torch (the shape of the kernel is not constant) if frozen_seed is not None: random.seed(self.frozen_seed) self.blur_kernel_half_size = [ ( random.randint(self.blur_kernel_half_size[0], self.blur_kernel_half_size[1]), random.randint(self.blur_kernel_half_size[0], self.blur_kernel_half_size[1]) ) for _ in range(length) ] def __call__(self, x: torch.Tensor, idx: int): # expects x of shape [b, c, h, w] assert x.ndim == 4 assert x.shape[1] in [1, 3] device = x.device if self.frozen_seed is not None: k_size_x, k_size_y = self.blur_kernel_half_size[idx] else: k_size_x = random.randint(self.blur_kernel_half_size[0], self.blur_kernel_half_size[1]) k_size_y = random.randint(self.blur_kernel_half_size[0], self.blur_kernel_half_size[1]) k_size_x = 2 * k_size_x + 1 k_size_y = 2 * k_size_y + 1 x = x.squeeze(0).permute(1, 2, 0).cpu().numpy() x = cv2.GaussianBlur(x, (k_size_x, k_size_y), 0) x = torch.from_numpy(x).to(device).permute(2, 0, 1).unsqueeze(0) self.current_degradation[idx] = { "blur_kernel_half_size": (k_size_x, k_size_y), } return x