|
import math |
|
from functools import partial |
|
|
|
import cv2 |
|
import numpy as np |
|
import torch |
|
from basicsr.data import degradations as degradations |
|
from basicsr.data.transforms import augment |
|
from basicsr.utils import img2tensor |
|
from torch.nn.functional import interpolate |
|
from torchvision.transforms import Compose |
|
from utils.basicsr_custom import ( |
|
random_mixed_kernels, |
|
random_add_gaussian_noise, |
|
random_add_jpg_compression, |
|
) |
|
|
|
|
|
def create_degradation(degradation): |
|
if degradation == 'sr_bicubic_x8_gaussian_noise_005': |
|
return Compose([ |
|
partial(down_scale, scale_factor=1.0 / 8.0, mode='bicubic'), |
|
partial(add_gaussian_noise, std=0.05), |
|
partial(interpolate, scale_factor=8.0, mode='nearest-exact'), |
|
partial(torch.clip, min=0, max=1), |
|
partial(torch.squeeze, dim=0), |
|
lambda x: (x, None) |
|
|
|
]) |
|
elif degradation == 'gaussian_noise_035': |
|
return Compose([ |
|
partial(add_gaussian_noise, std=0.35), |
|
partial(torch.clip, min=0, max=1), |
|
partial(torch.squeeze, dim=0), |
|
lambda x: (x, None) |
|
|
|
]) |
|
elif degradation == 'colorization_gaussian_noise_025': |
|
return Compose([ |
|
lambda x: torch.mean(x, dim=0, keepdim=True), |
|
partial(add_gaussian_noise, std=0.25), |
|
partial(torch.clip, min=0, max=1), |
|
lambda x: (x, None) |
|
]) |
|
elif degradation == 'random_inpainting_gaussian_noise_01': |
|
def inpainting_dps(x): |
|
total = x.shape[1] ** 2 |
|
|
|
l, h = [0.9, 0.9] |
|
prob = np.random.uniform(l, h) |
|
mask_vec = torch.ones([1, x.shape[1] * x.shape[1]]) |
|
samples = np.random.choice(x.shape[1] * x.shape[1], int(total * prob), replace=False) |
|
mask_vec[:, samples] = 0 |
|
mask_b = mask_vec.view(1, x.shape[1], x.shape[1]) |
|
mask_b = mask_b.repeat(3, 1, 1) |
|
mask = torch.ones_like(x, device=x.device) |
|
mask[:, ...] = mask_b |
|
return add_gaussian_noise(x * mask, 0.1).clip(0, 1), None |
|
|
|
return inpainting_dps |
|
elif degradation == 'difface': |
|
def deg(x): |
|
blur_kernel_size = 41 |
|
kernel_list = ['iso', 'aniso'] |
|
kernel_prob = [0.5, 0.5] |
|
blur_sigma = [0.1, 15] |
|
downsample_range = [0.8, 32] |
|
noise_range = [0, 20] |
|
jpeg_range = [30, 100] |
|
gt_gray = True |
|
gray_prob = 0.01 |
|
x = x.permute(1, 2, 0).numpy()[..., ::-1].astype(np.float32) |
|
|
|
img_gt = augment(x.copy(), hflip=True, rotation=False) |
|
h, w, _ = img_gt.shape |
|
|
|
|
|
|
|
kernel = degradations.random_mixed_kernels( |
|
kernel_list, |
|
kernel_prob, |
|
blur_kernel_size, |
|
blur_sigma, |
|
blur_sigma, [-math.pi, math.pi], |
|
noise_range=None) |
|
img_lq = cv2.filter2D(img_gt, -1, kernel) |
|
|
|
scale = np.random.uniform(downsample_range[0], downsample_range[1]) |
|
img_lq = cv2.resize(img_lq, (int(w // scale), int(h // scale)), interpolation=cv2.INTER_LINEAR) |
|
|
|
if noise_range is not None: |
|
img_lq = random_add_gaussian_noise(img_lq, noise_range) |
|
|
|
if jpeg_range is not None: |
|
img_lq = random_add_jpg_compression(img_lq, jpeg_range) |
|
|
|
|
|
img_lq = cv2.resize(img_lq, (w, h), interpolation=cv2.INTER_LINEAR) |
|
|
|
|
|
|
|
|
|
|
|
if np.random.uniform() < gray_prob: |
|
img_lq = cv2.cvtColor(img_lq, cv2.COLOR_BGR2GRAY) |
|
img_lq = np.tile(img_lq[:, :, None], [1, 1, 3]) |
|
if gt_gray: |
|
img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2GRAY) |
|
img_gt = np.tile(img_gt[:, :, None], [1, 1, 3]) |
|
|
|
|
|
img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
img_lq = torch.clamp((img_lq * 255.0).round(), 0, 255) / 255. |
|
|
|
return img_lq, img_gt.clip(0, 1) |
|
|
|
return deg |
|
else: |
|
raise NotImplementedError() |
|
|
|
|
|
def down_scale(x, scale_factor, mode): |
|
with torch.no_grad(): |
|
return interpolate(x.unsqueeze(0), |
|
scale_factor=scale_factor, |
|
mode=mode, |
|
antialias=True, |
|
align_corners=False).clip(0, 1) |
|
|
|
|
|
def add_gaussian_noise(x, std): |
|
with torch.no_grad(): |
|
x = x + torch.randn_like(x) * std |
|
return x |
|
|