Spaces:
Running
Running
File size: 5,756 Bytes
cec5823 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
# -*- coding: utf-8 -*-
"""
Created on Sun Mar 24 01:21:46 2024
@author: jamyl
"""
import torch
from rstor.properties import DATASET_BLUR_KERNEL_PATH
import random
from scipy.io import loadmat
import cv2
class Degradation():
def __init__(self,
length: int = 1000,
frozen_seed: int = None):
self.frozen_seed = frozen_seed
self.current_degradation = {}
class DegradationNoise(Degradation):
def __init__(self,
length: int = 1000,
noise_stddev: float = [0., 50.],
frozen_seed: int = None):
super().__init__(length, frozen_seed)
self.noise_stddev = noise_stddev
if frozen_seed is not None:
random.seed(frozen_seed)
self.noise_stddev = [(self.noise_stddev[1] - self.noise_stddev[0]) *
random.random() + self.noise_stddev[0] for _ in range(length)]
def __call__(self, x: torch.Tensor, idx: int):
# WARNING! INPLACE OPERATIONS!!!!!
# expects x of shape [b, c, h, w]
assert x.ndim == 4
assert x.shape[1] in [1, 3]
if self.frozen_seed is not None:
std_dev = self.noise_stddev[idx]
else:
std_dev = (self.noise_stddev[1] - self.noise_stddev[0]) * random.random() + self.noise_stddev[0]
if std_dev > 0.:
# x += (std_dev/255.)*np.random.randn(*x.shape)
x += (std_dev/255.)*torch.randn(*x.shape, device=x.device)
self.current_degradation[idx] = {
"noise_stddev": std_dev
}
return x
class DegradationBlurMat(Degradation):
def __init__(self,
length: int = 1000,
frozen_seed: int = None,
blur_index: int = None):
super().__init__(length, frozen_seed)
kernels = loadmat(DATASET_BLUR_KERNEL_PATH)["kernels"].squeeze()
# conversion to torch (the shape of the kernel is not constant)
self.kernels = tuple([
torch.from_numpy(kernel/kernel.sum(keepdims=True)).unsqueeze(0).unsqueeze(0)
for kernel in kernels] + [torch.ones((1, 1)).unsqueeze(0).unsqueeze(0)])
self.n_kernels = len(self.kernels)
if frozen_seed is not None:
random.seed(frozen_seed)
self.kernel_ids = [random.randint(0, self.n_kernels-1) for _ in range(length)]
if blur_index is not None:
self.frozen_seed = 42
self.kernel_ids = [blur_index for _ in range(length)]
def __call__(self, x: torch.Tensor, idx: int):
# expects x of shape [b, c, h, w]
assert x.ndim == 4
assert x.shape[1] in [1, 3]
device = x.device
if self.frozen_seed is not None:
kernel_id = self.kernel_ids[idx]
else:
kernel_id = random.randint(0, self.n_kernels-1)
kernel = self.kernels[kernel_id].to(device).repeat(3, 1, 1, 1).float() # repeat for grouped conv
_, _, kh, kw = kernel.shape
# We use padding = same to make
# sure that the output size does not depend on the kernel.
# define nn.Conf layer to define both padding mode and padding value...
conv_layer = torch.nn.Conv2d(in_channels=x.shape[1],
out_channels=x.shape[1],
kernel_size=(kh, kw),
padding="same",
padding_mode='replicate',
groups=3,
bias=False)
# Set the predefined kernel as weights and freeze the parameters
with torch.no_grad():
conv_layer.weight = torch.nn.Parameter(kernel)
conv_layer.weight.requires_grad = False
# breakpoint()
x = conv_layer(x)
# Alternative Functional version with 0 padding :
# x = F.conv2d(x, kernel, padding="same", groups=3)
self.current_degradation[idx] = {
"blur_kernel_id": kernel_id
}
return x
class DegradationBlurGauss(Degradation):
def __init__(self,
length: int = 1000,
blur_kernel_half_size: int = [0, 2],
frozen_seed: int = None):
super().__init__(length, frozen_seed)
self.blur_kernel_half_size = blur_kernel_half_size
# conversion to torch (the shape of the kernel is not constant)
if frozen_seed is not None:
random.seed(self.frozen_seed)
self.blur_kernel_half_size = [
(
random.randint(self.blur_kernel_half_size[0], self.blur_kernel_half_size[1]),
random.randint(self.blur_kernel_half_size[0], self.blur_kernel_half_size[1])
) for _ in range(length)
]
def __call__(self, x: torch.Tensor, idx: int):
# expects x of shape [b, c, h, w]
assert x.ndim == 4
assert x.shape[1] in [1, 3]
device = x.device
if self.frozen_seed is not None:
k_size_x, k_size_y = self.blur_kernel_half_size[idx]
else:
k_size_x = random.randint(self.blur_kernel_half_size[0], self.blur_kernel_half_size[1])
k_size_y = random.randint(self.blur_kernel_half_size[0], self.blur_kernel_half_size[1])
k_size_x = 2 * k_size_x + 1
k_size_y = 2 * k_size_y + 1
x = x.squeeze(0).permute(1, 2, 0).cpu().numpy()
x = cv2.GaussianBlur(x, (k_size_x, k_size_y), 0)
x = torch.from_numpy(x).to(device).permute(2, 0, 1).unsqueeze(0)
self.current_degradation[idx] = {
"blur_kernel_half_size": (k_size_x, k_size_y),
}
return x
|