diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..20685910623de1d926427fe93b40e97d1d6bc67a --- /dev/null +++ b/.gitignore @@ -0,0 +1,30 @@ +datasets/* +.ipynb_checkpoints +.idea +__pycache__ + +datasets/ +tmp_imgs +runs/ +runs_last/ +saved_models/* +saved_models/ +pre_trained/ +save_log/* +tmp/* + +*.pyc +*.pth +*.png +*.jpg +*.mp4 +*.txt +*.json +*.zip +*.mp4 +*.csv + +!__assets__/lr_inputs/* +!__assets__/* +!__assets__/visual_results/* +!requirements.txt \ No newline at end of file diff --git a/__assets__/logo.png b/__assets__/logo.png new file mode 100644 index 0000000000000000000000000000000000000000..1cb677e309058f64ffd1bfd25671cf3f31b7fdff Binary files /dev/null and b/__assets__/logo.png differ diff --git a/__assets__/lr_inputs/41.png b/__assets__/lr_inputs/41.png new file mode 100644 index 0000000000000000000000000000000000000000..53c034638cb7369464833ebf1e39b3914999b0d7 Binary files /dev/null and b/__assets__/lr_inputs/41.png differ diff --git a/__assets__/lr_inputs/f91.jpg b/__assets__/lr_inputs/f91.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a656c6e467db079eac509dca09718275ab35d928 Binary files /dev/null and b/__assets__/lr_inputs/f91.jpg differ diff --git a/__assets__/lr_inputs/image-00164.jpg b/__assets__/lr_inputs/image-00164.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9aadd5dd0a8a1db4f999f71fb332543c3a62cfed Binary files /dev/null and b/__assets__/lr_inputs/image-00164.jpg differ diff --git a/__assets__/lr_inputs/image-00186.png b/__assets__/lr_inputs/image-00186.png new file mode 100644 index 0000000000000000000000000000000000000000..8fd0125c9b1faf74eb774120ba82e669be88cee3 Binary files /dev/null and b/__assets__/lr_inputs/image-00186.png differ diff --git a/__assets__/lr_inputs/image-00277.png b/__assets__/lr_inputs/image-00277.png new file mode 100644 index 0000000000000000000000000000000000000000..7f8630b0591c3e59496a73e7932902c2a3d37de7 Binary files /dev/null and b/__assets__/lr_inputs/image-00277.png differ diff --git a/__assets__/lr_inputs/image-00440.png b/__assets__/lr_inputs/image-00440.png new file mode 100644 index 0000000000000000000000000000000000000000..ec234d3508e698f7a28f5ff5e2aad55ff74ffde7 Binary files /dev/null and b/__assets__/lr_inputs/image-00440.png differ diff --git a/__assets__/lr_inputs/image-00542.png b/__assets__/lr_inputs/image-00542.png new file mode 100644 index 0000000000000000000000000000000000000000..fadcb24c76f3239fefc065cd3fc2e16e7f3ca727 Binary files /dev/null and b/__assets__/lr_inputs/image-00542.png differ diff --git a/__assets__/lr_inputs/img_eva.jpeg b/__assets__/lr_inputs/img_eva.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..672cd170f3c43f96e873ef0339a6b44347bdf30d Binary files /dev/null and b/__assets__/lr_inputs/img_eva.jpeg differ diff --git a/__assets__/lr_inputs/screenshot_resize.jpg b/__assets__/lr_inputs/screenshot_resize.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c65070e9fac905f574d7a77fbdeca5261b405b03 Binary files /dev/null and b/__assets__/lr_inputs/screenshot_resize.jpg differ diff --git a/__assets__/visual_results/0079_2_visual.png b/__assets__/visual_results/0079_2_visual.png new file mode 100644 index 0000000000000000000000000000000000000000..a055a958022095f711c79283764fa4bd8a7d2ca1 Binary files /dev/null and b/__assets__/visual_results/0079_2_visual.png differ diff --git a/__assets__/visual_results/0079_visual.png b/__assets__/visual_results/0079_visual.png new file mode 100644 index 0000000000000000000000000000000000000000..06c09af12f571d1173b5a030bf8f487b19b549b0 Binary files /dev/null and b/__assets__/visual_results/0079_visual.png differ diff --git a/__assets__/visual_results/eva_visual.png b/__assets__/visual_results/eva_visual.png new file mode 100644 index 0000000000000000000000000000000000000000..a7dc176be42cf97d306babcb1d8dc0778182dc78 Binary files /dev/null and b/__assets__/visual_results/eva_visual.png differ diff --git a/__assets__/visual_results/f91_visual.png b/__assets__/visual_results/f91_visual.png new file mode 100644 index 0000000000000000000000000000000000000000..9baaae2f12e77f4f0ee6fe18e23e15af5c43fa2d Binary files /dev/null and b/__assets__/visual_results/f91_visual.png differ diff --git a/__assets__/visual_results/kiteret_visual.png b/__assets__/visual_results/kiteret_visual.png new file mode 100644 index 0000000000000000000000000000000000000000..b2d7ac6eca2e938654fd040408174981771a4d88 Binary files /dev/null and b/__assets__/visual_results/kiteret_visual.png differ diff --git a/__assets__/visual_results/pokemon2_visual.png b/__assets__/visual_results/pokemon2_visual.png new file mode 100644 index 0000000000000000000000000000000000000000..4c1dd0e6161f78f3289f510b99108cbc1a84d3da Binary files /dev/null and b/__assets__/visual_results/pokemon2_visual.png differ diff --git a/__assets__/visual_results/pokemon_visual.png b/__assets__/visual_results/pokemon_visual.png new file mode 100644 index 0000000000000000000000000000000000000000..c076013da30cf08f612d4ea550b430a4347d8b5e Binary files /dev/null and b/__assets__/visual_results/pokemon_visual.png differ diff --git a/__assets__/visual_results/wataru_visual.png b/__assets__/visual_results/wataru_visual.png new file mode 100644 index 0000000000000000000000000000000000000000..2d22e5dab212147653ce608d44cf588532745e3a Binary files /dev/null and b/__assets__/visual_results/wataru_visual.png differ diff --git a/__assets__/workflow.png b/__assets__/workflow.png new file mode 100644 index 0000000000000000000000000000000000000000..a40321fd8aa5bdeed20d5301de694629c8fd6b0f Binary files /dev/null and b/__assets__/workflow.png differ diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..f7658844a251711a2eeb2da3f54a5108284bae74 --- /dev/null +++ b/app.py @@ -0,0 +1,117 @@ +import os, sys +import cv2 +import gradio as gr +import torch +import numpy as np +from torchvision.utils import save_image + + +# Import files from the local folder +root_path = os.path.abspath('.') +sys.path.append(root_path) +from test_code.inference import super_resolve_img +from test_code.test_utils import load_grl, load_rrdb + + +def auto_download_if_needed(weight_path): + if os.path.exists(weight_path): + return + + if not os.path.exists("pretrained"): + os.makedirs("pretrained") + + if weight_path == "pretrained/4x_APISR_GRL_GAN_generator.pth": + os.system("wget https://github.com/Kiteretsu77/APISR/releases/download/v0.1.0/4x_APISR_GRL_GAN_generator.pth") + os.system("mv 4x_APISR_GRL_GAN_generator.pth pretrained") + + if weight_path == "pretrained/2x_APISR_RRDB_GAN_generator.pth": + os.system("wget https://github.com/Kiteretsu77/APISR/releases/download/v0.1.0/2x_APISR_RRDB_GAN_generator.pth") + os.system("mv 2x_APISR_RRDB_GAN_generator.pth pretrained") + + + +def inference(img_path, model_name): + + try: + weight_dtype = torch.float32 + + # Load the model + if model_name == "4xGRL": + weight_path = "pretrained/4x_APISR_GRL_GAN_generator.pth" + auto_download_if_needed(weight_path) + generator = load_grl(weight_path, scale=4) # Directly use default way now + + elif model_name == "2xRRDB": + weight_path = "pretrained/2x_APISR_RRDB_GAN_generator.pth" + auto_download_if_needed(weight_path) + generator = load_rrdb(weight_path, scale=2) # Directly use default way now + + else: + raise gr.Error(error) + + generator = generator.to(dtype=weight_dtype) + + + # In default, we will automatically use crop to match 4x size + super_resolved_img = super_resolve_img(generator, img_path, output_path=None, weight_dtype=weight_dtype, crop_for_4x=True) + save_image(super_resolved_img, "SR_result.png") + outputs = cv2.imread("SR_result.png") + outputs = cv2.cvtColor(outputs, cv2.COLOR_RGB2BGR) + + return outputs + + + except Exception as error: + raise gr.Error(f"global exception: {error}") + + + +if __name__ == '__main__': + + MARKDOWN = \ + """ + ## APISR: Anime Production Inspired Real-World Anime Super-Resolution (CVPR 2024) + + [GitHub](https://github.com/Kiteretsu77/APISR) | [Paper](https://arxiv.org/abs/2403.01598) + + If APISR is helpful for you, please help star the GitHub Repo. Thanks! + """ + + block = gr.Blocks().queue() + with block: + with gr.Row(): + gr.Markdown(MARKDOWN) + with gr.Row(elem_classes=["container"]): + with gr.Column(scale=2): + input_image = gr.Image(type="filepath", label="Input") + model_name = gr.Dropdown( + [ + "2xRRDB", + "4xGRL" + ], + type="value", + value="4xGRL", + label="model", + ) + run_btn = gr.Button(value="Submit") + + with gr.Column(scale=3): + output_image = gr.Image(type="numpy", label="Output image") + + with gr.Row(elem_classes=["container"]): + gr.Examples( + [ + ["__assets__/lr_inputs/image-00277.png"], + ["__assets__/lr_inputs/image-00542.png"], + ["__assets__/lr_inputs/41.png"], + ["__assets__/lr_inputs/f91.jpg"], + ["__assets__/lr_inputs/image-00440.png"], + ["__assets__/lr_inputs/image-00164.png"], + ["__assets__/lr_inputs/img_eva.jpeg"], + ], + [input_image], + ) + + run_btn.click(inference, inputs=[input_image, model_name], outputs=[output_image]) + + block.launch() \ No newline at end of file diff --git a/architecture/cunet.py b/architecture/cunet.py new file mode 100644 index 0000000000000000000000000000000000000000..6ea00b837cd91bfb33b42e82f30805eb14fcdfbe --- /dev/null +++ b/architecture/cunet.py @@ -0,0 +1,189 @@ +# Github Repository: https://github.com/bilibili/ailab/blob/main/Real-CUGAN/README_EN.md +# Code snippet (with certain modificaiton) from: https://github.com/bilibili/ailab/blob/main/Real-CUGAN/VapourSynth/upcunet_v3_vs.py + +import torch +from torch import nn as nn +from torch.nn import functional as F +import os, sys +import numpy as np +from time import time as ttime, sleep + + +class UNet_Full(nn.Module): + + def __init__(self): + super(UNet_Full, self).__init__() + self.unet1 = UNet1(3, 3, deconv=True) + self.unet2 = UNet2(3, 3, deconv=False) + + def forward(self, x): + n, c, h0, w0 = x.shape + + ph = ((h0 - 1) // 2 + 1) * 2 + pw = ((w0 - 1) // 2 + 1) * 2 + x = F.pad(x, (18, 18 + pw - w0, 18, 18 + ph - h0), 'reflect') # In order to ensure that it can be divided by 2 + + x1 = self.unet1(x) + x2 = self.unet2(x1) + + x1 = F.pad(x1, (-20, -20, -20, -20)) + output = torch.add(x2, x1) + + if (w0 != pw or h0 != ph): + output = output[:, :, :h0 * 2, :w0 * 2] + + return output + + +class SEBlock(nn.Module): + def __init__(self, in_channels, reduction=8, bias=False): + super(SEBlock, self).__init__() + self.conv1 = nn.Conv2d(in_channels, in_channels // reduction, 1, 1, 0, bias=bias) + self.conv2 = nn.Conv2d(in_channels // reduction, in_channels, 1, 1, 0, bias=bias) + + def forward(self, x): + if ("Half" in x.type()): # torch.HalfTensor/torch.cuda.HalfTensor + x0 = torch.mean(x.float(), dim=(2, 3), keepdim=True).half() + else: + x0 = torch.mean(x, dim=(2, 3), keepdim=True) + x0 = self.conv1(x0) + x0 = F.relu(x0, inplace=True) + x0 = self.conv2(x0) + x0 = torch.sigmoid(x0) + x = torch.mul(x, x0) + return x + +class UNetConv(nn.Module): + def __init__(self, in_channels, mid_channels, out_channels, se): + super(UNetConv, self).__init__() + self.conv = nn.Sequential( + nn.Conv2d(in_channels, mid_channels, 3, 1, 0), + nn.LeakyReLU(0.1, inplace=True), + nn.Conv2d(mid_channels, out_channels, 3, 1, 0), + nn.LeakyReLU(0.1, inplace=True), + ) + if se: + self.seblock = SEBlock(out_channels, reduction=8, bias=True) + else: + self.seblock = None + + def forward(self, x): + z = self.conv(x) + if self.seblock is not None: + z = self.seblock(z) + return z + +class UNet1(nn.Module): + def __init__(self, in_channels, out_channels, deconv): + super(UNet1, self).__init__() + self.conv1 = UNetConv(in_channels, 32, 64, se=False) + self.conv1_down = nn.Conv2d(64, 64, 2, 2, 0) + self.conv2 = UNetConv(64, 128, 64, se=True) + self.conv2_up = nn.ConvTranspose2d(64, 64, 2, 2, 0) + self.conv3 = nn.Conv2d(64, 64, 3, 1, 0) + + if deconv: + self.conv_bottom = nn.ConvTranspose2d(64, out_channels, 4, 2, 3) + else: + self.conv_bottom = nn.Conv2d(64, out_channels, 3, 1, 0) + + for m in self.modules(): + if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + x1 = self.conv1(x) + x2 = self.conv1_down(x1) + x2 = F.leaky_relu(x2, 0.1, inplace=True) + x2 = self.conv2(x2) + x2 = self.conv2_up(x2) + x2 = F.leaky_relu(x2, 0.1, inplace=True) + + x1 = F.pad(x1, (-4, -4, -4, -4)) + x3 = self.conv3(x1 + x2) + x3 = F.leaky_relu(x3, 0.1, inplace=True) + z = self.conv_bottom(x3) + return z + + +class UNet2(nn.Module): + def __init__(self, in_channels, out_channels, deconv): + super(UNet2, self).__init__() + + self.conv1 = UNetConv(in_channels, 32, 64, se=False) + self.conv1_down = nn.Conv2d(64, 64, 2, 2, 0) + self.conv2 = UNetConv(64, 64, 128, se=True) + self.conv2_down = nn.Conv2d(128, 128, 2, 2, 0) + self.conv3 = UNetConv(128, 256, 128, se=True) + self.conv3_up = nn.ConvTranspose2d(128, 128, 2, 2, 0) + self.conv4 = UNetConv(128, 64, 64, se=True) + self.conv4_up = nn.ConvTranspose2d(64, 64, 2, 2, 0) + self.conv5 = nn.Conv2d(64, 64, 3, 1, 0) + + if deconv: + self.conv_bottom = nn.ConvTranspose2d(64, out_channels, 4, 2, 3) + else: + self.conv_bottom = nn.Conv2d(64, out_channels, 3, 1, 0) + + for m in self.modules(): + if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + x1 = self.conv1(x) + x2 = self.conv1_down(x1) + x2 = F.leaky_relu(x2, 0.1, inplace=True) + x2 = self.conv2(x2) + + x3 = self.conv2_down(x2) + x3 = F.leaky_relu(x3, 0.1, inplace=True) + x3 = self.conv3(x3) + x3 = self.conv3_up(x3) + x3 = F.leaky_relu(x3, 0.1, inplace=True) + + x2 = F.pad(x2, (-4, -4, -4, -4)) + x4 = self.conv4(x2 + x3) + x4 = self.conv4_up(x4) + x4 = F.leaky_relu(x4, 0.1, inplace=True) + + x1 = F.pad(x1, (-16, -16, -16, -16)) + x5 = self.conv5(x1 + x4) + x5 = F.leaky_relu(x5, 0.1, inplace=True) + + z = self.conv_bottom(x5) + return z + + + +def main(): + root_path = os.path.abspath('.') + sys.path.append(root_path) + + from opt import opt # Manage GPU to choose + import time + + model = UNet_Full().cuda() + pytorch_total_params = sum(p.numel() for p in model.parameters()) + print(f"CuNet has param {pytorch_total_params//1000} K params") + + + # Count the number of FLOPs to double check + x = torch.randn((1, 3, 180, 180)).cuda() + start = time.time() + x = model(x) + print("output size is ", x.shape) + total = time.time() - start + print(total) + + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/architecture/dataset.py b/architecture/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..b1e1340284b3a6b096af46684f095786087dfb22 --- /dev/null +++ b/architecture/dataset.py @@ -0,0 +1,106 @@ +# -*- coding: utf-8 -*- +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable +from torchvision.models import vgg19 +import torchvision.transforms as transforms +from torch.utils.data import DataLoader, Dataset +from torchvision.utils import save_image, make_grid +from torchvision.transforms import ToTensor + +import numpy as np +import cv2 +import glob +import random +from PIL import Image +from tqdm import tqdm + + +# from degradation.degradation_main import degredate_process, preparation +from opt import opt + + +class ImageDataset(Dataset): + @torch.no_grad() + def __init__(self, train_lr_paths, degrade_hr_paths, train_hr_paths): + # print("low_res path sample is ", train_lr_paths[0]) + # print(train_hr_paths[0]) + # hr_height, hr_width = hr_shape + self.transform = transforms.Compose( + [ + transforms.ToTensor(), + ] + ) + + self.files_lr = train_lr_paths + self.files_degrade_hr = degrade_hr_paths + self.files_hr = train_hr_paths + + assert(len(self.files_lr) == len(self.files_hr)) + assert(len(self.files_lr) == len(self.files_degrade_hr)) + + + def augment(self, imgs, hflip=True, rotation=True): + """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees). + + All the images in the list use the same augmentation. + + Args: + imgs (list[ndarray] | ndarray): Images to be augmented. If the input + is an ndarray, it will be transformed to a list. + hflip (bool): Horizontal flip. Default: True. + rotation (bool): Rotation. Default: True. + + Returns: + imgs (list[ndarray] | ndarray): Augmented images and flows. If returned + results only have one element, just return ndarray. + + """ + hflip = hflip and random.random() < 0.5 + vflip = rotation and random.random() < 0.5 + rot90 = rotation and random.random() < 0.5 + + def _augment(img): + if hflip: # horizontal + cv2.flip(img, 1, img) + if vflip: # vertical + cv2.flip(img, 0, img) + if rot90: + img = img.transpose(1, 0, 2) + return img + + + if not isinstance(imgs, list): + imgs = [imgs] + + imgs = [_augment(img) for img in imgs] + if len(imgs) == 1: + imgs = imgs[0] + + + return imgs + + + def __getitem__(self, index): + + # Read File + img_lr = cv2.imread(self.files_lr[index % len(self.files_lr)]) # Should be BGR + img_degrade_hr = cv2.imread(self.files_degrade_hr[index % len(self.files_degrade_hr)]) + img_hr = cv2.imread(self.files_hr[index % len(self.files_hr)]) + + # Augmentation + if random.random() < opt["augment_prob"]: + img_lr, img_degrade_hr, img_hr = self.augment([img_lr, img_degrade_hr, img_hr]) + + # Transform to Tensor + img_lr = self.transform(img_lr) + img_degrade_hr = self.transform(img_degrade_hr) + img_hr = self.transform(img_hr) # ToTensor() is already in the range [0, 1] + + + return {"lr": img_lr, "degrade_hr": img_degrade_hr, "hr": img_hr} + + def __len__(self): + assert(len(self.files_hr) == len(self.files_lr)) + return len(self.files_hr) diff --git a/architecture/discriminator.py b/architecture/discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..372ca50dbe3ed9643ffb08399af8c6bab6b28ed0 --- /dev/null +++ b/architecture/discriminator.py @@ -0,0 +1,241 @@ +# -*- coding: utf-8 -*- +from torch import nn as nn +from torch.nn import functional as F +from torch.nn.utils import spectral_norm +import torch +import functools + +class UNetDiscriminatorSN(nn.Module): + """Defines a U-Net discriminator with spectral normalization (SN) + + It is used in Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data. + + Arg: + num_in_ch (int): Channel number of inputs. Default: 3. + num_feat (int): Channel number of base intermediate features. Default: 64. + skip_connection (bool): Whether to use skip connections between U-Net. Default: True. + """ + + def __init__(self, num_in_ch, num_feat=64, skip_connection=True): + super(UNetDiscriminatorSN, self).__init__() + self.skip_connection = skip_connection + norm = spectral_norm + # the first convolution + self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1) + # downsample + self.conv1 = norm(nn.Conv2d(num_feat, num_feat * 2, 4, 2, 1, bias=False)) + self.conv2 = norm(nn.Conv2d(num_feat * 2, num_feat * 4, 4, 2, 1, bias=False)) + self.conv3 = norm(nn.Conv2d(num_feat * 4, num_feat * 8, 4, 2, 1, bias=False)) + # upsample + self.conv4 = norm(nn.Conv2d(num_feat * 8, num_feat * 4, 3, 1, 1, bias=False)) + self.conv5 = norm(nn.Conv2d(num_feat * 4, num_feat * 2, 3, 1, 1, bias=False)) + self.conv6 = norm(nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1, bias=False)) + # extra convolutions + self.conv7 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False)) + self.conv8 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False)) + self.conv9 = nn.Conv2d(num_feat, 1, 3, 1, 1) + + def forward(self, x): + + # downsample + x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True) + x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True) + x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True) + x3 = F.leaky_relu(self.conv3(x2), negative_slope=0.2, inplace=True) + + # upsample + x3 = F.interpolate(x3, scale_factor=2, mode='bilinear', align_corners=False) + x4 = F.leaky_relu(self.conv4(x3), negative_slope=0.2, inplace=True) + + if self.skip_connection: + x4 = x4 + x2 + x4 = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False) + x5 = F.leaky_relu(self.conv5(x4), negative_slope=0.2, inplace=True) + + if self.skip_connection: + x5 = x5 + x1 + x5 = F.interpolate(x5, scale_factor=2, mode='bilinear', align_corners=False) + x6 = F.leaky_relu(self.conv6(x5), negative_slope=0.2, inplace=True) + + if self.skip_connection: + x6 = x6 + x0 + + # extra convolutions + out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True) + out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True) + out = self.conv9(out) + + return out + + + +def get_conv_layer(input_nc, ndf, kernel_size, stride, padding, bias=True, use_sn=False): + if not use_sn: + return nn.Conv2d(input_nc, ndf, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias) + return spectral_norm(nn.Conv2d(input_nc, ndf, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)) + + +class PatchDiscriminator(nn.Module): + """Defines a PatchGAN discriminator, the receptive field of default config is 70x70. + + Args: + use_sn (bool): Use spectra_norm or not, if use_sn is True, then norm_type should be none. + """ + + def __init__(self, + num_in_ch, + num_feat=64, + num_layers=3, + max_nf_mult=8, + norm_type='batch', + use_sigmoid=False, + use_sn=False): + super(PatchDiscriminator, self).__init__() + + norm_layer = self._get_norm_layer(norm_type) + if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters + use_bias = norm_layer.func != nn.BatchNorm2d + else: + use_bias = norm_layer != nn.BatchNorm2d + + kw = 4 + padw = 1 + sequence = [ + get_conv_layer(num_in_ch, num_feat, kernel_size=kw, stride=2, padding=padw, use_sn=use_sn), + nn.LeakyReLU(0.2, True) + ] + nf_mult = 1 + nf_mult_prev = 1 + for n in range(1, num_layers): # gradually increase the number of filters + nf_mult_prev = nf_mult + nf_mult = min(2**n, max_nf_mult) + sequence += [ + get_conv_layer( + num_feat * nf_mult_prev, + num_feat * nf_mult, + kernel_size=kw, + stride=2, + padding=padw, + bias=use_bias, + use_sn=use_sn), + norm_layer(num_feat * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + nf_mult_prev = nf_mult + nf_mult = min(2**num_layers, max_nf_mult) + sequence += [ + get_conv_layer( + num_feat * nf_mult_prev, + num_feat * nf_mult, + kernel_size=kw, + stride=1, + padding=padw, + bias=use_bias, + use_sn=use_sn), + norm_layer(num_feat * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + # output 1 channel prediction map 我觉得这个应该就是pixel by pixel的feedback反馈 + sequence += [get_conv_layer(num_feat * nf_mult, 1, kernel_size=kw, stride=1, padding=padw, use_sn=use_sn)] + + if use_sigmoid: + sequence += [nn.Sigmoid()] + self.model = nn.Sequential(*sequence) + + def _get_norm_layer(self, norm_type='batch'): + if norm_type == 'batch': + norm_layer = functools.partial(nn.BatchNorm2d, affine=True) + elif norm_type == 'instance': + norm_layer = functools.partial(nn.InstanceNorm2d, affine=False) + elif norm_type == 'batchnorm2d': + norm_layer = nn.BatchNorm2d + elif norm_type == 'none': + norm_layer = nn.Identity + else: + raise NotImplementedError(f'normalization layer [{norm_type}] is not found') + + return norm_layer + + def forward(self, x): + return self.model(x) + + +class MultiScaleDiscriminator(nn.Module): + """Define a multi-scale discriminator, each discriminator is a instance of PatchDiscriminator. + + Args: + num_layers (int or list): If the type of this variable is int, then degrade to PatchDiscriminator. + If the type of this variable is list, then the length of the list is + the number of discriminators. + use_downscale (bool): Progressive downscale the input to feed into different discriminators. + If set to True, then the discriminators are usually the same. + """ + + def __init__(self, + num_in_ch, + num_feat=64, + num_layers=[3, 3, 3], + max_nf_mult=8, + norm_type='none', + use_sigmoid=False, + use_sn=True, + use_downscale=True): + super(MultiScaleDiscriminator, self).__init__() + + if isinstance(num_layers, int): + num_layers = [num_layers] + + # check whether the discriminators are the same + if use_downscale: + assert len(set(num_layers)) == 1 + self.use_downscale = use_downscale + + self.num_dis = len(num_layers) + self.dis_list = nn.ModuleList() + for nl in num_layers: + self.dis_list.append( + PatchDiscriminator( + num_in_ch, + num_feat=num_feat, + num_layers=nl, + max_nf_mult=max_nf_mult, + norm_type=norm_type, + use_sigmoid=use_sigmoid, + use_sn=use_sn, + )) + + def forward(self, x): + outs = [] + h, w = x.size()[2:] + + y = x + for i in range(self.num_dis): + if i != 0 and self.use_downscale: + y = F.interpolate(y, size=(h // 2, w // 2), mode='bilinear', align_corners=True) + h, w = y.size()[2:] + outs.append(self.dis_list[i](y)) + + return outs + + +def main(): + from pthflops import count_ops + from torchsummary import summary + + model = UNetDiscriminatorSN(3) + pytorch_total_params = sum(p.numel() for p in model.parameters()) + + # Create a network and a corresponding input + device = 'cuda' + inp = torch.rand(1, 3, 400, 400) + + # Count the number of FLOPs + count_ops(model, inp) + summary(model.cuda(), (3, 400, 400), batch_size=1) + # print(f"pathGAN has param {pytorch_total_params//1000} K params") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/architecture/grl.py b/architecture/grl.py new file mode 100644 index 0000000000000000000000000000000000000000..3d899698484839a8d5f0c20cf91f7fdb47968846 --- /dev/null +++ b/architecture/grl.py @@ -0,0 +1,616 @@ +""" +Efficient and Explicit Modelling of Image Hierarchies for Image Restoration +Image restoration transformers with global, regional, and local modelling +A clean version of the. +Shared buffers are used for relative_coords_table, relative_position_index, and attn_mask. +""" +import cv2 +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision.transforms import ToTensor +from torchvision.utils import save_image +from fairscale.nn import checkpoint_wrapper +from omegaconf import OmegaConf +from timm.models.layers import to_2tuple, trunc_normal_ + +# Import files from local folder +import os, sys +root_path = os.path.abspath('.') +sys.path.append(root_path) + +from architecture.grl_common import Upsample, UpsampleOneStep +from architecture.grl_common.mixed_attn_block_efficient import ( + _get_stripe_info, + EfficientMixAttnTransformerBlock, +) +from architecture.grl_common.ops import ( + bchw_to_blc, + blc_to_bchw, + calculate_mask, + calculate_mask_all, + get_relative_coords_table_all, + get_relative_position_index_simple, +) +from architecture.grl_common.swin_v1_block import ( + build_last_conv, +) + + +class TransformerStage(nn.Module): + """Transformer stage. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads_window (list[int]): Number of window attention heads in different layers. + num_heads_stripe (list[int]): Number of stripe attention heads in different layers. + stripe_size (list[int]): Stripe size. Default: [8, 8] + stripe_groups (list[int]): Number of stripe groups. Default: [None, None]. + stripe_shift (bool): whether to shift the stripes. This is used as an ablation study. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qkv_proj_type (str): QKV projection type. Default: linear. Choices: linear, separable_conv. + anchor_proj_type (str): Anchor projection type. Default: avgpool. Choices: avgpool, maxpool, conv2d, separable_conv, patchmerging. + anchor_one_stage (bool): Whether to use one operator or multiple progressive operators to reduce feature map resolution. Default: True. + anchor_window_down_factor (int): The downscale factor used to get the anchors. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + pretrained_window_size (list[int]): pretrained window size. This is actually not used. Default: [0, 0]. + pretrained_stripe_size (list[int]): pretrained stripe size. This is actually not used. Default: [0, 0]. + conv_type: The convolutional block before residual connection. + init_method: initialization method of the weight parameters used to train large scale models. + Choices: n, normal -- Swin V1 init method. + l, layernorm -- Swin V2 init method. Zero the weight and bias in the post layer normalization layer. + r, res_rescale -- EDSR rescale method. Rescale the residual blocks with a scaling factor 0.1 + w, weight_rescale -- MSRResNet rescale method. Rescale the weight parameter in residual blocks with a scaling factor 0.1 + t, trunc_normal_ -- nn.Linear, trunc_normal; nn.Conv2d, weight_rescale + fairscale_checkpoint (bool): Whether to use fairscale checkpoint. + offload_to_cpu (bool): used by fairscale_checkpoint + args: + out_proj_type (str): Type of the output projection in the self-attention modules. Default: linear. Choices: linear, conv2d. + local_connection (bool): Whether to enable the local modelling module (two convs followed by Channel attention). For GRL base model, this is used. "local_connection": local_connection, + euclidean_dist (bool): use Euclidean distance or inner product as the similarity metric. An ablation study. + """ + + def __init__( + self, + dim, + input_resolution, + depth, + num_heads_window, + num_heads_stripe, + window_size, + stripe_size, + stripe_groups, + stripe_shift, + mlp_ratio=4.0, + qkv_bias=True, + qkv_proj_type="linear", + anchor_proj_type="avgpool", + anchor_one_stage=True, + anchor_window_down_factor=1, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + norm_layer=nn.LayerNorm, + pretrained_window_size=[0, 0], + pretrained_stripe_size=[0, 0], + conv_type="1conv", + init_method="", + fairscale_checkpoint=False, + offload_to_cpu=False, + args=None, + ): + super().__init__() + + self.dim = dim + self.input_resolution = input_resolution + self.init_method = init_method + + self.blocks = nn.ModuleList() + for i in range(depth): + block = EfficientMixAttnTransformerBlock( + dim=dim, + input_resolution=input_resolution, + num_heads_w=num_heads_window, + num_heads_s=num_heads_stripe, + window_size=window_size, + window_shift=i % 2 == 0, + stripe_size=stripe_size, + stripe_groups=stripe_groups, + stripe_type="H" if i % 2 == 0 else "W", + stripe_shift=i % 4 in [2, 3] if stripe_shift else False, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qkv_proj_type=qkv_proj_type, + anchor_proj_type=anchor_proj_type, + anchor_one_stage=anchor_one_stage, + anchor_window_down_factor=anchor_window_down_factor, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + pretrained_window_size=pretrained_window_size, + pretrained_stripe_size=pretrained_stripe_size, + res_scale=0.1 if init_method == "r" else 1.0, + args=args, + ) + # print(fairscale_checkpoint, offload_to_cpu) + if fairscale_checkpoint: + block = checkpoint_wrapper(block, offload_to_cpu=offload_to_cpu) + self.blocks.append(block) + + self.conv = build_last_conv(conv_type, dim) + + def _init_weights(self): + for n, m in self.named_modules(): + if self.init_method == "w": + if isinstance(m, (nn.Linear, nn.Conv2d)) and n.find("cpb_mlp") < 0: + print("nn.Linear and nn.Conv2d weight initilization") + m.weight.data *= 0.1 + elif self.init_method == "l": + if isinstance(m, nn.LayerNorm): + print("nn.LayerNorm initialization") + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 0) + elif self.init_method.find("t") >= 0: + scale = 0.1 ** (len(self.init_method) - 1) * int(self.init_method[-1]) + if isinstance(m, nn.Linear) and n.find("cpb_mlp") < 0: + trunc_normal_(m.weight, std=scale) + elif isinstance(m, nn.Conv2d): + m.weight.data *= 0.1 + print( + "Initialization nn.Linear - trunc_normal; nn.Conv2d - weight rescale." + ) + else: + raise NotImplementedError( + f"Parameter initialization method {self.init_method} not implemented in TransformerStage." + ) + + def forward(self, x, x_size, table_index_mask): + res = x + for blk in self.blocks: + res = blk(res, x_size, table_index_mask) + res = bchw_to_blc(self.conv(blc_to_bchw(res, x_size))) + + return res + x + + def flops(self): + pass + + +class GRL(nn.Module): + r"""Image restoration transformer with global, non-local, and local connections + Args: + img_size (int | list[int]): Input image size. Default 64 + in_channels (int): Number of input image channels. Default: 3 + out_channels (int): Number of output image channels. Default: None + embed_dim (int): Patch embedding dimension. Default: 96 + upscale (int): Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction + img_range (float): Image range. 1. or 255. + upsampler (str): The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None + depths (list[int]): Depth of each Swin Transformer layer. + num_heads_window (list[int]): Number of window attention heads in different layers. + num_heads_stripe (list[int]): Number of stripe attention heads in different layers. + window_size (int): Window size. Default: 8. + stripe_size (list[int]): Stripe size. Default: [8, 8] + stripe_groups (list[int]): Number of stripe groups. Default: [None, None]. + stripe_shift (bool): whether to shift the stripes. This is used as an ablation study. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qkv_proj_type (str): QKV projection type. Default: linear. Choices: linear, separable_conv. + anchor_proj_type (str): Anchor projection type. Default: avgpool. Choices: avgpool, maxpool, conv2d, separable_conv, patchmerging. + anchor_one_stage (bool): Whether to use one operator or multiple progressive operators to reduce feature map resolution. Default: True. + anchor_window_down_factor (int): The downscale factor used to get the anchors. + out_proj_type (str): Type of the output projection in the self-attention modules. Default: linear. Choices: linear, conv2d. + local_connection (bool): Whether to enable the local modelling module (two convs followed by Channel attention). For GRL base model, this is used. + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + pretrained_window_size (list[int]): pretrained window size. This is actually not used. Default: [0, 0]. + pretrained_stripe_size (list[int]): pretrained stripe size. This is actually not used. Default: [0, 0]. + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + conv_type (str): The convolutional block before residual connection. Default: 1conv. Choices: 1conv, 3conv, 1conv1x1, linear + init_method: initialization method of the weight parameters used to train large scale models. + Choices: n, normal -- Swin V1 init method. + l, layernorm -- Swin V2 init method. Zero the weight and bias in the post layer normalization layer. + r, res_rescale -- EDSR rescale method. Rescale the residual blocks with a scaling factor 0.1 + w, weight_rescale -- MSRResNet rescale method. Rescale the weight parameter in residual blocks with a scaling factor 0.1 + t, trunc_normal_ -- nn.Linear, trunc_normal; nn.Conv2d, weight_rescale + fairscale_checkpoint (bool): Whether to use fairscale checkpoint. + offload_to_cpu (bool): used by fairscale_checkpoint + euclidean_dist (bool): use Euclidean distance or inner product as the similarity metric. An ablation study. + + """ + + def __init__( + self, + img_size=64, + in_channels=3, + out_channels=None, + embed_dim=96, + upscale=2, + img_range=1.0, + upsampler="", + depths=[6, 6, 6, 6, 6, 6], + num_heads_window=[3, 3, 3, 3, 3, 3], + num_heads_stripe=[3, 3, 3, 3, 3, 3], + window_size=8, + stripe_size=[8, 8], # used for stripe window attention + stripe_groups=[None, None], + stripe_shift=False, + mlp_ratio=4.0, + qkv_bias=True, + qkv_proj_type="linear", + anchor_proj_type="avgpool", + anchor_one_stage=True, + anchor_window_down_factor=1, + out_proj_type="linear", + local_connection=False, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.1, + norm_layer=nn.LayerNorm, + pretrained_window_size=[0, 0], + pretrained_stripe_size=[0, 0], + conv_type="1conv", + init_method="n", # initialization method of the weight parameters used to train large scale models. + fairscale_checkpoint=False, # fairscale activation checkpointing + offload_to_cpu=False, + euclidean_dist=False, + **kwargs, + ): + super(GRL, self).__init__() + # Process the input arguments + out_channels = out_channels or in_channels + self.in_channels = in_channels + self.out_channels = out_channels + num_out_feats = 64 + self.embed_dim = embed_dim + self.upscale = upscale + self.upsampler = upsampler + self.img_range = img_range + if in_channels == 3: + rgb_mean = (0.4488, 0.4371, 0.4040) + self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) + else: + self.mean = torch.zeros(1, 1, 1, 1) + + max_stripe_size = max([0 if s is None else s for s in stripe_size]) + max_stripe_groups = max([0 if s is None else s for s in stripe_groups]) + max_stripe_groups *= anchor_window_down_factor + self.pad_size = max(window_size, max_stripe_size, max_stripe_groups) + # if max_stripe_size >= window_size: + # self.pad_size *= anchor_window_down_factor + # if stripe_groups[0] is None and stripe_groups[1] is None: + # self.pad_size = max(stripe_size) + # else: + # self.pad_size = window_size + self.input_resolution = to_2tuple(img_size) + self.window_size = to_2tuple(window_size) + self.shift_size = [w // 2 for w in self.window_size] + self.stripe_size = stripe_size + self.stripe_groups = stripe_groups + self.pretrained_window_size = pretrained_window_size + self.pretrained_stripe_size = pretrained_stripe_size + self.anchor_window_down_factor = anchor_window_down_factor + + # Head of the network. First convolution. + self.conv_first = nn.Conv2d(in_channels, embed_dim, 3, 1, 1) + + # Body of the network + self.norm_start = norm_layer(embed_dim) + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] + # stochastic depth decay rule + args = OmegaConf.create( + { + "out_proj_type": out_proj_type, + "local_connection": local_connection, + "euclidean_dist": euclidean_dist, + } + ) + for k, v in self.set_table_index_mask(self.input_resolution).items(): + self.register_buffer(k, v) + + self.layers = nn.ModuleList() + for i in range(len(depths)): + layer = TransformerStage( + dim=embed_dim, + input_resolution=self.input_resolution, + depth=depths[i], + num_heads_window=num_heads_window[i], + num_heads_stripe=num_heads_stripe[i], + window_size=self.window_size, + stripe_size=stripe_size, + stripe_groups=stripe_groups, + stripe_shift=stripe_shift, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qkv_proj_type=qkv_proj_type, + anchor_proj_type=anchor_proj_type, + anchor_one_stage=anchor_one_stage, + anchor_window_down_factor=anchor_window_down_factor, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[ + sum(depths[:i]) : sum(depths[: i + 1]) + ], # no impact on SR results + norm_layer=norm_layer, + pretrained_window_size=pretrained_window_size, + pretrained_stripe_size=pretrained_stripe_size, + conv_type=conv_type, + init_method=init_method, + fairscale_checkpoint=fairscale_checkpoint, + offload_to_cpu=offload_to_cpu, + args=args, + ) + self.layers.append(layer) + self.norm_end = norm_layer(embed_dim) + + # Tail of the network + self.conv_after_body = build_last_conv(conv_type, embed_dim) + + ##################################################################################################### + ################################ 3, high quality image reconstruction ################################ + if self.upsampler == "pixelshuffle": + # for classical SR + self.conv_before_upsample = nn.Sequential( + nn.Conv2d(embed_dim, num_out_feats, 3, 1, 1), nn.LeakyReLU(inplace=True) + ) + self.upsample = Upsample(upscale, num_out_feats) + self.conv_last = nn.Conv2d(num_out_feats, out_channels, 3, 1, 1) + elif self.upsampler == "pixelshuffledirect": + # for lightweight SR (to save parameters) + self.upsample = UpsampleOneStep( + upscale, + embed_dim, + out_channels, + ) + elif self.upsampler == "nearest+conv": + # for real-world SR (less artifacts) + assert self.upscale == 4, "only support x4 now." + self.conv_before_upsample = nn.Sequential( + nn.Conv2d(embed_dim, num_out_feats, 3, 1, 1), nn.LeakyReLU(inplace=True) + ) + self.conv_up1 = nn.Conv2d(num_out_feats, num_out_feats, 3, 1, 1) + self.conv_up2 = nn.Conv2d(num_out_feats, num_out_feats, 3, 1, 1) + self.conv_hr = nn.Conv2d(num_out_feats, num_out_feats, 3, 1, 1) + self.conv_last = nn.Conv2d(num_out_feats, out_channels, 3, 1, 1) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + else: + # for image denoising and JPEG compression artifact reduction + self.conv_last = nn.Conv2d(embed_dim, out_channels, 3, 1, 1) + + self.apply(self._init_weights) + if init_method in ["l", "w"] or init_method.find("t") >= 0: + for layer in self.layers: + layer._init_weights() + + def set_table_index_mask(self, x_size): + """ + Two used cases: + 1) At initialization: set the shared buffers. + 2) During forward pass: get the new buffers if the resolution of the input changes + """ + # ss - stripe_size, sss - stripe_shift_size + ss, sss = _get_stripe_info(self.stripe_size, self.stripe_groups, True, x_size) + df = self.anchor_window_down_factor + + table_w = get_relative_coords_table_all( + self.window_size, self.pretrained_window_size + ) + table_sh = get_relative_coords_table_all(ss, self.pretrained_stripe_size, df) + table_sv = get_relative_coords_table_all( + ss[::-1], self.pretrained_stripe_size, df + ) + + index_w = get_relative_position_index_simple(self.window_size) + index_sh_a2w = get_relative_position_index_simple(ss, df, False) + index_sh_w2a = get_relative_position_index_simple(ss, df, True) + index_sv_a2w = get_relative_position_index_simple(ss[::-1], df, False) + index_sv_w2a = get_relative_position_index_simple(ss[::-1], df, True) + + mask_w = calculate_mask(x_size, self.window_size, self.shift_size) + mask_sh_a2w = calculate_mask_all(x_size, ss, sss, df, False) + mask_sh_w2a = calculate_mask_all(x_size, ss, sss, df, True) + mask_sv_a2w = calculate_mask_all(x_size, ss[::-1], sss[::-1], df, False) + mask_sv_w2a = calculate_mask_all(x_size, ss[::-1], sss[::-1], df, True) + return { + "table_w": table_w, + "table_sh": table_sh, + "table_sv": table_sv, + "index_w": index_w, + "index_sh_a2w": index_sh_a2w, + "index_sh_w2a": index_sh_w2a, + "index_sv_a2w": index_sv_a2w, + "index_sv_w2a": index_sv_w2a, + "mask_w": mask_w, + "mask_sh_a2w": mask_sh_a2w, + "mask_sh_w2a": mask_sh_w2a, + "mask_sv_a2w": mask_sv_a2w, + "mask_sv_w2a": mask_sv_w2a, + } + + def get_table_index_mask(self, device=None, input_resolution=None): + # Used during forward pass + if input_resolution == self.input_resolution: + return { + "table_w": self.table_w, + "table_sh": self.table_sh, + "table_sv": self.table_sv, + "index_w": self.index_w, + "index_sh_a2w": self.index_sh_a2w, + "index_sh_w2a": self.index_sh_w2a, + "index_sv_a2w": self.index_sv_a2w, + "index_sv_w2a": self.index_sv_w2a, + "mask_w": self.mask_w, + "mask_sh_a2w": self.mask_sh_a2w, + "mask_sh_w2a": self.mask_sh_w2a, + "mask_sv_a2w": self.mask_sv_a2w, + "mask_sv_w2a": self.mask_sv_w2a, + } + else: + table_index_mask = self.set_table_index_mask(input_resolution) + for k, v in table_index_mask.items(): + table_index_mask[k] = v.to(device) + return table_index_mask + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + # Only used to initialize linear layers + # weight_shape = m.weight.shape + # if weight_shape[0] > 256 and weight_shape[1] > 256: + # std = 0.004 + # else: + # std = 0.02 + # print(f"Standard deviation during initialization {std}.") + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {"absolute_pos_embed"} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {"relative_position_bias_table"} + + def check_image_size(self, x): + _, _, h, w = x.size() + mod_pad_h = (self.pad_size - h % self.pad_size) % self.pad_size + mod_pad_w = (self.pad_size - w % self.pad_size) % self.pad_size + # print("padding size", h, w, self.pad_size, mod_pad_h, mod_pad_w) + + try: + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "reflect") + except BaseException: + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "constant") + return x + + def forward_features(self, x): + x_size = (x.shape[2], x.shape[3]) + x = bchw_to_blc(x) + x = self.norm_start(x) + x = self.pos_drop(x) + + table_index_mask = self.get_table_index_mask(x.device, x_size) + for layer in self.layers: + x = layer(x, x_size, table_index_mask) + + x = self.norm_end(x) # B L C + x = blc_to_bchw(x, x_size) + + return x + + def forward(self, x): + H, W = x.shape[2:] + x = self.check_image_size(x) + + self.mean = self.mean.type_as(x) + x = (x - self.mean) * self.img_range + + if self.upsampler == "pixelshuffle": + # for classical SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.conv_last(self.upsample(x)) + elif self.upsampler == "pixelshuffledirect": + # for lightweight SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.upsample(x) + elif self.upsampler == "nearest+conv": + # for real-world SR (claimed to have less artifacts) + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.lrelu( + self.conv_up1( + torch.nn.functional.interpolate(x, scale_factor=2, mode="nearest") + ) + ) + x = self.lrelu( + self.conv_up2( + torch.nn.functional.interpolate(x, scale_factor=2, mode="nearest") + ) + ) + x = self.conv_last(self.lrelu(self.conv_hr(x))) + else: + # for image denoising and JPEG compression artifact reduction + x_first = self.conv_first(x) + res = self.conv_after_body(self.forward_features(x_first)) + x_first + if self.in_channels == self.out_channels: + x = x + self.conv_last(res) + else: + x = self.conv_last(res) + + x = x / self.img_range + self.mean + + return x[:, :, : H * self.upscale, : W * self.upscale] + + def flops(self): + pass + + def convert_checkpoint(self, state_dict): + for k in list(state_dict.keys()): + if ( + k.find("relative_coords_table") >= 0 + or k.find("relative_position_index") >= 0 + or k.find("attn_mask") >= 0 + or k.find("model.table_") >= 0 + or k.find("model.index_") >= 0 + or k.find("model.mask_") >= 0 + # or k.find(".upsample.") >= 0 + ): + state_dict.pop(k) + print(k) + return state_dict + + +if __name__ == "__main__": + # The version of GRL we use + model = GRL( + upscale = 4, + img_size = 64, + window_size = 8, + depths = [4, 4, 4, 4], + embed_dim = 64, + num_heads_window = [2, 2, 2, 2], + num_heads_stripe = [2, 2, 2, 2], + mlp_ratio = 2, + qkv_proj_type = "linear", + anchor_proj_type = "avgpool", + anchor_window_down_factor = 2, + out_proj_type = "linear", + conv_type = "1conv", + upsampler = "nearest+conv", # Change + ).cuda() + + # Parameter analysis + num_params = 0 + for p in model.parameters(): + if p.requires_grad: + num_params += p.numel() + print(f"Number of parameters {num_params / 10 ** 6: 0.2f}") + + # Print param + for name, param in model.named_parameters(): + print(name, param.dtype) + + + # Count the number of FLOPs to double check + x = torch.randn((1, 3, 180, 180)).cuda() # Don't use input size that is too big (we don't have @torch.no_grad here) + x = model(x) + print("output size is ", x.shape) + diff --git a/architecture/grl_common/__init__.py b/architecture/grl_common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..711842e9c7673427001700c318bf446227f5e834 --- /dev/null +++ b/architecture/grl_common/__init__.py @@ -0,0 +1,8 @@ +from architecture.grl_common.resblock import ResBlock +from architecture.grl_common.upsample import ( + Upsample, + UpsampleOneStep, +) + + +__all__ = ["Upsample", "UpsampleOneStep", "ResBlock"] diff --git a/architecture/grl_common/common_edsr.py b/architecture/grl_common/common_edsr.py new file mode 100644 index 0000000000000000000000000000000000000000..8d0da6e0ad593d97bb1bccb5f7a75a622670322e --- /dev/null +++ b/architecture/grl_common/common_edsr.py @@ -0,0 +1,227 @@ +""" +EDSR common.py +Since a lot of models are developed on top of EDSR, here we include some common functions from EDSR. +In this repository, the common functions is used by edsr_esa.py and ipt.py +""" + + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def default_conv(in_channels, out_channels, kernel_size, bias=True): + return nn.Conv2d( + in_channels, out_channels, kernel_size, padding=(kernel_size // 2), bias=bias + ) + + +class MeanShift(nn.Conv2d): + def __init__( + self, + rgb_range, + rgb_mean=(0.4488, 0.4371, 0.4040), + rgb_std=(1.0, 1.0, 1.0), + sign=-1, + ): + + super(MeanShift, self).__init__(3, 3, kernel_size=1) + std = torch.Tensor(rgb_std) + self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1) + self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std + for p in self.parameters(): + p.requires_grad = False + + +class BasicBlock(nn.Sequential): + def __init__( + self, + conv, + in_channels, + out_channels, + kernel_size, + stride=1, + bias=False, + bn=True, + act=nn.ReLU(True), + ): + + m = [conv(in_channels, out_channels, kernel_size, bias=bias)] + if bn: + m.append(nn.BatchNorm2d(out_channels)) + if act is not None: + m.append(act) + + super(BasicBlock, self).__init__(*m) + + +class ESA(nn.Module): + def __init__(self, esa_channels, n_feats): + super(ESA, self).__init__() + f = esa_channels + self.conv1 = nn.Conv2d(n_feats, f, kernel_size=1) + self.conv_f = nn.Conv2d(f, f, kernel_size=1) + # self.conv_max = conv(f, f, kernel_size=3, padding=1) + self.conv2 = nn.Conv2d(f, f, kernel_size=3, stride=2, padding=0) + self.conv3 = nn.Conv2d(f, f, kernel_size=3, padding=1) + # self.conv3_ = conv(f, f, kernel_size=3, padding=1) + self.conv4 = nn.Conv2d(f, n_feats, kernel_size=1) + self.sigmoid = nn.Sigmoid() + # self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + c1_ = self.conv1(x) + c1 = self.conv2(c1_) + v_max = F.max_pool2d(c1, kernel_size=7, stride=3) + c3 = self.conv3(v_max) + # v_range = self.relu(self.conv_max(v_max)) + # c3 = self.relu(self.conv3(v_range)) + # c3 = self.conv3_(c3) + c3 = F.interpolate( + c3, (x.size(2), x.size(3)), mode="bilinear", align_corners=False + ) + cf = self.conv_f(c1_) + c4 = self.conv4(c3 + cf) + m = self.sigmoid(c4) + + return x * m + + +# class ESA(nn.Module): +# def __init__(self, esa_channels, n_feats, conv=nn.Conv2d): +# super(ESA, self).__init__() +# f = n_feats // 4 +# self.conv1 = conv(n_feats, f, kernel_size=1) +# self.conv_f = conv(f, f, kernel_size=1) +# self.conv_max = conv(f, f, kernel_size=3, padding=1) +# self.conv2 = conv(f, f, kernel_size=3, stride=2, padding=0) +# self.conv3 = conv(f, f, kernel_size=3, padding=1) +# self.conv3_ = conv(f, f, kernel_size=3, padding=1) +# self.conv4 = conv(f, n_feats, kernel_size=1) +# self.sigmoid = nn.Sigmoid() +# self.relu = nn.ReLU(inplace=True) +# +# def forward(self, x): +# c1_ = (self.conv1(x)) +# c1 = self.conv2(c1_) +# v_max = F.max_pool2d(c1, kernel_size=7, stride=3) +# v_range = self.relu(self.conv_max(v_max)) +# c3 = self.relu(self.conv3(v_range)) +# c3 = self.conv3_(c3) +# c3 = F.interpolate(c3, (x.size(2), x.size(3)), mode='bilinear', align_corners=False) +# cf = self.conv_f(c1_) +# c4 = self.conv4(c3 + cf) +# m = self.sigmoid(c4) +# +# return x * m + + +class ResBlock(nn.Module): + def __init__( + self, + conv, + n_feats, + kernel_size, + bias=True, + bn=False, + act=nn.ReLU(True), + res_scale=1, + esa_block=True, + depth_wise_kernel=7, + ): + + super(ResBlock, self).__init__() + m = [] + for i in range(2): + m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) + if bn: + m.append(nn.BatchNorm2d(n_feats)) + if i == 0: + m.append(act) + + self.body = nn.Sequential(*m) + self.esa_block = esa_block + if self.esa_block: + esa_channels = 16 + self.c5 = nn.Conv2d( + n_feats, + n_feats, + depth_wise_kernel, + padding=depth_wise_kernel // 2, + groups=n_feats, + bias=True, + ) + self.esa = ESA(esa_channels, n_feats) + self.res_scale = res_scale + + def forward(self, x): + res = self.body(x).mul(self.res_scale) + res += x + if self.esa_block: + res = self.esa(self.c5(res)) + + return res + + +class Upsampler(nn.Sequential): + def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True): + + m = [] + if (scale & (scale - 1)) == 0: # Is scale = 2^n? + for _ in range(int(math.log(scale, 2))): + m.append(conv(n_feats, 4 * n_feats, 3, bias)) + m.append(nn.PixelShuffle(2)) + if bn: + m.append(nn.BatchNorm2d(n_feats)) + if act == "relu": + m.append(nn.ReLU(True)) + elif act == "prelu": + m.append(nn.PReLU(n_feats)) + + elif scale == 3: + m.append(conv(n_feats, 9 * n_feats, 3, bias)) + m.append(nn.PixelShuffle(3)) + if bn: + m.append(nn.BatchNorm2d(n_feats)) + if act == "relu": + m.append(nn.ReLU(True)) + elif act == "prelu": + m.append(nn.PReLU(n_feats)) + else: + raise NotImplementedError + + super(Upsampler, self).__init__(*m) + + +class LiteUpsampler(nn.Sequential): + def __init__(self, conv, scale, n_feats, n_out=3, bn=False, act=False, bias=True): + + m = [] + m.append(conv(n_feats, n_out * (scale**2), 3, bias)) + m.append(nn.PixelShuffle(scale)) + # if (scale & (scale - 1)) == 0: # Is scale = 2^n? + # for _ in range(int(math.log(scale, 2))): + # m.append(conv(n_feats, 4 * n_out, 3, bias)) + # m.append(nn.PixelShuffle(2)) + # if bn: + # m.append(nn.BatchNorm2d(n_out)) + # if act == 'relu': + # m.append(nn.ReLU(True)) + # elif act == 'prelu': + # m.append(nn.PReLU(n_out)) + + # elif scale == 3: + # m.append(conv(n_feats, 9 * n_out, 3, bias)) + # m.append(nn.PixelShuffle(3)) + # if bn: + # m.append(nn.BatchNorm2d(n_out)) + # if act == 'relu': + # m.append(nn.ReLU(True)) + # elif act == 'prelu': + # m.append(nn.PReLU(n_out)) + # else: + # raise NotImplementedError + + super(LiteUpsampler, self).__init__(*m) diff --git a/architecture/grl_common/mixed_attn_block.py b/architecture/grl_common/mixed_attn_block.py new file mode 100644 index 0000000000000000000000000000000000000000..d845b7fa67302ec4f566c94aa0a4a0b9b0185f45 --- /dev/null +++ b/architecture/grl_common/mixed_attn_block.py @@ -0,0 +1,1126 @@ +import math +from abc import ABC +from math import prod + +import torch +import torch.nn as nn +import torch.nn.functional as F +from architecture.grl_common.ops import ( + bchw_to_bhwc, + bchw_to_blc, + blc_to_bchw, + blc_to_bhwc, + calculate_mask, + calculate_mask_all, + get_relative_coords_table_all, + get_relative_position_index_simple, + window_partition, + window_reverse, +) +from architecture.grl_common.swin_v1_block import Mlp +from timm.models.layers import DropPath + + +class CPB_MLP(nn.Sequential): + def __init__(self, in_channels, out_channels, channels=512): + m = [ + nn.Linear(in_channels, channels, bias=True), + nn.ReLU(inplace=True), + nn.Linear(channels, out_channels, bias=False), + ] + super(CPB_MLP, self).__init__(*m) + + +class AffineTransformWindow(nn.Module): + r"""Affine transformation of the attention map. + The window is a square window. + Supports attention between different window sizes + """ + + def __init__( + self, + num_heads, + input_resolution, + window_size, + pretrained_window_size=[0, 0], + shift_size=0, + anchor_window_down_factor=1, + args=None, + ): + super(AffineTransformWindow, self).__init__() + # print("AffineTransformWindow", args) + self.num_heads = num_heads + self.input_resolution = input_resolution + self.window_size = window_size + self.pretrained_window_size = pretrained_window_size + self.shift_size = shift_size + self.anchor_window_down_factor = anchor_window_down_factor + self.use_buffer = args.use_buffer + + logit_scale = torch.log(10 * torch.ones((num_heads, 1, 1))) + self.logit_scale = nn.Parameter(logit_scale, requires_grad=True) + + # mlp to generate continuous relative position bias + self.cpb_mlp = CPB_MLP(2, num_heads) + if self.use_buffer: + table = get_relative_coords_table_all( + window_size, pretrained_window_size, anchor_window_down_factor + ) + index = get_relative_position_index_simple( + window_size, anchor_window_down_factor + ) + self.register_buffer("relative_coords_table", table) + self.register_buffer("relative_position_index", index) + + if self.shift_size > 0: + attn_mask = calculate_mask( + input_resolution, self.window_size, self.shift_size + ) + else: + attn_mask = None + self.register_buffer("attn_mask", attn_mask) + + def forward(self, attn, x_size): + B_, H, N, _ = attn.shape + device = attn.device + # logit scale + attn = attn * torch.clamp(self.logit_scale, max=math.log(1.0 / 0.01)).exp() + + # relative position bias + if self.use_buffer: + table = self.relative_coords_table + index = self.relative_position_index + else: + table = get_relative_coords_table_all( + self.window_size, + self.pretrained_window_size, + self.anchor_window_down_factor, + ).to(device) + index = get_relative_position_index_simple( + self.window_size, self.anchor_window_down_factor + ).to(device) + + bias_table = self.cpb_mlp(table) # 2*Wh-1, 2*Ww-1, num_heads + bias_table = bias_table.view(-1, self.num_heads) + + win_dim = prod(self.window_size) + bias = bias_table[index.view(-1)] + bias = bias.view(win_dim, win_dim, -1).permute(2, 0, 1).contiguous() + # nH, Wh*Ww, Wh*Ww + bias = 16 * torch.sigmoid(bias) + attn = attn + bias.unsqueeze(0) + + # W-MSA/SW-MSA + if self.use_buffer: + mask = self.attn_mask + # during test and window shift, recalculate the mask + if self.input_resolution != x_size and self.shift_size > 0: + mask = calculate_mask(x_size, self.window_size, self.shift_size) + mask = mask.to(attn.device) + else: + if self.shift_size > 0: + mask = calculate_mask(x_size, self.window_size, self.shift_size) + mask = mask.to(attn.device) + else: + mask = None + + # shift attention mask + if mask is not None: + nW = mask.shape[0] + mask = mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask + attn = attn.view(-1, self.num_heads, N, N) + + return attn + + +class AffineTransformStripe(nn.Module): + r"""Affine transformation of the attention map. + The window is a stripe window. Supports attention between different window sizes + """ + + def __init__( + self, + num_heads, + input_resolution, + stripe_size, + stripe_groups, + stripe_shift, + pretrained_stripe_size=[0, 0], + anchor_window_down_factor=1, + window_to_anchor=True, + args=None, + ): + super(AffineTransformStripe, self).__init__() + self.num_heads = num_heads + self.input_resolution = input_resolution + self.stripe_size = stripe_size + self.stripe_groups = stripe_groups + self.pretrained_stripe_size = pretrained_stripe_size + # TODO: be careful when determining the pretrained_stripe_size + self.stripe_shift = stripe_shift + stripe_size, shift_size = self._get_stripe_info(input_resolution) + self.anchor_window_down_factor = anchor_window_down_factor + self.window_to_anchor = window_to_anchor + self.use_buffer = args.use_buffer + + logit_scale = torch.log(10 * torch.ones((num_heads, 1, 1))) + self.logit_scale = nn.Parameter(logit_scale, requires_grad=True) + + # mlp to generate continuous relative position bias + self.cpb_mlp = CPB_MLP(2, num_heads) + if self.use_buffer: + table = get_relative_coords_table_all( + stripe_size, pretrained_stripe_size, anchor_window_down_factor + ) + index = get_relative_position_index_simple( + stripe_size, anchor_window_down_factor, window_to_anchor + ) + self.register_buffer("relative_coords_table", table) + self.register_buffer("relative_position_index", index) + + if self.stripe_shift: + attn_mask = calculate_mask_all( + input_resolution, + stripe_size, + shift_size, + anchor_window_down_factor, + window_to_anchor, + ) + else: + attn_mask = None + self.register_buffer("attn_mask", attn_mask) + + def forward(self, attn, x_size): + B_, H, N1, N2 = attn.shape + device = attn.device + # logit scale + attn = attn * torch.clamp(self.logit_scale, max=math.log(1.0 / 0.01)).exp() + + # relative position bias + stripe_size, shift_size = self._get_stripe_info(x_size) + fixed_stripe_size = ( + self.stripe_groups[0] is None and self.stripe_groups[1] is None + ) + if not self.use_buffer or ( + self.use_buffer + and self.input_resolution != x_size + and not fixed_stripe_size + ): + # during test and stripe size is not fixed. + pretrained_stripe_size = ( + self.pretrained_stripe_size + ) # or stripe_size; Needs further pondering + table = get_relative_coords_table_all( + stripe_size, pretrained_stripe_size, self.anchor_window_down_factor + ) + table = table.to(device) + index = get_relative_position_index_simple( + stripe_size, self.anchor_window_down_factor, self.window_to_anchor + ).to(device) + else: + table = self.relative_coords_table + index = self.relative_position_index + # The same table size-> 1, Wh+AWh-1, Ww+AWw-1, 2 + # But different index size -> # Wh*Ww, AWh*AWw + # if N1 < N2: + # index = index.transpose(0, 1) + + bias_table = self.cpb_mlp(table).view(-1, self.num_heads) + # if not self.training: + # print(bias_table.shape, index.max(), index.min()) + bias = bias_table[index.view(-1)] + bias = bias.view(N1, N2, -1).permute(2, 0, 1).contiguous() + # nH, Wh*Ww, Wh*Ww + bias = 16 * torch.sigmoid(bias) + # print(N1, N2, attn.shape, bias.unsqueeze(0).shape) + attn = attn + bias.unsqueeze(0) + + # W-MSA/SW-MSA + if self.use_buffer: + mask = self.attn_mask + # during test and window shift, recalculate the mask + if self.input_resolution != x_size and self.stripe_shift > 0: + mask = calculate_mask_all( + x_size, + stripe_size, + shift_size, + self.anchor_window_down_factor, + self.window_to_anchor, + ) + mask = mask.to(device) + else: + if self.stripe_shift > 0: + mask = calculate_mask_all( + x_size, + stripe_size, + shift_size, + self.anchor_window_down_factor, + self.window_to_anchor, + ) + mask = mask.to(attn.device) + else: + mask = None + + # shift attention mask + if mask is not None: + nW = mask.shape[0] + mask = mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(B_ // nW, nW, self.num_heads, N1, N2) + mask + attn = attn.view(-1, self.num_heads, N1, N2) + + return attn + + def _get_stripe_info(self, input_resolution): + stripe_size, shift_size = [], [] + for s, g, d in zip(self.stripe_size, self.stripe_groups, input_resolution): + if g is None: + stripe_size.append(s) + shift_size.append(s // 2 if self.stripe_shift else 0) + else: + stripe_size.append(d // g) + shift_size.append(0 if g == 1 else d // (g * 2)) + return stripe_size, shift_size + + +class Attention(ABC, nn.Module): + def __init__(self): + super(Attention, self).__init__() + + def attn(self, q, k, v, attn_transform, x_size, reshape=True): + # cosine attention map + B_, _, H, head_dim = q.shape + if self.euclidean_dist: + attn = torch.norm(q.unsqueeze(-2) - k.unsqueeze(-3), dim=-1) + else: + attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1) + attn = attn_transform(attn, x_size) + # attention + attn = self.softmax(attn) + attn = self.attn_drop(attn) + x = attn @ v # B_, H, N1, head_dim + if reshape: + x = x.transpose(1, 2).reshape(B_, -1, H * head_dim) + # B_, N, C + return x + + +class WindowAttention(Attention): + r"""Window attention. QKV is the input to the forward method. + Args: + num_heads (int): Number of attention heads. + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + pretrained_window_size (tuple[int]): The height and width of the window in pre-training. + """ + + def __init__( + self, + input_resolution, + window_size, + num_heads, + window_shift=False, + attn_drop=0.0, + pretrained_window_size=[0, 0], + args=None, + ): + + super(WindowAttention, self).__init__() + self.input_resolution = input_resolution + self.window_size = window_size + self.pretrained_window_size = pretrained_window_size + self.num_heads = num_heads + self.shift_size = window_size[0] // 2 if window_shift else 0 + self.euclidean_dist = args.euclidean_dist + + self.attn_transform = AffineTransformWindow( + num_heads, + input_resolution, + window_size, + pretrained_window_size, + self.shift_size, + args=args, + ) + self.attn_drop = nn.Dropout(attn_drop) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, qkv, x_size): + """ + Args: + qkv: input QKV features with shape of (B, L, 3C) + x_size: use x_size to determine whether the relative positional bias table and index + need to be regenerated. + """ + H, W = x_size + B, L, C = qkv.shape + qkv = qkv.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + qkv = torch.roll( + qkv, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2) + ) + + # partition windows + qkv = window_partition(qkv, self.window_size) # nW*B, wh, ww, C + qkv = qkv.view(-1, prod(self.window_size), C) # nW*B, wh*ww, C + + B_, N, _ = qkv.shape + qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + # attention + x = self.attn(q, k, v, self.attn_transform, x_size) + + # merge windows + x = x.view(-1, *self.window_size, C // 3) + x = window_reverse(x, self.window_size, x_size) # B, H, W, C/3 + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + x = x.view(B, L, C // 3) + + return x + + def extra_repr(self) -> str: + return ( + f"window_size={self.window_size}, shift_size={self.shift_size}, " + f"pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}" + ) + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class StripeAttention(Attention): + r"""Stripe attention + Args: + stripe_size (tuple[int]): The height and width of the stripe. + num_heads (int): Number of attention heads. + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + pretrained_stripe_size (tuple[int]): The height and width of the stripe in pre-training. + """ + + def __init__( + self, + input_resolution, + stripe_size, + stripe_groups, + stripe_shift, + num_heads, + attn_drop=0.0, + pretrained_stripe_size=[0, 0], + args=None, + ): + + super(StripeAttention, self).__init__() + self.input_resolution = input_resolution + self.stripe_size = stripe_size # Wh, Ww + self.stripe_groups = stripe_groups + self.stripe_shift = stripe_shift + self.num_heads = num_heads + self.pretrained_stripe_size = pretrained_stripe_size + self.euclidean_dist = args.euclidean_dist + + self.attn_transform = AffineTransformStripe( + num_heads, + input_resolution, + stripe_size, + stripe_groups, + stripe_shift, + pretrained_stripe_size, + anchor_window_down_factor=1, + args=args, + ) + self.attn_drop = nn.Dropout(attn_drop) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, qkv, x_size): + """ + Args: + x: input features with shape of (B, L, C) + stripe_size: use stripe_size to determine whether the relative positional bias table and index + need to be regenerated. + """ + H, W = x_size + B, L, C = qkv.shape + qkv = qkv.view(B, H, W, C) + + running_stripe_size, running_shift_size = self.attn_transform._get_stripe_info( + x_size + ) + # cyclic shift + if self.stripe_shift: + qkv = torch.roll( + qkv, + shifts=(-running_shift_size[0], -running_shift_size[1]), + dims=(1, 2), + ) + + # partition windows + qkv = window_partition(qkv, running_stripe_size) # nW*B, wh, ww, C + qkv = qkv.view(-1, prod(running_stripe_size), C) # nW*B, wh*ww, C + + B_, N, _ = qkv.shape + qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + # attention + x = self.attn(q, k, v, self.attn_transform, x_size) + + # merge windows + x = x.view(-1, *running_stripe_size, C // 3) + x = window_reverse(x, running_stripe_size, x_size) # B H W C/3 + + # reverse the shift + if self.stripe_shift: + x = torch.roll(x, shifts=running_shift_size, dims=(1, 2)) + + x = x.view(B, L, C // 3) + return x + + def extra_repr(self) -> str: + return ( + f"stripe_size={self.stripe_size}, stripe_groups={self.stripe_groups}, stripe_shift={self.stripe_shift}, " + f"pretrained_stripe_size={self.pretrained_stripe_size}, num_heads={self.num_heads}" + ) + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class AnchorStripeAttention(Attention): + r"""Stripe attention + Args: + stripe_size (tuple[int]): The height and width of the stripe. + num_heads (int): Number of attention heads. + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + pretrained_stripe_size (tuple[int]): The height and width of the stripe in pre-training. + """ + + def __init__( + self, + input_resolution, + stripe_size, + stripe_groups, + stripe_shift, + num_heads, + attn_drop=0.0, + pretrained_stripe_size=[0, 0], + anchor_window_down_factor=1, + args=None, + ): + + super(AnchorStripeAttention, self).__init__() + self.input_resolution = input_resolution + self.stripe_size = stripe_size # Wh, Ww + self.stripe_groups = stripe_groups + self.stripe_shift = stripe_shift + self.num_heads = num_heads + self.pretrained_stripe_size = pretrained_stripe_size + self.anchor_window_down_factor = anchor_window_down_factor + self.euclidean_dist = args.euclidean_dist + + self.attn_transform1 = AffineTransformStripe( + num_heads, + input_resolution, + stripe_size, + stripe_groups, + stripe_shift, + pretrained_stripe_size, + anchor_window_down_factor, + window_to_anchor=False, + args=args, + ) + + self.attn_transform2 = AffineTransformStripe( + num_heads, + input_resolution, + stripe_size, + stripe_groups, + stripe_shift, + pretrained_stripe_size, + anchor_window_down_factor, + window_to_anchor=True, + args=args, + ) + + self.attn_drop = nn.Dropout(attn_drop) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, qkv, anchor, x_size): + """ + Args: + qkv: input features with shape of (B, L, C) + anchor: + x_size: use stripe_size to determine whether the relative positional bias table and index + need to be regenerated. + """ + H, W = x_size + B, L, C = qkv.shape + qkv = qkv.view(B, H, W, C) + + stripe_size, shift_size = self.attn_transform1._get_stripe_info(x_size) + anchor_stripe_size = [s // self.anchor_window_down_factor for s in stripe_size] + anchor_shift_size = [s // self.anchor_window_down_factor for s in shift_size] + # cyclic shift + if self.stripe_shift: + qkv = torch.roll(qkv, shifts=(-shift_size[0], -shift_size[1]), dims=(1, 2)) + anchor = torch.roll( + anchor, + shifts=(-anchor_shift_size[0], -anchor_shift_size[1]), + dims=(1, 2), + ) + + # partition windows + qkv = window_partition(qkv, stripe_size) # nW*B, wh, ww, C + qkv = qkv.view(-1, prod(stripe_size), C) # nW*B, wh*ww, C + anchor = window_partition(anchor, anchor_stripe_size) + anchor = anchor.view(-1, prod(anchor_stripe_size), C // 3) + + B_, N1, _ = qkv.shape + N2 = anchor.shape[1] + qkv = qkv.reshape(B_, N1, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + anchor = anchor.reshape(B_, N2, self.num_heads, -1).permute(0, 2, 1, 3) + + # attention + x = self.attn(anchor, k, v, self.attn_transform1, x_size, False) + x = self.attn(q, anchor, x, self.attn_transform2, x_size) + + # merge windows + x = x.view(B_, *stripe_size, C // 3) + x = window_reverse(x, stripe_size, x_size) # B H' W' C + + # reverse the shift + if self.stripe_shift: + x = torch.roll(x, shifts=shift_size, dims=(1, 2)) + + x = x.view(B, H * W, C // 3) + return x + + def extra_repr(self) -> str: + return ( + f"stripe_size={self.stripe_size}, stripe_groups={self.stripe_groups}, stripe_shift={self.stripe_shift}, " + f"pretrained_stripe_size={self.pretrained_stripe_size}, num_heads={self.num_heads}, anchor_window_down_factor={self.anchor_window_down_factor}" + ) + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SeparableConv(nn.Sequential): + def __init__(self, in_channels, out_channels, kernel_size, stride, bias, args): + m = [ + nn.Conv2d( + in_channels, + in_channels, + kernel_size, + stride, + kernel_size // 2, + groups=in_channels, + bias=bias, + ) + ] + if args.separable_conv_act: + m.append(nn.GELU()) + m.append(nn.Conv2d(in_channels, out_channels, 1, 1, 0, bias=bias)) + super(SeparableConv, self).__init__(*m) + + +class QKVProjection(nn.Module): + def __init__(self, dim, qkv_bias, proj_type, args): + super(QKVProjection, self).__init__() + self.proj_type = proj_type + if proj_type == "linear": + self.body = nn.Linear(dim, dim * 3, bias=qkv_bias) + else: + self.body = SeparableConv(dim, dim * 3, 3, 1, qkv_bias, args) + + def forward(self, x, x_size): + if self.proj_type == "separable_conv": + x = blc_to_bchw(x, x_size) + x = self.body(x) + if self.proj_type == "separable_conv": + x = bchw_to_blc(x) + return x + + +class PatchMerging(nn.Module): + r"""Patch Merging Layer. + Args: + dim (int): Number of input channels. + """ + + def __init__(self, in_dim, out_dim): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.reduction = nn.Linear(4 * in_dim, out_dim, bias=False) + + def forward(self, x, x_size): + """ + x: B, H*W, C + """ + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.reduction(x) + + return x + + +class AnchorLinear(nn.Module): + r"""Linear anchor projection layer + Args: + dim (int): Number of input channels. + """ + + def __init__(self, in_channels, out_channels, down_factor, pooling_mode, bias): + super().__init__() + self.down_factor = down_factor + if pooling_mode == "maxpool": + self.pooling = nn.MaxPool2d(down_factor, down_factor) + elif pooling_mode == "avgpool": + self.pooling = nn.AvgPool2d(down_factor, down_factor) + self.reduction = nn.Linear(in_channels, out_channels, bias=bias) + + def forward(self, x, x_size): + """ + x: B, H*W, C + """ + x = blc_to_bchw(x, x_size) + x = bchw_to_blc(self.pooling(x)) + x = blc_to_bhwc(self.reduction(x), [s // self.down_factor for s in x_size]) + return x + + +class AnchorProjection(nn.Module): + def __init__(self, dim, proj_type, one_stage, anchor_window_down_factor, args): + super(AnchorProjection, self).__init__() + self.proj_type = proj_type + self.body = nn.ModuleList([]) + if one_stage: + if proj_type == "patchmerging": + m = PatchMerging(dim, dim // 2) + elif proj_type == "conv2d": + kernel_size = anchor_window_down_factor + 1 + stride = anchor_window_down_factor + padding = kernel_size // 2 + m = nn.Conv2d(dim, dim // 2, kernel_size, stride, padding) + elif proj_type == "separable_conv": + kernel_size = anchor_window_down_factor + 1 + stride = anchor_window_down_factor + m = SeparableConv(dim, dim // 2, kernel_size, stride, True, args) + elif proj_type.find("pool") >= 0: + m = AnchorLinear( + dim, dim // 2, anchor_window_down_factor, proj_type, True + ) + self.body.append(m) + else: + for i in range(int(math.log2(anchor_window_down_factor))): + cin = dim if i == 0 else dim // 2 + if proj_type == "patchmerging": + m = PatchMerging(cin, dim // 2) + elif proj_type == "conv2d": + m = nn.Conv2d(cin, dim // 2, 3, 2, 1) + elif proj_type == "separable_conv": + m = SeparableConv(cin, dim // 2, 3, 2, True, args) + self.body.append(m) + + def forward(self, x, x_size): + if self.proj_type.find("conv") >= 0: + x = blc_to_bchw(x, x_size) + for m in self.body: + x = m(x) + x = bchw_to_bhwc(x) + elif self.proj_type.find("pool") >= 0: + for m in self.body: + x = m(x, x_size) + else: + for i, m in enumerate(self.body): + x = m(x, [s // 2**i for s in x_size]) + x = blc_to_bhwc(x, [s // 2 ** (i + 1) for s in x_size]) + return x + + +class MixedAttention(nn.Module): + r"""Mixed window attention and stripe attention + Args: + dim (int): Number of input channels. + stripe_size (tuple[int]): The height and width of the stripe. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + pretrained_stripe_size (tuple[int]): The height and width of the stripe in pre-training. + """ + + def __init__( + self, + dim, + input_resolution, + num_heads_w, + num_heads_s, + window_size, + window_shift, + stripe_size, + stripe_groups, + stripe_shift, + qkv_bias=True, + qkv_proj_type="linear", + anchor_proj_type="separable_conv", + anchor_one_stage=True, + anchor_window_down_factor=1, + attn_drop=0.0, + proj_drop=0.0, + pretrained_window_size=[0, 0], + pretrained_stripe_size=[0, 0], + args=None, + ): + + super(MixedAttention, self).__init__() + self.dim = dim + self.input_resolution = input_resolution + self.use_anchor = anchor_window_down_factor > 1 + self.args = args + # print(args) + self.qkv = QKVProjection(dim, qkv_bias, qkv_proj_type, args) + if self.use_anchor: + # anchor is only used for stripe attention + self.anchor = AnchorProjection( + dim, anchor_proj_type, anchor_one_stage, anchor_window_down_factor, args + ) + + self.window_attn = WindowAttention( + input_resolution, + window_size, + num_heads_w, + window_shift, + attn_drop, + pretrained_window_size, + args, + ) + + if self.args.double_window: + self.stripe_attn = WindowAttention( + input_resolution, + window_size, + num_heads_w, + window_shift, + attn_drop, + pretrained_window_size, + args, + ) + else: + if self.use_anchor: + self.stripe_attn = AnchorStripeAttention( + input_resolution, + stripe_size, + stripe_groups, + stripe_shift, + num_heads_s, + attn_drop, + pretrained_stripe_size, + anchor_window_down_factor, + args, + ) + else: + if self.args.stripe_square: + self.stripe_attn = StripeAttention( + input_resolution, + window_size, + [None, None], + window_shift, + num_heads_s, + attn_drop, + pretrained_stripe_size, + args, + ) + else: + self.stripe_attn = StripeAttention( + input_resolution, + stripe_size, + stripe_groups, + stripe_shift, + num_heads_s, + attn_drop, + pretrained_stripe_size, + args, + ) + if self.args.out_proj_type == "linear": + self.proj = nn.Linear(dim, dim) + else: + self.proj = nn.Conv2d(dim, dim, 3, 1, 1) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, x_size): + """ + Args: + x: input features with shape of (B, L, C) + stripe_size: use stripe_size to determine whether the relative positional bias table and index + need to be regenerated. + """ + B, L, C = x.shape + + # qkv projection + qkv = self.qkv(x, x_size) + qkv_window, qkv_stripe = torch.split(qkv, C * 3 // 2, dim=-1) + # anchor projection + if self.use_anchor: + anchor = self.anchor(x, x_size) + + # attention + x_window = self.window_attn(qkv_window, x_size) + if self.use_anchor: + x_stripe = self.stripe_attn(qkv_stripe, anchor, x_size) + else: + x_stripe = self.stripe_attn(qkv_stripe, x_size) + x = torch.cat([x_window, x_stripe], dim=-1) + + # output projection + if self.args.out_proj_type == "linear": + x = self.proj(x) + else: + x = blc_to_bchw(x, x_size) + x = bchw_to_blc(self.proj(x)) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}" + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class ChannelAttention(nn.Module): + """Channel attention used in RCAN. + Args: + num_feat (int): Channel number of intermediate features. + reduction (int): Channel reduction factor. Default: 16. + """ + + def __init__(self, num_feat, reduction=16): + super(ChannelAttention, self).__init__() + self.attention = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(num_feat, num_feat // reduction, 1, padding=0), + nn.ReLU(inplace=True), + nn.Conv2d(num_feat // reduction, num_feat, 1, padding=0), + nn.Sigmoid(), + ) + + def forward(self, x): + y = self.attention(x) + return x * y + + +class CAB(nn.Module): + def __init__(self, num_feat, compress_ratio=4, reduction=18): + super(CAB, self).__init__() + + self.cab = nn.Sequential( + nn.Conv2d(num_feat, num_feat // compress_ratio, 3, 1, 1), + nn.GELU(), + nn.Conv2d(num_feat // compress_ratio, num_feat, 3, 1, 1), + ChannelAttention(num_feat, reduction), + ) + + def forward(self, x, x_size): + x = self.cab(blc_to_bchw(x, x_size).contiguous()) + return bchw_to_blc(x) + + +class MixAttnTransformerBlock(nn.Module): + r"""Mix attention transformer block with shared QKV projection and output projection for mixed attention modules. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + pretrained_stripe_size (int): Window size in pre-training. + attn_type (str, optional): Attention type. Default: cwhv. + c: residual blocks + w: window attention + h: horizontal stripe attention + v: vertical stripe attention + """ + + def __init__( + self, + dim, + input_resolution, + num_heads_w, + num_heads_s, + window_size=7, + window_shift=False, + stripe_size=[8, 8], + stripe_groups=[None, None], + stripe_shift=False, + stripe_type="H", + mlp_ratio=4.0, + qkv_bias=True, + qkv_proj_type="linear", + anchor_proj_type="separable_conv", + anchor_one_stage=True, + anchor_window_down_factor=1, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + pretrained_window_size=[0, 0], + pretrained_stripe_size=[0, 0], + res_scale=1.0, + args=None, + ): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads_w = num_heads_w + self.num_heads_s = num_heads_s + self.window_size = window_size + self.window_shift = window_shift + self.stripe_shift = stripe_shift + self.stripe_type = stripe_type + self.args = args + if self.stripe_type == "W": + self.stripe_size = stripe_size[::-1] + self.stripe_groups = stripe_groups[::-1] + else: + self.stripe_size = stripe_size + self.stripe_groups = stripe_groups + self.mlp_ratio = mlp_ratio + self.res_scale = res_scale + + self.attn = MixedAttention( + dim, + input_resolution, + num_heads_w, + num_heads_s, + window_size, + window_shift, + self.stripe_size, + self.stripe_groups, + stripe_shift, + qkv_bias, + qkv_proj_type, + anchor_proj_type, + anchor_one_stage, + anchor_window_down_factor, + attn_drop, + drop, + pretrained_window_size, + pretrained_stripe_size, + args, + ) + self.norm1 = norm_layer(dim) + if self.args.local_connection: + self.conv = CAB(dim) + + # self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + # self.mlp = Mlp( + # in_features=dim, + # hidden_features=int(dim * mlp_ratio), + # act_layer=act_layer, + # drop=drop, + # ) + # self.norm2 = norm_layer(dim) + + def forward(self, x, x_size): + # Mixed attention + if self.args.local_connection: + x = ( + x + + self.res_scale * self.drop_path(self.norm1(self.attn(x, x_size))) + + self.conv(x, x_size) + ) + else: + x = x + self.res_scale * self.drop_path(self.norm1(self.attn(x, x_size))) + # FFN + x = x + self.res_scale * self.drop_path(self.norm2(self.mlp(x))) + + # return x + + def extra_repr(self) -> str: + return ( + f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads=({self.num_heads_w}, {self.num_heads_s}), " + f"window_size={self.window_size}, window_shift={self.window_shift}, " + f"stripe_size={self.stripe_size}, stripe_groups={self.stripe_groups}, stripe_shift={self.stripe_shift}, self.stripe_type={self.stripe_type}, " + f"mlp_ratio={self.mlp_ratio}, res_scale={self.res_scale}" + ) + + +# def flops(self): +# flops = 0 +# H, W = self.input_resolution +# # norm1 +# flops += self.dim * H * W +# # W-MSA/SW-MSA +# nW = H * W / self.stripe_size[0] / self.stripe_size[1] +# flops += nW * self.attn.flops(self.stripe_size[0] * self.stripe_size[1]) +# # mlp +# flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio +# # norm2 +# flops += self.dim * H * W +# return flops diff --git a/architecture/grl_common/mixed_attn_block_efficient.py b/architecture/grl_common/mixed_attn_block_efficient.py new file mode 100644 index 0000000000000000000000000000000000000000..3cb78c23b79281423d6065f72f307b560543bd1c --- /dev/null +++ b/architecture/grl_common/mixed_attn_block_efficient.py @@ -0,0 +1,568 @@ +import math +from abc import ABC +from math import prod + +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import DropPath + + +from architecture.grl_common.mixed_attn_block import ( + AnchorProjection, + CAB, + CPB_MLP, + QKVProjection, +) +from architecture.grl_common.ops import ( + window_partition, + window_reverse, +) +from architecture.grl_common.swin_v1_block import Mlp + + +class AffineTransform(nn.Module): + r"""Affine transformation of the attention map. + The window could be a square window or a stripe window. Supports attention between different window sizes + """ + + def __init__(self, num_heads): + super(AffineTransform, self).__init__() + logit_scale = torch.log(10 * torch.ones((num_heads, 1, 1))) + self.logit_scale = nn.Parameter(logit_scale, requires_grad=True) + + # mlp to generate continuous relative position bias + self.cpb_mlp = CPB_MLP(2, num_heads) + + def forward(self, attn, relative_coords_table, relative_position_index, mask): + B_, H, N1, N2 = attn.shape + # logit scale + attn = attn * torch.clamp(self.logit_scale, max=math.log(1.0 / 0.01)).exp() + + bias_table = self.cpb_mlp(relative_coords_table) # 2*Wh-1, 2*Ww-1, num_heads + bias_table = bias_table.view(-1, H) + + bias = bias_table[relative_position_index.view(-1)] + bias = bias.view(N1, N2, -1).permute(2, 0, 1).contiguous() + # nH, Wh*Ww, Wh*Ww + bias = 16 * torch.sigmoid(bias) + attn = attn + bias.unsqueeze(0) + + # W-MSA/SW-MSA + # shift attention mask + if mask is not None: + nW = mask.shape[0] + mask = mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(B_ // nW, nW, H, N1, N2) + mask + attn = attn.view(-1, H, N1, N2) + + return attn + + +def _get_stripe_info(stripe_size_in, stripe_groups_in, stripe_shift, input_resolution): + stripe_size, shift_size = [], [] + for s, g, d in zip(stripe_size_in, stripe_groups_in, input_resolution): + if g is None: + stripe_size.append(s) + shift_size.append(s // 2 if stripe_shift else 0) + else: + stripe_size.append(d // g) + shift_size.append(0 if g == 1 else d // (g * 2)) + return stripe_size, shift_size + + +class Attention(ABC, nn.Module): + def __init__(self): + super(Attention, self).__init__() + + def attn(self, q, k, v, attn_transform, table, index, mask, reshape=True): + # q, k, v: # nW*B, H, wh*ww, dim + # cosine attention map + B_, _, H, head_dim = q.shape + if self.euclidean_dist: + # print("use euclidean distance") + attn = torch.norm(q.unsqueeze(-2) - k.unsqueeze(-3), dim=-1) + else: + attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1) + attn = attn_transform(attn, table, index, mask) + # attention + attn = self.softmax(attn) + attn = self.attn_drop(attn) + x = attn @ v # B_, H, N1, head_dim + if reshape: + x = x.transpose(1, 2).reshape(B_, -1, H * head_dim) + # B_, N, C + return x + + +class WindowAttention(Attention): + r"""Window attention. QKV is the input to the forward method. + Args: + num_heads (int): Number of attention heads. + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + pretrained_window_size (tuple[int]): The height and width of the window in pre-training. + """ + + def __init__( + self, + input_resolution, + window_size, + num_heads, + window_shift=False, + attn_drop=0.0, + pretrained_window_size=[0, 0], + args=None, + ): + + super(WindowAttention, self).__init__() + self.input_resolution = input_resolution + self.window_size = window_size + self.pretrained_window_size = pretrained_window_size + self.num_heads = num_heads + self.shift_size = window_size[0] // 2 if window_shift else 0 + self.euclidean_dist = args.euclidean_dist + + self.attn_transform = AffineTransform(num_heads) + self.attn_drop = nn.Dropout(attn_drop) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, qkv, x_size, table, index, mask): + """ + Args: + qkv: input QKV features with shape of (B, L, 3C) + x_size: use x_size to determine whether the relative positional bias table and index + need to be regenerated. + """ + H, W = x_size + B, L, C = qkv.shape + qkv = qkv.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + qkv = torch.roll( + qkv, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2) + ) + + # partition windows + qkv = window_partition(qkv, self.window_size) # nW*B, wh, ww, C + qkv = qkv.view(-1, prod(self.window_size), C) # nW*B, wh*ww, C + + B_, N, _ = qkv.shape + qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # nW*B, H, wh*ww, dim + + # attention + x = self.attn(q, k, v, self.attn_transform, table, index, mask) + + # merge windows + x = x.view(-1, *self.window_size, C // 3) + x = window_reverse(x, self.window_size, x_size) # B, H, W, C/3 + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + x = x.view(B, L, C // 3) + + return x + + def extra_repr(self) -> str: + return ( + f"window_size={self.window_size}, shift_size={self.shift_size}, " + f"pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}" + ) + + def flops(self, N): + pass + + +class AnchorStripeAttention(Attention): + r"""Stripe attention + Args: + stripe_size (tuple[int]): The height and width of the stripe. + num_heads (int): Number of attention heads. + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + pretrained_stripe_size (tuple[int]): The height and width of the stripe in pre-training. + """ + + def __init__( + self, + input_resolution, + stripe_size, + stripe_groups, + stripe_shift, + num_heads, + attn_drop=0.0, + pretrained_stripe_size=[0, 0], + anchor_window_down_factor=1, + args=None, + ): + + super(AnchorStripeAttention, self).__init__() + self.input_resolution = input_resolution + self.stripe_size = stripe_size # Wh, Ww + self.stripe_groups = stripe_groups + self.stripe_shift = stripe_shift + self.num_heads = num_heads + self.pretrained_stripe_size = pretrained_stripe_size + self.anchor_window_down_factor = anchor_window_down_factor + self.euclidean_dist = args.euclidean_dist + + self.attn_transform1 = AffineTransform(num_heads) + self.attn_transform2 = AffineTransform(num_heads) + + self.attn_drop = nn.Dropout(attn_drop) + self.softmax = nn.Softmax(dim=-1) + + def forward( + self, qkv, anchor, x_size, table, index_a2w, index_w2a, mask_a2w, mask_w2a + ): + """ + Args: + qkv: input features with shape of (B, L, C) + anchor: + x_size: use stripe_size to determine whether the relative positional bias table and index + need to be regenerated. + """ + H, W = x_size + B, L, C = qkv.shape + qkv = qkv.view(B, H, W, C) + + stripe_size, shift_size = _get_stripe_info( + self.stripe_size, self.stripe_groups, self.stripe_shift, x_size + ) + anchor_stripe_size = [s // self.anchor_window_down_factor for s in stripe_size] + anchor_shift_size = [s // self.anchor_window_down_factor for s in shift_size] + # cyclic shift + if self.stripe_shift: + qkv = torch.roll(qkv, shifts=(-shift_size[0], -shift_size[1]), dims=(1, 2)) + anchor = torch.roll( + anchor, + shifts=(-anchor_shift_size[0], -anchor_shift_size[1]), + dims=(1, 2), + ) + + # partition windows + qkv = window_partition(qkv, stripe_size) # nW*B, wh, ww, C + qkv = qkv.view(-1, prod(stripe_size), C) # nW*B, wh*ww, C + anchor = window_partition(anchor, anchor_stripe_size) + anchor = anchor.view(-1, prod(anchor_stripe_size), C // 3) + + B_, N1, _ = qkv.shape + N2 = anchor.shape[1] + qkv = qkv.reshape(B_, N1, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + anchor = anchor.reshape(B_, N2, self.num_heads, -1).permute(0, 2, 1, 3) + + # attention + x = self.attn( + anchor, k, v, self.attn_transform1, table, index_a2w, mask_a2w, False + ) + x = self.attn(q, anchor, x, self.attn_transform2, table, index_w2a, mask_w2a) + + # merge windows + x = x.view(B_, *stripe_size, C // 3) + x = window_reverse(x, stripe_size, x_size) # B H' W' C + + # reverse the shift + if self.stripe_shift: + x = torch.roll(x, shifts=shift_size, dims=(1, 2)) + + x = x.view(B, H * W, C // 3) + return x + + def extra_repr(self) -> str: + return ( + f"stripe_size={self.stripe_size}, stripe_groups={self.stripe_groups}, stripe_shift={self.stripe_shift}, " + f"pretrained_stripe_size={self.pretrained_stripe_size}, num_heads={self.num_heads}, anchor_window_down_factor={self.anchor_window_down_factor}" + ) + + def flops(self, N): + pass + + +class MixedAttention(nn.Module): + r"""Mixed window attention and stripe attention + Args: + dim (int): Number of input channels. + stripe_size (tuple[int]): The height and width of the stripe. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + pretrained_stripe_size (tuple[int]): The height and width of the stripe in pre-training. + """ + + def __init__( + self, + dim, + input_resolution, + num_heads_w, + num_heads_s, + window_size, + window_shift, + stripe_size, + stripe_groups, + stripe_shift, + qkv_bias=True, + qkv_proj_type="linear", + anchor_proj_type="separable_conv", + anchor_one_stage=True, + anchor_window_down_factor=1, + attn_drop=0.0, + proj_drop=0.0, + pretrained_window_size=[0, 0], + pretrained_stripe_size=[0, 0], + args=None, + ): + + super(MixedAttention, self).__init__() + self.dim = dim + self.input_resolution = input_resolution + self.args = args + # print(args) + self.qkv = QKVProjection(dim, qkv_bias, qkv_proj_type, args) + # anchor is only used for stripe attention + self.anchor = AnchorProjection( + dim, anchor_proj_type, anchor_one_stage, anchor_window_down_factor, args + ) + + self.window_attn = WindowAttention( + input_resolution, + window_size, + num_heads_w, + window_shift, + attn_drop, + pretrained_window_size, + args, + ) + self.stripe_attn = AnchorStripeAttention( + input_resolution, + stripe_size, + stripe_groups, + stripe_shift, + num_heads_s, + attn_drop, + pretrained_stripe_size, + anchor_window_down_factor, + args, + ) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, x_size, table_index_mask): + """ + Args: + x: input features with shape of (B, L, C) + stripe_size: use stripe_size to determine whether the relative positional bias table and index + need to be regenerated. + """ + B, L, C = x.shape + + # qkv projection + qkv = self.qkv(x, x_size) + qkv_window, qkv_stripe = torch.split(qkv, C * 3 // 2, dim=-1) + # anchor projection + anchor = self.anchor(x, x_size) + + # attention + x_window = self.window_attn( + qkv_window, x_size, *self._get_table_index_mask(table_index_mask, True) + ) + x_stripe = self.stripe_attn( + qkv_stripe, + anchor, + x_size, + *self._get_table_index_mask(table_index_mask, False), + ) + x = torch.cat([x_window, x_stripe], dim=-1) + + # output projection + x = self.proj(x) + x = self.proj_drop(x) + return x + + def _get_table_index_mask(self, table_index_mask, window_attn=True): + if window_attn: + return ( + table_index_mask["table_w"], + table_index_mask["index_w"], + table_index_mask["mask_w"], + ) + else: + return ( + table_index_mask["table_s"], + table_index_mask["index_a2w"], + table_index_mask["index_w2a"], + table_index_mask["mask_a2w"], + table_index_mask["mask_w2a"], + ) + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}" + + def flops(self, N): + pass + + +class EfficientMixAttnTransformerBlock(nn.Module): + r"""Mix attention transformer block with shared QKV projection and output projection for mixed attention modules. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + pretrained_stripe_size (int): Window size in pre-training. + attn_type (str, optional): Attention type. Default: cwhv. + c: residual blocks + w: window attention + h: horizontal stripe attention + v: vertical stripe attention + """ + + def __init__( + self, + dim, + input_resolution, + num_heads_w, + num_heads_s, + window_size=7, + window_shift=False, + stripe_size=[8, 8], + stripe_groups=[None, None], + stripe_shift=False, + stripe_type="H", + mlp_ratio=4.0, + qkv_bias=True, + qkv_proj_type="linear", + anchor_proj_type="separable_conv", + anchor_one_stage=True, + anchor_window_down_factor=1, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + pretrained_window_size=[0, 0], + pretrained_stripe_size=[0, 0], + res_scale=1.0, + args=None, + ): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads_w = num_heads_w + self.num_heads_s = num_heads_s + self.window_size = window_size + self.window_shift = window_shift + self.stripe_shift = stripe_shift + self.stripe_type = stripe_type + self.args = args + if self.stripe_type == "W": + self.stripe_size = stripe_size[::-1] + self.stripe_groups = stripe_groups[::-1] + else: + self.stripe_size = stripe_size + self.stripe_groups = stripe_groups + self.mlp_ratio = mlp_ratio + self.res_scale = res_scale + + self.attn = MixedAttention( + dim, + input_resolution, + num_heads_w, + num_heads_s, + window_size, + window_shift, + self.stripe_size, + self.stripe_groups, + stripe_shift, + qkv_bias, + qkv_proj_type, + anchor_proj_type, + anchor_one_stage, + anchor_window_down_factor, + attn_drop, + drop, + pretrained_window_size, + pretrained_stripe_size, + args, + ) + self.norm1 = norm_layer(dim) + if self.args.local_connection: + self.conv = CAB(dim) + + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.mlp = Mlp( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + drop=drop, + ) + self.norm2 = norm_layer(dim) + + def _get_table_index_mask(self, all_table_index_mask): + table_index_mask = { + "table_w": all_table_index_mask["table_w"], + "index_w": all_table_index_mask["index_w"], + } + if self.stripe_type == "W": + table_index_mask["table_s"] = all_table_index_mask["table_sv"] + table_index_mask["index_a2w"] = all_table_index_mask["index_sv_a2w"] + table_index_mask["index_w2a"] = all_table_index_mask["index_sv_w2a"] + else: + table_index_mask["table_s"] = all_table_index_mask["table_sh"] + table_index_mask["index_a2w"] = all_table_index_mask["index_sh_a2w"] + table_index_mask["index_w2a"] = all_table_index_mask["index_sh_w2a"] + if self.window_shift: + table_index_mask["mask_w"] = all_table_index_mask["mask_w"] + else: + table_index_mask["mask_w"] = None + if self.stripe_shift: + if self.stripe_type == "W": + table_index_mask["mask_a2w"] = all_table_index_mask["mask_sv_a2w"] + table_index_mask["mask_w2a"] = all_table_index_mask["mask_sv_w2a"] + else: + table_index_mask["mask_a2w"] = all_table_index_mask["mask_sh_a2w"] + table_index_mask["mask_w2a"] = all_table_index_mask["mask_sh_w2a"] + else: + table_index_mask["mask_a2w"] = None + table_index_mask["mask_w2a"] = None + return table_index_mask + + def forward(self, x, x_size, all_table_index_mask): + # Mixed attention + table_index_mask = self._get_table_index_mask(all_table_index_mask) + if self.args.local_connection: + x = ( + x + + self.res_scale + * self.drop_path(self.norm1(self.attn(x, x_size, table_index_mask))) + + self.conv(x, x_size) + ) + else: + x = x + self.res_scale * self.drop_path( + self.norm1(self.attn(x, x_size, table_index_mask)) + ) + # FFN + x = x + self.res_scale * self.drop_path(self.norm2(self.mlp(x))) + + return x + + def extra_repr(self) -> str: + return ( + f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads=({self.num_heads_w}, {self.num_heads_s}), " + f"window_size={self.window_size}, window_shift={self.window_shift}, " + f"stripe_size={self.stripe_size}, stripe_groups={self.stripe_groups}, stripe_shift={self.stripe_shift}, self.stripe_type={self.stripe_type}, " + f"mlp_ratio={self.mlp_ratio}, res_scale={self.res_scale}" + ) + + def flops(self): + pass diff --git a/architecture/grl_common/ops.py b/architecture/grl_common/ops.py new file mode 100644 index 0000000000000000000000000000000000000000..37406bd8795b61781eaca1d4a854547eff1725a0 --- /dev/null +++ b/architecture/grl_common/ops.py @@ -0,0 +1,551 @@ +from math import prod +from typing import Tuple + +import numpy as np +import torch +from timm.models.layers import to_2tuple + + +def bchw_to_bhwc(x: torch.Tensor) -> torch.Tensor: + """Permutes a tensor from the shape (B, C, H, W) to (B, H, W, C).""" + return x.permute(0, 2, 3, 1) + + +def bhwc_to_bchw(x: torch.Tensor) -> torch.Tensor: + """Permutes a tensor from the shape (B, H, W, C) to (B, C, H, W).""" + return x.permute(0, 3, 1, 2) + + +def bchw_to_blc(x: torch.Tensor) -> torch.Tensor: + """Rearrange a tensor from the shape (B, C, H, W) to (B, L, C).""" + return x.flatten(2).transpose(1, 2) + + +def blc_to_bchw(x: torch.Tensor, x_size: Tuple) -> torch.Tensor: + """Rearrange a tensor from the shape (B, L, C) to (B, C, H, W).""" + B, L, C = x.shape + return x.transpose(1, 2).view(B, C, *x_size) + + +def blc_to_bhwc(x: torch.Tensor, x_size: Tuple) -> torch.Tensor: + """Rearrange a tensor from the shape (B, L, C) to (B, H, W, C).""" + B, L, C = x.shape + return x.view(B, *x_size, C) + + +def window_partition(x, window_size: Tuple[int, int]): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view( + B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C + ) + windows = ( + x.permute(0, 1, 3, 2, 4, 5) + .contiguous() + .view(-1, window_size[0], window_size[1], C) + ) + return windows + + +def window_reverse(windows, window_size: Tuple[int, int], img_size: Tuple[int, int]): + """ + Args: + windows: (num_windows * B, window_size[0], window_size[1], C) + window_size (Tuple[int, int]): Window size + img_size (Tuple[int, int]): Image size + + Returns: + x: (B, H, W, C) + """ + H, W = img_size + B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1])) + x = windows.view( + B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1 + ) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +def _fill_window(input_resolution, window_size, shift_size=None): + if shift_size is None: + shift_size = [s // 2 for s in window_size] + + img_mask = torch.zeros((1, *input_resolution, 1)) # 1 H W 1 + h_slices = ( + slice(0, -window_size[0]), + slice(-window_size[0], -shift_size[0]), + slice(-shift_size[0], None), + ) + w_slices = ( + slice(0, -window_size[1]), + slice(-window_size[1], -shift_size[1]), + slice(-shift_size[1], None), + ) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, window_size) + # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, prod(window_size)) + return mask_windows + + +##################################### +# Different versions of the functions +# 1) Swin Transformer, SwinIR, Square window attention in GRL; +# 2) Early development of the decomposition-based efficient attention mechanism (efficient_win_attn.py); +# 3) GRL. Window-anchor attention mechanism. +# 1) & 3) are still useful +##################################### + + +def calculate_mask(input_resolution, window_size, shift_size): + """ + Use case: 1) + """ + # calculate attention mask for SW-MSA + if isinstance(shift_size, int): + shift_size = to_2tuple(shift_size) + mask_windows = _fill_window(input_resolution, window_size, shift_size) + + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill( + attn_mask == 0, float(0.0) + ) # nW, window_size**2, window_size**2 + + return attn_mask + + +def calculate_mask_all( + input_resolution, + window_size, + shift_size, + anchor_window_down_factor=1, + window_to_anchor=True, +): + """ + Use case: 3) + """ + # calculate attention mask for SW-MSA + anchor_resolution = [s // anchor_window_down_factor for s in input_resolution] + aws = [s // anchor_window_down_factor for s in window_size] + anchor_shift = [s // anchor_window_down_factor for s in shift_size] + + # mask of window1: nW, Wh**Ww + mask_windows = _fill_window(input_resolution, window_size, shift_size) + # mask of window2: nW, AWh*AWw + mask_anchor = _fill_window(anchor_resolution, aws, anchor_shift) + + if window_to_anchor: + attn_mask = mask_windows.unsqueeze(2) - mask_anchor.unsqueeze(1) + else: + attn_mask = mask_anchor.unsqueeze(2) - mask_windows.unsqueeze(1) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill( + attn_mask == 0, float(0.0) + ) # nW, Wh**Ww, AWh*AWw + + return attn_mask + + +def calculate_win_mask( + input_resolution1, input_resolution2, window_size1, window_size2 +): + """ + Use case: 2) + """ + # calculate attention mask for SW-MSA + + # mask of window1: nW, Wh**Ww + mask_windows1 = _fill_window(input_resolution1, window_size1) + # mask of window2: nW, AWh*AWw + mask_windows2 = _fill_window(input_resolution2, window_size2) + + attn_mask = mask_windows1.unsqueeze(2) - mask_windows2.unsqueeze(1) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill( + attn_mask == 0, float(0.0) + ) # nW, Wh**Ww, AWh*AWw + + return attn_mask + + +def _get_meshgrid_coords(start_coords, end_coords): + coord_h = torch.arange(start_coords[0], end_coords[0]) + coord_w = torch.arange(start_coords[1], end_coords[1]) + coords = torch.stack(torch.meshgrid([coord_h, coord_w], indexing="ij")) # 2, Wh, Ww + coords = torch.flatten(coords, 1) # 2, Wh*Ww + return coords + + +def get_relative_coords_table( + window_size, pretrained_window_size=[0, 0], anchor_window_down_factor=1 +): + """ + Use case: 1) + """ + # get relative_coords_table + ws = window_size + aws = [w // anchor_window_down_factor for w in window_size] + pws = pretrained_window_size + paws = [w // anchor_window_down_factor for w in pretrained_window_size] + + ts = [(w1 + w2) // 2 for w1, w2 in zip(ws, aws)] + pts = [(w1 + w2) // 2 for w1, w2 in zip(pws, paws)] + + # TODO: pretrained window size and pretrained anchor window size is only used here. + # TODO: Investigate whether it is really important to use this setting when finetuning large window size + # TODO: based on pretrained weights with small window size. + + coord_h = torch.arange(-(ts[0] - 1), ts[0], dtype=torch.float32) + coord_w = torch.arange(-(ts[1] - 1), ts[1], dtype=torch.float32) + table = torch.stack(torch.meshgrid([coord_h, coord_w], indexing="ij")).permute( + 1, 2, 0 + ) + table = table.contiguous().unsqueeze(0) # 1, Wh+AWh-1, Ww+AWw-1, 2 + if pts[0] > 0: + table[:, :, :, 0] /= pts[0] - 1 + table[:, :, :, 1] /= pts[1] - 1 + else: + table[:, :, :, 0] /= ts[0] - 1 + table[:, :, :, 1] /= ts[1] - 1 + table *= 8 # normalize to -8, 8 + table = torch.sign(table) * torch.log2(torch.abs(table) + 1.0) / np.log2(8) + return table + + +def get_relative_coords_table_all( + window_size, pretrained_window_size=[0, 0], anchor_window_down_factor=1 +): + """ + Use case: 3) + + Support all window shapes. + Args: + window_size: + pretrained_window_size: + anchor_window_down_factor: + + Returns: + + """ + # get relative_coords_table + ws = window_size + aws = [w // anchor_window_down_factor for w in window_size] + pws = pretrained_window_size + paws = [w // anchor_window_down_factor for w in pretrained_window_size] + + # positive table size: (Ww - 1) - (Ww - AWw) // 2 + ts_p = [w1 - 1 - (w1 - w2) // 2 for w1, w2 in zip(ws, aws)] + # negative table size: -(AWw - 1) - (Ww - AWw) // 2 + ts_n = [-(w2 - 1) - (w1 - w2) // 2 for w1, w2 in zip(ws, aws)] + pts = [w1 - 1 - (w1 - w2) // 2 for w1, w2 in zip(pws, paws)] + + # TODO: pretrained window size and pretrained anchor window size is only used here. + # TODO: Investigate whether it is really important to use this setting when finetuning large window size + # TODO: based on pretrained weights with small window size. + + coord_h = torch.arange(ts_n[0], ts_p[0] + 1, dtype=torch.float32) + coord_w = torch.arange(ts_n[1], ts_p[1] + 1, dtype=torch.float32) + table = torch.stack(torch.meshgrid([coord_h, coord_w], indexing="ij")).permute( + 1, 2, 0 + ) + table = table.contiguous().unsqueeze(0) # 1, Wh+AWh-1, Ww+AWw-1, 2 + if pts[0] > 0: + table[:, :, :, 0] /= pts[0] + table[:, :, :, 1] /= pts[1] + else: + table[:, :, :, 0] /= ts_p[0] + table[:, :, :, 1] /= ts_p[1] + table *= 8 # normalize to -8, 8 + table = torch.sign(table) * torch.log2(torch.abs(table) + 1.0) / np.log2(8) + # 1, Wh+AWh-1, Ww+AWw-1, 2 + return table + + +def coords_diff(coords1, coords2, max_diff): + # The coordinates starts from (-start_coord[0], -start_coord[1]) + coords = coords1[:, :, None] - coords2[:, None, :] # 2, Wh*Ww, AWh*AWw + coords = coords.permute(1, 2, 0).contiguous() # Wh*Ww, AWh*AWw, 2 + coords[:, :, 0] += max_diff[0] - 1 # shift to start from 0 + coords[:, :, 1] += max_diff[1] - 1 + coords[:, :, 0] *= 2 * max_diff[1] - 1 + idx = coords.sum(-1) # Wh*Ww, AWh*AWw + return idx + + +def get_relative_position_index( + window_size, anchor_window_down_factor=1, window_to_anchor=True +): + """ + Use case: 1) + """ + # get pair-wise relative position index for each token inside the window + ws = window_size + aws = [w // anchor_window_down_factor for w in window_size] + coords_anchor_end = [(w1 + w2) // 2 for w1, w2 in zip(ws, aws)] + coords_anchor_start = [(w1 - w2) // 2 for w1, w2 in zip(ws, aws)] + + coords = _get_meshgrid_coords((0, 0), window_size) # 2, Wh*Ww + coords_anchor = _get_meshgrid_coords(coords_anchor_start, coords_anchor_end) + # 2, AWh*AWw + + if window_to_anchor: + idx = coords_diff(coords, coords_anchor, max_diff=coords_anchor_end) + else: + idx = coords_diff(coords_anchor, coords, max_diff=coords_anchor_end) + return idx # Wh*Ww, AWh*AWw or AWh*AWw, Wh*Ww + + +def coords_diff_odd(coords1, coords2, start_coord, max_diff): + # The coordinates starts from (-start_coord[0], -start_coord[1]) + coords = coords1[:, :, None] - coords2[:, None, :] # 2, Wh*Ww, AWh*AWw + coords = coords.permute(1, 2, 0).contiguous() # Wh*Ww, AWh*AWw, 2 + coords[:, :, 0] += start_coord[0] # shift to start from 0 + coords[:, :, 1] += start_coord[1] + coords[:, :, 0] *= max_diff + idx = coords.sum(-1) # Wh*Ww, AWh*AWw + return idx + + +def get_relative_position_index_all( + window_size, anchor_window_down_factor=1, window_to_anchor=True +): + """ + Use case: 3) + Support all window shapes: + square window - square window + rectangular window - rectangular window + window - anchor + anchor - window + [8, 8] - [8, 8] + [4, 86] - [2, 43] + """ + # get pair-wise relative position index for each token inside the window + ws = window_size + aws = [w // anchor_window_down_factor for w in window_size] + coords_anchor_start = [(w1 - w2) // 2 for w1, w2 in zip(ws, aws)] + coords_anchor_end = [s + w2 for s, w2 in zip(coords_anchor_start, aws)] + + coords = _get_meshgrid_coords((0, 0), window_size) # 2, Wh*Ww + coords_anchor = _get_meshgrid_coords(coords_anchor_start, coords_anchor_end) + # 2, AWh*AWw + + max_horizontal_diff = aws[1] + ws[1] - 1 + if window_to_anchor: + offset = [w2 + s - 1 for s, w2 in zip(coords_anchor_start, aws)] + idx = coords_diff_odd(coords, coords_anchor, offset, max_horizontal_diff) + else: + offset = [w1 - s - 1 for s, w1 in zip(coords_anchor_start, ws)] + idx = coords_diff_odd(coords_anchor, coords, offset, max_horizontal_diff) + return idx # Wh*Ww, AWh*AWw or AWh*AWw, Wh*Ww + + +def get_relative_position_index_simple( + window_size, anchor_window_down_factor=1, window_to_anchor=True +): + """ + Use case: 3) + This is a simplified version of get_relative_position_index_all + The start coordinate of anchor window is also (0, 0) + get pair-wise relative position index for each token inside the window + """ + ws = window_size + aws = [w // anchor_window_down_factor for w in window_size] + + coords = _get_meshgrid_coords((0, 0), window_size) # 2, Wh*Ww + coords_anchor = _get_meshgrid_coords((0, 0), aws) + # 2, AWh*AWw + + max_horizontal_diff = aws[1] + ws[1] - 1 + if window_to_anchor: + offset = [w2 - 1 for w2 in aws] + idx = coords_diff_odd(coords, coords_anchor, offset, max_horizontal_diff) + else: + offset = [w1 - 1 for w1 in ws] + idx = coords_diff_odd(coords_anchor, coords, offset, max_horizontal_diff) + return idx # Wh*Ww, AWh*AWw or AWh*AWw, Wh*Ww + + +# def get_relative_position_index(window_size): +# # This is a very early version +# # get pair-wise relative position index for each token inside the window +# coords = _get_meshgrid_coords(start_coords=(0, 0), end_coords=window_size) + +# coords = coords[:, :, None] - coords[:, None, :] # 2, Wh*Ww, Wh*Ww +# coords = coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 +# coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 +# coords[:, :, 1] += window_size[1] - 1 +# coords[:, :, 0] *= 2 * window_size[1] - 1 +# idx = coords.sum(-1) # Wh*Ww, Wh*Ww +# return idx + + +def get_relative_win_position_index(window_size, anchor_window_size): + """ + Use case: 2) + """ + # get pair-wise relative position index for each token inside the window + ws = window_size + aws = anchor_window_size + coords_anchor_end = [(w1 + w2) // 2 for w1, w2 in zip(ws, aws)] + coords_anchor_start = [(w1 - w2) // 2 for w1, w2 in zip(ws, aws)] + + coords = _get_meshgrid_coords((0, 0), window_size) # 2, Wh*Ww + coords_anchor = _get_meshgrid_coords(coords_anchor_start, coords_anchor_end) + # 2, AWh*AWw + coords = coords[:, :, None] - coords_anchor[:, None, :] # 2, Wh*Ww, AWh*AWw + coords = coords.permute(1, 2, 0).contiguous() # Wh*Ww, AWh*AWw, 2 + coords[:, :, 0] += coords_anchor_end[0] - 1 # shift to start from 0 + coords[:, :, 1] += coords_anchor_end[1] - 1 + coords[:, :, 0] *= 2 * coords_anchor_end[1] - 1 + idx = coords.sum(-1) # Wh*Ww, AWh*AWw + return idx + + +# def get_relative_coords_table(window_size, pretrained_window_size): +# # This is a very early version +# # get relative_coords_table +# ws = window_size +# pws = pretrained_window_size +# coord_h = torch.arange(-(ws[0] - 1), ws[0], dtype=torch.float32) +# coord_w = torch.arange(-(ws[1] - 1), ws[1], dtype=torch.float32) +# table = torch.stack(torch.meshgrid([coord_h, coord_w], indexing='ij')).permute(1, 2, 0) +# table = table.contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2 +# if pws[0] > 0: +# table[:, :, :, 0] /= pws[0] - 1 +# table[:, :, :, 1] /= pws[1] - 1 +# else: +# table[:, :, :, 0] /= ws[0] - 1 +# table[:, :, :, 1] /= ws[1] - 1 +# table *= 8 # normalize to -8, 8 +# table = torch.sign(table) * torch.log2(torch.abs(table) + 1.0) / np.log2(8) +# return table + + +def get_relative_win_coords_table( + window_size, + anchor_window_size, + pretrained_window_size=[0, 0], + pretrained_anchor_window_size=[0, 0], +): + """ + Use case: 2) + """ + # get relative_coords_table + ws = window_size + aws = anchor_window_size + pws = pretrained_window_size + paws = pretrained_anchor_window_size + + # TODO: pretrained window size and pretrained anchor window size is only used here. + # TODO: Investigate whether it is really important to use this setting when finetuning large window size + # TODO: based on pretrained weights with small window size. + + table_size = [(wsi + awsi) // 2 for wsi, awsi in zip(ws, aws)] + table_size_pretrained = [(pwsi + pawsi) // 2 for pwsi, pawsi in zip(pws, paws)] + coord_h = torch.arange(-(table_size[0] - 1), table_size[0], dtype=torch.float32) + coord_w = torch.arange(-(table_size[1] - 1), table_size[1], dtype=torch.float32) + table = torch.stack(torch.meshgrid([coord_h, coord_w], indexing="ij")).permute( + 1, 2, 0 + ) + table = table.contiguous().unsqueeze(0) # 1, Wh+AWh-1, Ww+AWw-1, 2 + if table_size_pretrained[0] > 0: + table[:, :, :, 0] /= table_size_pretrained[0] - 1 + table[:, :, :, 1] /= table_size_pretrained[1] - 1 + else: + table[:, :, :, 0] /= table_size[0] - 1 + table[:, :, :, 1] /= table_size[1] - 1 + table *= 8 # normalize to -8, 8 + table = torch.sign(table) * torch.log2(torch.abs(table) + 1.0) / np.log2(8) + return table + + +if __name__ == "__main__": + table = get_relative_coords_table_all((4, 86), anchor_window_down_factor=2) + table = table.view(-1, 2) + index1 = get_relative_position_index_all((4, 86), 2, False) + index2 = get_relative_position_index_simple((4, 86), 2, False) + print(index2) + index3 = get_relative_position_index_all((4, 86), 2) + index4 = get_relative_position_index_simple((4, 86), 2) + print(index4) + print( + table.shape, + index2.shape, + index2.max(), + index2.min(), + index4.shape, + index4.max(), + index4.min(), + torch.allclose(index1, index2), + torch.allclose(index3, index4), + ) + + table = get_relative_coords_table_all((4, 86), anchor_window_down_factor=1) + table = table.view(-1, 2) + index1 = get_relative_position_index_all((4, 86), 1, False) + index2 = get_relative_position_index_simple((4, 86), 1, False) + # print(index1) + index3 = get_relative_position_index_all((4, 86), 1) + index4 = get_relative_position_index_simple((4, 86), 1) + # print(index2) + print( + table.shape, + index2.shape, + index2.max(), + index2.min(), + index4.shape, + index4.max(), + index4.min(), + torch.allclose(index1, index2), + torch.allclose(index3, index4), + ) + + table = get_relative_coords_table_all((8, 8), anchor_window_down_factor=2) + table = table.view(-1, 2) + index1 = get_relative_position_index_all((8, 8), 2, False) + index2 = get_relative_position_index_simple((8, 8), 2, False) + # print(index1) + index3 = get_relative_position_index_all((8, 8), 2) + index4 = get_relative_position_index_simple((8, 8), 2) + # print(index2) + print( + table.shape, + index2.shape, + index2.max(), + index2.min(), + index4.shape, + index4.max(), + index4.min(), + torch.allclose(index1, index2), + torch.allclose(index3, index4), + ) + + table = get_relative_coords_table_all((8, 8), anchor_window_down_factor=1) + table = table.view(-1, 2) + index1 = get_relative_position_index_all((8, 8), 1, False) + index2 = get_relative_position_index_simple((8, 8), 1, False) + # print(index1) + index3 = get_relative_position_index_all((8, 8), 1) + index4 = get_relative_position_index_simple((8, 8), 1) + # print(index2) + print( + table.shape, + index2.shape, + index2.max(), + index2.min(), + index4.shape, + index4.max(), + index4.min(), + torch.allclose(index1, index2), + torch.allclose(index3, index4), + ) diff --git a/architecture/grl_common/resblock.py b/architecture/grl_common/resblock.py new file mode 100644 index 0000000000000000000000000000000000000000..af1999c8d07a99d6aae1fc33bb4fb98670acbf4f --- /dev/null +++ b/architecture/grl_common/resblock.py @@ -0,0 +1,61 @@ +import torch.nn as nn + + +class ResBlock(nn.Module): + """Residual block without BN. + + It has a style of: + + :: + + ---Conv-ReLU-Conv-+- + |________________| + + Args: + num_feats (int): Channel number of intermediate features. + Default: 64. + res_scale (float): Used to scale the residual before addition. + Default: 1.0. + """ + + def __init__(self, num_feats=64, res_scale=1.0, bias=True, shortcut=True): + super().__init__() + self.res_scale = res_scale + self.shortcut = shortcut + self.conv1 = nn.Conv2d(num_feats, num_feats, 3, 1, 1, bias=bias) + self.conv2 = nn.Conv2d(num_feats, num_feats, 3, 1, 1, bias=bias) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + """Forward function. + + Args: + x (Tensor): Input tensor with shape (n, c, h, w). + + Returns: + Tensor: Forward results. + """ + + identity = x + out = self.conv2(self.relu(self.conv1(x))) + if self.shortcut: + return identity + out * self.res_scale + else: + return out * self.res_scale + + +class ResBlockWrapper(ResBlock): + "Used for transformers" + + def __init__(self, num_feats, bias=True, shortcut=True): + super(ResBlockWrapper, self).__init__( + num_feats=num_feats, bias=bias, shortcut=shortcut + ) + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + x = x.view(B, H, W, C).permute(0, 3, 1, 2) + x = super(ResBlockWrapper, self).forward(x) + x = x.flatten(2).permute(0, 2, 1) + return x diff --git a/architecture/grl_common/swin_v1_block.py b/architecture/grl_common/swin_v1_block.py new file mode 100644 index 0000000000000000000000000000000000000000..26ed1e291de57f29cbeea54af3e8af9b119b7476 --- /dev/null +++ b/architecture/grl_common/swin_v1_block.py @@ -0,0 +1,602 @@ +from math import prod + +import torch +import torch.nn as nn +from architecture.grl_common.ops import ( + bchw_to_blc, + blc_to_bchw, + calculate_mask, + window_partition, + window_reverse, +) +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + + +class Mlp(nn.Module): + """MLP as used in Vision Transformer, MLP-Mixer and related networks""" + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.0, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + drop_probs = to_2tuple(drop) + + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.fc2(x) + x = self.drop2(x) + return x + + +class WindowAttentionV1(nn.Module): + r"""Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__( + self, + dim, + window_size, + num_heads, + qkv_bias=True, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + use_pe=True, + ): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + self.use_pe = use_pe + + if self.use_pe: + # define a parameter table of relative position bias + ws = self.window_size + table = torch.zeros((2 * ws[0] - 1) * (2 * ws[1] - 1), num_heads) + self.relative_position_bias_table = nn.Parameter(table) + # 2*Wh-1 * 2*Ww-1, nH + trunc_normal_(self.relative_position_bias_table, std=0.02) + + self.get_relative_position_index(self.window_size) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + + def get_relative_position_index(self, window_size): + # get pair-wise relative position index for each token inside the window + coord_h = torch.arange(window_size[0]) + coord_w = torch.arange(window_size[1]) + coords = torch.stack(torch.meshgrid([coord_h, coord_w])) # 2, Wh, Ww + coords = torch.flatten(coords, 1) # 2, Wh*Ww + coords = coords[:, :, None] - coords[:, None, :] # 2, Wh*Ww, Wh*Ww + coords = coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 + coords[:, :, 1] += window_size[1] - 1 + coords[:, :, 0] *= 2 * window_size[1] - 1 + relative_position_index = coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + + # qkv projection + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + # attention map + q = q * self.scale + attn = q @ k.transpose(-2, -1) + + # positional encoding + if self.use_pe: + win_dim = prod(self.window_size) + bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1) + ] + bias = bias.view(win_dim, win_dim, -1).permute(2, 0, 1).contiguous() + # nH, Wh*Ww, Wh*Ww + attn = attn + bias.unsqueeze(0) + + # shift attention mask + if mask is not None: + nW = mask.shape[0] + mask = mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask + attn = attn.view(-1, self.num_heads, N, N) + + # attention + attn = self.softmax(attn) + attn = self.attn_drop(attn) + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + + # output projection + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}" + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class WindowAttentionWrapperV1(WindowAttentionV1): + def __init__(self, shift_size, input_resolution, **kwargs): + super(WindowAttentionWrapperV1, self).__init__(**kwargs) + self.shift_size = shift_size + self.input_resolution = input_resolution + + if self.shift_size > 0: + attn_mask = calculate_mask(input_resolution, self.window_size, shift_size) + else: + attn_mask = None + self.register_buffer("attn_mask", attn_mask) + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + + # partition windows + x = window_partition(x, self.window_size) # nW*B, wh, ww, C + x = x.view(-1, prod(self.window_size), C) # nW*B, wh*ww, C + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_mask = self.attn_mask + else: + attn_mask = calculate_mask(x_size, self.window_size, self.shift_size) + attn_mask = attn_mask.to(x.device) + + # attention + x = super(WindowAttentionWrapperV1, self).forward(x, mask=attn_mask) + # nW*B, wh*ww, C + + # merge windows + x = x.view(-1, *self.window_size, C) + x = window_reverse(x, self.window_size, x_size) # B, H, W, C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + x = x.view(B, H * W, C) + + return x + + +class SwinTransformerBlockV1(nn.Module): + r"""Swin Transformer Block. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__( + self, + dim, + input_resolution, + num_heads, + window_size=7, + shift_size=0, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + use_pe=True, + res_scale=1.0, + ): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert ( + 0 <= self.shift_size < self.window_size + ), "shift_size must in 0-window_size" + self.res_scale = res_scale + + self.norm1 = norm_layer(dim) + self.attn = WindowAttentionWrapperV1( + shift_size=self.shift_size, + input_resolution=self.input_resolution, + dim=dim, + window_size=to_2tuple(self.window_size), + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + use_pe=use_pe, + ) + + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + self.mlp = Mlp( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + drop=drop, + ) + + def forward(self, x, x_size): + # Window attention + x = x + self.res_scale * self.drop_path(self.attn(self.norm1(x), x_size)) + # FFN + x = x + self.res_scale * self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return ( + f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}, res_scale={self.res_scale}" + ) + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + +class PatchMerging(nn.Module): + r"""Patch Merging Layer. + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class PatchEmbed(nn.Module): + r"""Image to Patch Embedding + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__( + self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None + ): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [ + img_size[0] // patch_size[0], + img_size[1] // patch_size[1], + ] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + x = x.flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + flops = 0 + H, W = self.img_size + if self.norm is not None: + flops += H * W * self.embed_dim + return flops + + +class PatchUnEmbed(nn.Module): + r"""Image to Patch Unembedding + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__( + self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None + ): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [ + img_size[0] // patch_size[0], + img_size[1] // patch_size[1], + ] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + def forward(self, x, x_size): + B, HW, C = x.shape + x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C + return x + + def flops(self): + flops = 0 + return flops + + +class Linear(nn.Linear): + def __init__(self, in_features, out_features, bias=True): + super(Linear, self).__init__(in_features, out_features, bias) + + def forward(self, x): + B, C, H, W = x.shape + x = bchw_to_blc(x) + x = super(Linear, self).forward(x) + x = blc_to_bchw(x, (H, W)) + return x + + +def build_last_conv(conv_type, dim): + if conv_type == "1conv": + block = nn.Conv2d(dim, dim, 3, 1, 1) + elif conv_type == "3conv": + # to save parameters and memory + block = nn.Sequential( + nn.Conv2d(dim, dim // 4, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim, 3, 1, 1), + ) + elif conv_type == "1conv1x1": + block = nn.Conv2d(dim, dim, 1, 1, 0) + elif conv_type == "linear": + block = Linear(dim, dim) + return block + + +# class BasicLayer(nn.Module): +# """A basic Swin Transformer layer for one stage. +# Args: +# dim (int): Number of input channels. +# input_resolution (tuple[int]): Input resolution. +# depth (int): Number of blocks. +# num_heads (int): Number of attention heads. +# window_size (int): Local window size. +# mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. +# qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True +# qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. +# drop (float, optional): Dropout rate. Default: 0.0 +# attn_drop (float, optional): Attention dropout rate. Default: 0.0 +# drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 +# norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm +# downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None +# args: Additional arguments +# """ + +# def __init__( +# self, +# dim, +# input_resolution, +# depth, +# num_heads, +# window_size, +# mlp_ratio=4.0, +# qkv_bias=True, +# qk_scale=None, +# drop=0.0, +# attn_drop=0.0, +# drop_path=0.0, +# norm_layer=nn.LayerNorm, +# downsample=None, +# args=None, +# ): + +# super().__init__() +# self.dim = dim +# self.input_resolution = input_resolution +# self.depth = depth + +# # build blocks +# self.blocks = nn.ModuleList( +# [ +# _parse_block( +# dim=dim, +# input_resolution=input_resolution, +# num_heads=num_heads, +# window_size=window_size, +# shift_size=0 +# if args.no_shift +# else (0 if (i % 2 == 0) else window_size // 2), +# mlp_ratio=mlp_ratio, +# qkv_bias=qkv_bias, +# qk_scale=qk_scale, +# drop=drop, +# attn_drop=attn_drop, +# drop_path=drop_path[i] +# if isinstance(drop_path, list) +# else drop_path, +# norm_layer=norm_layer, +# stripe_type="H" if (i % 2 == 0) else "W", +# args=args, +# ) +# for i in range(depth) +# ] +# ) +# # self.blocks = nn.ModuleList( +# # [ +# # STV1Block( +# # dim=dim, +# # input_resolution=input_resolution, +# # num_heads=num_heads, +# # window_size=window_size, +# # shift_size=0 if (i % 2 == 0) else window_size // 2, +# # mlp_ratio=mlp_ratio, +# # qkv_bias=qkv_bias, +# # qk_scale=qk_scale, +# # drop=drop, +# # attn_drop=attn_drop, +# # drop_path=drop_path[i] +# # if isinstance(drop_path, list) +# # else drop_path, +# # norm_layer=norm_layer, +# # ) +# # for i in range(depth) +# # ] +# # ) + +# # patch merging layer +# if downsample is not None: +# self.downsample = downsample( +# input_resolution, dim=dim, norm_layer=norm_layer +# ) +# else: +# self.downsample = None + +# def forward(self, x, x_size): +# for blk in self.blocks: +# x = blk(x, x_size) +# if self.downsample is not None: +# x = self.downsample(x) +# return x + +# def extra_repr(self) -> str: +# return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + +# def flops(self): +# flops = 0 +# for blk in self.blocks: +# flops += blk.flops() +# if self.downsample is not None: +# flops += self.downsample.flops() +# return flops diff --git a/architecture/grl_common/swin_v2_block.py b/architecture/grl_common/swin_v2_block.py new file mode 100644 index 0000000000000000000000000000000000000000..e62f13704ee2fe5e1674cf6316df8137597688c3 --- /dev/null +++ b/architecture/grl_common/swin_v2_block.py @@ -0,0 +1,306 @@ +import math +from math import prod + +import torch +import torch.nn as nn +import torch.nn.functional as F +from architecture.grl_common.ops import ( + calculate_mask, + get_relative_coords_table, + get_relative_position_index, + window_partition, + window_reverse, +) +from architecture.grl_common.swin_v1_block import Mlp +from timm.models.layers import DropPath, to_2tuple + + +class WindowAttentionV2(nn.Module): + r"""Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + pretrained_window_size (tuple[int]): The height and width of the window in pre-training. + """ + + def __init__( + self, + dim, + window_size, + num_heads, + qkv_bias=True, + attn_drop=0.0, + proj_drop=0.0, + pretrained_window_size=[0, 0], + use_pe=True, + ): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.pretrained_window_size = pretrained_window_size + self.num_heads = num_heads + self.use_pe = use_pe + + self.logit_scale = nn.Parameter( + torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True + ) + + if self.use_pe: + # mlp to generate continuous relative position bias + self.cpb_mlp = nn.Sequential( + nn.Linear(2, 512, bias=True), + nn.ReLU(inplace=True), + nn.Linear(512, num_heads, bias=False), + ) + table = get_relative_coords_table(window_size, pretrained_window_size) + index = get_relative_position_index(window_size) + self.register_buffer("relative_coords_table", table) + self.register_buffer("relative_position_index", index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + # self.qkv = nn.Linear(dim, dim * 3, bias=False) + # if qkv_bias: + # self.q_bias = nn.Parameter(torch.zeros(dim)) + # self.v_bias = nn.Parameter(torch.zeros(dim)) + # else: + # self.q_bias = None + # self.v_bias = None + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + + # qkv projection + # qkv_bias = None + # if self.q_bias is not None: + # qkv_bias = torch.cat( + # ( + # self.q_bias, + # torch.zeros_like(self.v_bias, requires_grad=False), + # self.v_bias, + # ) + # ) + # qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) + qkv = self.qkv(x) + qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + # cosine attention map + attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1) + logit_scale = torch.clamp(self.logit_scale, max=math.log(1.0 / 0.01)).exp() + attn = attn * logit_scale + + # positional encoding + if self.use_pe: + bias_table = self.cpb_mlp(self.relative_coords_table) + bias_table = bias_table.view(-1, self.num_heads) + + win_dim = prod(self.window_size) + bias = bias_table[self.relative_position_index.view(-1)] + bias = bias.view(win_dim, win_dim, -1).permute(2, 0, 1).contiguous() + # nH, Wh*Ww, Wh*Ww + bias = 16 * torch.sigmoid(bias) + attn = attn + bias.unsqueeze(0) + + # shift attention mask + if mask is not None: + nW = mask.shape[0] + mask = mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask + attn = attn.view(-1, self.num_heads, N, N) + + # attention + attn = self.softmax(attn) + attn = self.attn_drop(attn) + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + + # output projection + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return ( + f"dim={self.dim}, window_size={self.window_size}, " + f"pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}" + ) + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class WindowAttentionWrapperV2(WindowAttentionV2): + def __init__(self, shift_size, input_resolution, **kwargs): + super(WindowAttentionWrapperV2, self).__init__(**kwargs) + self.shift_size = shift_size + self.input_resolution = input_resolution + + if self.shift_size > 0: + attn_mask = calculate_mask(input_resolution, self.window_size, shift_size) + else: + attn_mask = None + self.register_buffer("attn_mask", attn_mask) + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + + # partition windows + x = window_partition(x, self.window_size) # nW*B, wh, ww, C + x = x.view(-1, prod(self.window_size), C) # nW*B, wh*ww, C + + # W-MSA/SW-MSA + if self.input_resolution == x_size: + attn_mask = self.attn_mask + else: + attn_mask = calculate_mask(x_size, self.window_size, self.shift_size) + attn_mask = attn_mask.to(x.device) + + # attention + x = super(WindowAttentionWrapperV2, self).forward(x, mask=attn_mask) + # nW*B, wh*ww, C + + # merge windows + x = x.view(-1, *self.window_size, C) + x = window_reverse(x, self.window_size, x_size) # B, H, W, C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + x = x.view(B, H * W, C) + + return x + + +class SwinTransformerBlockV2(nn.Module): + r"""Swin Transformer Block. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + pretrained_window_size (int): Window size in pre-training. + """ + + def __init__( + self, + dim, + input_resolution, + num_heads, + window_size=7, + shift_size=0, + mlp_ratio=4.0, + qkv_bias=True, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + pretrained_window_size=0, + use_pe=True, + res_scale=1.0, + ): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert ( + 0 <= self.shift_size < self.window_size + ), "shift_size must in 0-window_size" + self.res_scale = res_scale + + self.attn = WindowAttentionWrapperV2( + shift_size=self.shift_size, + input_resolution=self.input_resolution, + dim=dim, + window_size=to_2tuple(self.window_size), + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=drop, + pretrained_window_size=to_2tuple(pretrained_window_size), + use_pe=use_pe, + ) + self.norm1 = norm_layer(dim) + + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.mlp = Mlp( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + drop=drop, + ) + self.norm2 = norm_layer(dim) + + def forward(self, x, x_size): + # Window attention + x = x + self.res_scale * self.drop_path(self.norm1(self.attn(x, x_size))) + # FFN + x = x + self.res_scale * self.drop_path(self.norm2(self.mlp(x))) + + return x + + def extra_repr(self) -> str: + return ( + f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}, res_scale={self.res_scale}" + ) + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops diff --git a/architecture/grl_common/upsample.py b/architecture/grl_common/upsample.py new file mode 100644 index 0000000000000000000000000000000000000000..86155d0efeec42abd1ba12ef263c50357709b625 --- /dev/null +++ b/architecture/grl_common/upsample.py @@ -0,0 +1,50 @@ +import math + +import torch.nn as nn + + +class Upsample(nn.Module): + """Upsample module. + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + super(Upsample, self).__init__() + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError( + f"scale {scale} is not supported. " "Supported scales: 2^n and 3." + ) + self.up = nn.Sequential(*m) + + def forward(self, x): + return self.up(x) + + +class UpsampleOneStep(nn.Module): + """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) + Used in lightweight SR to save parameters. + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat, num_out_ch): + super(UpsampleOneStep, self).__init__() + self.num_feat = num_feat + m = [] + m.append(nn.Conv2d(num_feat, (scale**2) * num_out_ch, 3, 1, 1)) + m.append(nn.PixelShuffle(scale)) + self.up = nn.Sequential(*m) + + def forward(self, x): + return self.up(x) diff --git a/architecture/rrdb.py b/architecture/rrdb.py new file mode 100644 index 0000000000000000000000000000000000000000..afe60f060d92320c2020f4e2a7b12811cdfeaf04 --- /dev/null +++ b/architecture/rrdb.py @@ -0,0 +1,218 @@ +# -*- coding: utf-8 -*- + +# Paper Github Repository: https://github.com/xinntao/Real-ESRGAN +# Code snippet from: https://github.com/XPixelGroup/BasicSR/blob/master/basicsr/archs/rrdbnet_arch.py +# Paper: https://arxiv.org/pdf/2107.10833.pdf + +import os, sys +import torch +from torch import nn as nn +from torch.nn import functional as F +from itertools import repeat +from torch.nn import init as init +from torch.nn.modules.batchnorm import _BatchNorm + + +def pixel_unshuffle(x, scale): + """ Pixel unshuffle. + + Args: + x (Tensor): Input feature with shape (b, c, hh, hw). + scale (int): Downsample ratio. + + Returns: + Tensor: the pixel unshuffled feature. + """ + b, c, hh, hw = x.size() + out_channel = c * (scale**2) + assert hh % scale == 0 and hw % scale == 0 + h = hh // scale + w = hw // scale + x_view = x.view(b, c, h, scale, w, scale) + return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w) + +def make_layer(basic_block, num_basic_block, **kwarg): + """Make layers by stacking the same blocks. + + Args: + basic_block (nn.module): nn.module class for basic block. + num_basic_block (int): number of blocks. + + Returns: + nn.Sequential: Stacked blocks in nn.Sequential. + """ + layers = [] + for _ in range(num_basic_block): + layers.append(basic_block(**kwarg)) + return nn.Sequential(*layers) + +def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs): + """Initialize network weights. + + Args: + module_list (list[nn.Module] | nn.Module): Modules to be initialized. + scale (float): Scale initialized weights, especially for residual + blocks. Default: 1. + bias_fill (float): The value to fill bias. Default: 0 + kwargs (dict): Other arguments for initialization function. + """ + if not isinstance(module_list, list): + module_list = [module_list] + for module in module_list: + for m in module.modules(): + if isinstance(m, nn.Conv2d): + init.kaiming_normal_(m.weight, **kwargs) + m.weight.data *= scale + if m.bias is not None: + m.bias.data.fill_(bias_fill) + elif isinstance(m, nn.Linear): + init.kaiming_normal_(m.weight, **kwargs) + m.weight.data *= scale + if m.bias is not None: + m.bias.data.fill_(bias_fill) + elif isinstance(m, _BatchNorm): + init.constant_(m.weight, 1) + if m.bias is not None: + m.bias.data.fill_(bias_fill) + +class ResidualDenseBlock(nn.Module): + """Residual Dense Block. + + Used in RRDB block in ESRGAN. + + Args: + num_feat (int): Channel number of intermediate features. + num_grow_ch (int): Channels for each growth. + """ + + def __init__(self, num_feat=64, num_grow_ch=32): + super(ResidualDenseBlock, self).__init__() + self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1) + self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1) + self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1) + self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1) + self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1) + + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + # initialization + default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) + + def forward(self, x): + x1 = self.lrelu(self.conv1(x)) + x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) + x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) + x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) + x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) + # Empirically, we use 0.2 to scale the residual for better performance + return x5 * 0.2 + x + + +class RRDB(nn.Module): + """Residual in Residual Dense Block. + + Used in RRDB-Net in ESRGAN. + + Args: + num_feat (int): Channel number of intermediate features. + num_grow_ch (int): Channels for each growth. + """ + + def __init__(self, num_feat, num_grow_ch=32): + super(RRDB, self).__init__() + self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch) + self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch) + self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch) + + def forward(self, x): + out = self.rdb1(x) + out = self.rdb2(out) + out = self.rdb3(out) + # Empirically, we use 0.2 to scale the residual for better performance + return out * 0.2 + x + + + +class RRDBNet(nn.Module): + """Networks consisting of Residual in Residual Dense Block, which is used + in ESRGAN. + + ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks. + + We extend ESRGAN for scale x2 and scale x1. + Note: This is one option for scale 1, scale 2 in RRDBNet. + We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size + and enlarge the channel size before feeding inputs into the main ESRGAN architecture. + + Args: + num_in_ch (int): Channel number of inputs. + num_out_ch (int): Channel number of outputs. + num_feat (int): Channel number of intermediate features. + Default: 64 + num_block (int): Block number in the trunk network. Defaults: 6 for our Anime training cases + num_grow_ch (int): Channels for each growth. Default: 32. + """ + + def __init__(self, num_in_ch, num_out_ch, scale, num_feat=64, num_block=6, num_grow_ch=32): + + super(RRDBNet, self).__init__() + self.scale = scale + if scale == 2: + num_in_ch = num_in_ch * 4 + elif scale == 1: + num_in_ch = num_in_ch * 16 + self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) + self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch) + self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + # upsample + self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + def forward(self, x): + if self.scale == 2: + feat = pixel_unshuffle(x, scale=2) + elif self.scale == 1: + feat = pixel_unshuffle(x, scale=4) + else: + feat = x + feat = self.conv_first(feat) + body_feat = self.conv_body(self.body(feat)) + feat = feat + body_feat + # upsample + feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest'))) + feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest'))) + out = self.conv_last(self.lrelu(self.conv_hr(feat))) + return out + + + +def main(): + root_path = os.path.abspath('.') + sys.path.append(root_path) + + from opt import opt # Manage GPU to choose + from pthflops import count_ops + from torchsummary import summary + import time + + # We use RRDB 6Blocks by default. + model = RRDBNet(3, 3).cuda() + pytorch_total_params = sum(p.numel() for p in model.parameters()) + print(f"RRDB has param {pytorch_total_params//1000} K params") + + + # Count the number of FLOPs to double check + x = torch.randn((1, 3, 180, 180)).cuda() + start = time.time() + x = model(x) + print("output size is ", x.shape) + total = time.time() - start + print(total) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/architecture/swinir.py b/architecture/swinir.py new file mode 100644 index 0000000000000000000000000000000000000000..dc0de9bbe2df1f67ea7c4d0dab9027a68b54797e --- /dev/null +++ b/architecture/swinir.py @@ -0,0 +1,874 @@ +# ----------------------------------------------------------------------------------- +# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257 +# Originally Written by Ze Liu, Modified by Jingyun Liang. +# ----------------------------------------------------------------------------------- + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, x_size) + else: + x = blk(x, x_size) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + +class RSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv'): + super(RSTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + + if resi_connection == '1conv': + self.conv = nn.Conv2d(dim, dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + norm_layer=None) + + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + norm_layer=None) + + def forward(self, x, x_size): + return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x + + def flops(self): + flops = 0 + flops += self.residual_group.flops() + H, W = self.input_resolution + flops += H * W * self.dim * self.dim * 9 + flops += self.patch_embed.flops() + flops += self.patch_unembed.flops() + + return flops + + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + x = x.flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + flops = 0 + H, W = self.img_size + if self.norm is not None: + flops += H * W * self.embed_dim + return flops + + +class PatchUnEmbed(nn.Module): + r""" Image to Patch Unembedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + def forward(self, x, x_size): + B, HW, C = x.shape + x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C + return x + + def flops(self): + flops = 0 + return flops + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + + +class UpsampleOneStep(nn.Sequential): + """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) + Used in lightweight SR to save parameters. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + + """ + + def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): + self.num_feat = num_feat + self.input_resolution = input_resolution + m = [] + m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) + m.append(nn.PixelShuffle(scale)) + super(UpsampleOneStep, self).__init__(*m) + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.num_feat * 3 * 9 + return flops + + +class SwinIR(nn.Module): + r""" SwinIR + A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer. + + Args: + img_size (int | tuple(int)): Input image size. Default 64 + patch_size (int | tuple(int)): Patch size. Default: 1 + in_chans (int): Number of input image channels. Default: 3 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction + img_range: Image range. 1. or 255. + upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None + resi_connection: The convolutional block before residual connection. '1conv'/'3conv' + """ + + def __init__(self, img_size=64, patch_size=1, in_chans=3, + embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv', + **kwargs): + super(SwinIR, self).__init__() + num_in_ch = in_chans + num_out_ch = in_chans + num_feat = 64 + self.img_range = img_range + if in_chans == 3: + rgb_mean = (0.4488, 0.4371, 0.4040) + self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) + else: + self.mean = torch.zeros(1, 1, 1, 1) + self.upscale = upscale + self.upsampler = upsampler + self.window_size = window_size + + ##################################################################################################### + ################################### 1, shallow feature extraction ################################### + self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1) + + ##################################################################################################### + ################################### 2, deep feature extraction ###################################### + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = embed_dim + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # merge non-overlapping patches into image + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build Residual Swin Transformer blocks (RSTB) + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = RSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + + ) + self.layers.append(layer) + self.norm = norm_layer(self.num_features) + + # build the last conv layer in deep feature extraction + if resi_connection == '1conv': + self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) + + ##################################################################################################### + ################################ 3, high quality image reconstruction ################################ + if self.upsampler == 'pixelshuffle': + # for classical SR + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.upsample = Upsample(upscale, num_feat) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR (to save parameters) + self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch, + (patches_resolution[0], patches_resolution[1])) + elif self.upsampler == 'nearest+conv': + # for real-world SR (less artifacts) + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + if self.upscale == 4: + self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + else: + # for image denoising and JPEG compression artifact reduction + self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def check_image_size(self, x): + _, _, h, w = x.size() + mod_pad_h = (self.window_size - h % self.window_size) % self.window_size + mod_pad_w = (self.window_size - w % self.window_size) % self.window_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + return x + + def forward_features(self, x): + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x, x_size) + + x = self.norm(x) # B L C + x = self.patch_unembed(x, x_size) + + return x + + def forward(self, x): + H, W = x.shape[2:] + x = self.check_image_size(x) + + self.mean = self.mean.type_as(x) + x = (x - self.mean) * self.img_range + + if self.upsampler == 'pixelshuffle': + # for classical SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.conv_last(self.upsample(x)) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.upsample(x) + elif self.upsampler == 'nearest+conv': + # for real-world SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + if self.upscale == 4: + x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + x = self.conv_last(self.lrelu(self.conv_hr(x))) + else: + # for image denoising and JPEG compression artifact reduction + x_first = self.conv_first(x) + res = self.conv_after_body(self.forward_features(x_first)) + x_first + x = x + self.conv_last(res) + + x = x / self.img_range + self.mean + + return x[:, :, :H*self.upscale, :W*self.upscale] + + def flops(self): + flops = 0 + H, W = self.patches_resolution + flops += H * W * 3 * self.embed_dim * 9 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers): + flops += layer.flops() + flops += H * W * 3 * self.embed_dim * self.embed_dim + flops += self.upsample.flops() + return flops + + +if __name__ == '__main__': + upscale = 4 + window_size = 8 + height = (1024 // upscale // window_size + 1) * window_size + width = (720 // upscale // window_size + 1) * window_size + model = SwinIR(upscale=2, img_size=(height, width), + window_size=window_size, img_range=1., depths=[6, 6, 6, 6], + embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect').cuda() + print(model) + + pytorch_total_params = sum(p.numel() for p in model.parameters()) + print(f"pathGAN has param {pytorch_total_params//1000} K params") + + + # Count the time + import time + x = torch.randn((1, 3, 180, 180)).cuda() + start = time.time() + x = model(x) + total = time.time() - start + print("total time spent is ", total) diff --git a/dataset_curation_pipeline/IC9600/ICNet.py b/dataset_curation_pipeline/IC9600/ICNet.py new file mode 100644 index 0000000000000000000000000000000000000000..155d9688e34bdd4fb649784eb9081b76acc49c9d --- /dev/null +++ b/dataset_curation_pipeline/IC9600/ICNet.py @@ -0,0 +1,151 @@ +import torch +import torchvision +import torch.nn as nn +import torch.nn.functional as F + + + +class slam(nn.Module): + def __init__(self, spatial_dim): + super(slam,self).__init__() + self.spatial_dim = spatial_dim + self.linear = nn.Sequential( + nn.Linear(spatial_dim**2,512), + nn.ReLU(), + nn.Linear(512,1), + nn.Sigmoid() + ) + + def forward(self, feature): + n,c,h,w = feature.shape + if (h != self.spatial_dim): + x = F.interpolate(feature,size=(self.spatial_dim,self.spatial_dim),mode= "bilinear", align_corners=True) + else: + x = feature + + + x = x.view(n,c,-1) + x = self.linear(x) + x = x.unsqueeze(dim =3) + out = x.expand_as(feature)*feature + + return out + + +class to_map(nn.Module): + def __init__(self,channels): + super(to_map,self).__init__() + self.to_map = nn.Sequential( + nn.Conv2d(in_channels=channels,out_channels=1, kernel_size=1,stride=1), + nn.Sigmoid() + ) + + def forward(self,feature): + return self.to_map(feature) + + +class conv_bn_relu(nn.Module): + def __init__(self,in_channels, out_channels, kernel_size = 3, padding = 1, stride = 1): + super(conv_bn_relu,self).__init__() + self.conv = nn.Conv2d(in_channels= in_channels, out_channels= out_channels, kernel_size= kernel_size, padding= padding, stride = stride) + self.bn = nn.BatchNorm2d(out_channels) + self.relu = nn.ReLU() + + def forward(self,x): + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) + return x + + + +class up_conv_bn_relu(nn.Module): + def __init__(self,up_size, in_channels, out_channels = 64, kernal_size = 1, padding =0, stride = 1): + super(up_conv_bn_relu,self).__init__() + self.upSample = nn.Upsample(size = (up_size,up_size),mode="bilinear",align_corners=True) + self.conv = nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size = kernal_size, stride = stride, padding= padding) + self.bn = nn.BatchNorm2d(num_features=out_channels) + self.act = nn.ReLU() + + def forward(self,x): + x = self.upSample(x) + x = self.conv(x) + x = self.bn(x) + x = self.act(x) + return x + + + +class ICNet(nn.Module): + def __init__(self, is_pretrain = True, size1 = 512, size2 = 256): + super(ICNet,self).__init__() + resnet18Pretrained1 = torchvision.models.resnet18(pretrained= is_pretrain) + resnet18Pretrained2 = torchvision.models.resnet18(pretrained= is_pretrain) + + self.size1 = size1 + self.size2 = size2 + + ## detail branch + self.b1_1 = nn.Sequential(*list(resnet18Pretrained1.children())[:5]) + self.b1_1_slam = slam(32) + + self.b1_2 = list(resnet18Pretrained1.children())[5] + self.b1_2_slam = slam(32) + + ## context branch + self.b2_1 = nn.Sequential(*list(resnet18Pretrained2.children())[:5]) + self.b2_1_slam = slam(32) + + self.b2_2 = list(resnet18Pretrained2.children())[5] + self.b2_2_slam = slam(32) + + self.b2_3 = list(resnet18Pretrained2.children())[6] + self.b2_3_slam = slam(16) + + self.b2_4 = list(resnet18Pretrained2.children())[7] + self.b2_4_slam = slam(8) + + ## upsample + self.upsize = size1 // 8 + self.up1 = up_conv_bn_relu(up_size = self.upsize, in_channels = 128, out_channels = 256) + self.up2 = up_conv_bn_relu(up_size = self.upsize, in_channels = 512, out_channels = 256) + + ## map prediction head + self.to_map_f = conv_bn_relu(256*2,256*2) + self.to_map_f_slam = slam(32) + self.to_map = to_map(256*2) + + ## score prediction head + self.to_score_f = conv_bn_relu(256*2,256*2) + self.to_score_f_slam = slam(32) + self.head = nn.Sequential( + nn.Linear(256*2,512), + nn.ReLU(), + nn.Linear(512,1), + nn.Sigmoid() + ) + self.avgpool = nn.AdaptiveAvgPool2d((1,1)) + + + def forward(self,x1): + assert(x1.shape[2] == x1.shape[3] == self.size1) + x2 = F.interpolate(x1, size= (self.size2,self.size2), mode = "bilinear", align_corners= True) + + x1 = self.b1_2_slam(self.b1_2(self.b1_1_slam(self.b1_1(x1)))) + x2 = self.b2_2_slam(self.b2_2(self.b2_1_slam(self.b2_1(x2)))) + x2 = self.b2_4_slam(self.b2_4(self.b2_3_slam(self.b2_3(x2)))) + + + x1 = self.up1(x1) + x2 = self.up2(x2) + x_cat = torch.cat((x1,x2),dim = 1) + + cly_map = self.to_map(self.to_map_f_slam(self.to_map_f(x_cat))) + + score_feature = self.to_score_f_slam(self.to_score_f(x_cat)) + score_feature = self.avgpool(score_feature) + score_feature = score_feature.squeeze() + score = self.head(score_feature) + score = score.squeeze() + + return score,cly_map diff --git a/dataset_curation_pipeline/IC9600/gene.py b/dataset_curation_pipeline/IC9600/gene.py new file mode 100644 index 0000000000000000000000000000000000000000..42f347f4f157c9232cf0709d87cdf2466278bf35 --- /dev/null +++ b/dataset_curation_pipeline/IC9600/gene.py @@ -0,0 +1,113 @@ +import argparse +import os, sys +import torch +import cv2 +from torchvision import transforms +from PIL import Image +import torch.nn.functional as F +import numpy as np +from matplotlib import pyplot as plt +from tqdm import tqdm + +# Import files from the local folder +root_path = os.path.abspath('.') +sys.path.append(root_path) +from opt import opt +from dataset_curation_pipeline.IC9600.ICNet import ICNet + + + +inference_transform = transforms.Compose([ + transforms.Resize((512,512)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + +def blend(ori_img, ic_img, alpha = 0.8, cm = plt.get_cmap("magma")): + cm_ic_map = cm(ic_img) + heatmap = Image.fromarray((cm_ic_map[:, :, -2::-1]*255).astype(np.uint8)) + ori_img = Image.fromarray(ori_img) + blend = Image.blend(ori_img,heatmap,alpha=alpha) + blend = np.array(blend) + return blend + + +def infer_one_image(model, img_path): + with torch.no_grad(): + ori_img = Image.open(img_path).convert("RGB") + ori_height = ori_img.height + ori_width = ori_img.width + img = inference_transform(ori_img) + img = img.cuda() + img = img.unsqueeze(0) + ic_score, ic_map = model(img) + ic_score = ic_score.item() + + + # ic_map = F.interpolate(ic_map, (ori_height, ori_width), mode = 'bilinear') + + ## gene ic map + # ic_map_np = ic_map.squeeze().detach().cpu().numpy() + # out_ic_map_name = os.path.basename(img_path).split('.')[0] + '_' + str(ic_score)[:7] + '.npy' + # out_ic_map_path = os.path.join(args.output, out_ic_map_name) + # np.save(out_ic_map_path, ic_map_np) + + ## gene blend map + # ic_map_img = (ic_map * 255).round().squeeze().detach().cpu().numpy().astype('uint8') + # blend_img = blend(np.array(ori_img), ic_map_img) + # out_blend_img_name = os.path.basename(img_path).split('.')[0] + '.png' + # out_blend_img_path = os.path.join(args.output, out_blend_img_name) + # cv2.imwrite(out_blend_img_path, blend_img) + return ic_score + + + +def infer_directory(img_dir): + imgs = sorted(os.listdir(img_dir)) + scores = [] + for img in tqdm(imgs): + img_path = os.path.join(img_dir, img) + score = infer_one_image(img_path) + + scores.append((score, img_path)) + print(img_path, score) + + scores = sorted(scores, key=lambda x: x[0]) + scores = scores[::-1] + + for score in scores[:50]: + print(score) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('-i', '--input', type = str, default = './example') + parser.add_argument('-o', '--output', type = str, default = './out') + parser.add_argument('-d', '--device', type = int, default=0) + + args = parser.parse_args() + + model = ICNet() + model.load_state_dict(torch.load('./checkpoint/ck.pth',map_location=torch.device('cpu'))) + model.eval() + device = torch.device(args.device) + model.to(device) + + inference_transform = transforms.Compose([ + transforms.Resize((512,512)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + if os.path.isfile(args.input): + infer_one_image(args.input) + else: + infer_directory(args.input) + + + + + + + + + diff --git a/dataset_curation_pipeline/collect.py b/dataset_curation_pipeline/collect.py new file mode 100644 index 0000000000000000000000000000000000000000..db98ddbff954c184ee8a5d4485c1432f3d297b41 --- /dev/null +++ b/dataset_curation_pipeline/collect.py @@ -0,0 +1,222 @@ +''' + This file is the whole dataset curation pipeline to collect the least compressed and the most informative frames from video source. +''' +import os, time, sys +import shutil +import cv2 +import torch +import argparse + +# Import files from the local folder +root_path = os.path.abspath('.') +sys.path.append(root_path) +from opt import opt +from dataset_curation_pipeline.IC9600.gene import infer_one_image +from dataset_curation_pipeline.IC9600.ICNet import ICNet + + +class video_scoring: + + def __init__(self, IC9600_pretrained_weight_path) -> None: + + # Init the model + self.scorer = ICNet() + self.scorer.load_state_dict(torch.load(IC9600_pretrained_weight_path, map_location=torch.device('cpu'))) + self.scorer.eval().cuda() + + + def select_frame(self, skip_num, img_lists, target_frame_num, save_dir, output_name_head, partition_idx): + ''' Execution of scoring to all I-Frame in img_folder and select target_frame to return back + Args: + skip_num (int): Only 1 in skip_num will be chosen to accelerate. + img_lists (str): The image lists of all files we want to process + target_frame_num (int): The number of frames we need to choose + save_dir (str): The path where we save those images + output_name_head (str): This is the input video name head + partition_idx (int): The partition idx + ''' + + stores = [] + for idx, image_path in enumerate(sorted(img_lists)): + if idx % skip_num != 0: + # We only process 1 in 3 to accelerate and also prevent minor case of repeated scene. + continue + + + # Evaluate the image complexity score for this image + score = infer_one_image(self.scorer, image_path) + + if verbose: + print(image_path, score) + stores.append((score, image_path)) + + if verbose: + print(image_path, score) + + + # Find the top most scores' images + stores.sort(key=lambda x:x[0]) + selected = stores[-target_frame_num:] + # print(len(stores), len(selected)) + if verbose: + print("The lowest selected score is ", selected[0]) # This is a kind of info + + + # Store the selected images + for idx, (score, img_path) in enumerate(selected): + output_name = output_name_head + "_" +str(partition_idx)+ "_" + str(idx) + ".png" + output_path = os.path.join(save_dir, output_name) + shutil.copyfile(img_path, output_path) + + + def run(self, skip_num, img_folder, target_frame_num, save_dir, output_name_head, partition_num): + ''' Execution of scoring to all I-Frame in img_folder and select target_frame to return back + Args: + skip_num (int): Only 1 in skip_num will be chosen to accelerate. + img_folder (str): The image folder of all I-Frames we need to process + target_frame_num (int): The number of frames we need to choose + save_dir (str): The path where we save those images + output_name_head (str): This is the input video name head + partition_num (int): The number of partition we want to crop the video to + ''' + assert(target_frame_num%partition_num == 0) + + img_lists = [] + for img_name in sorted(os.listdir(img_folder)): + path = os.path.join(img_folder, img_name) + img_lists.append(path) + length = len(img_lists) + unit_length = (length // partition_num) + target_partition_num = target_frame_num // partition_num + + # Cut the folder to several partition and select those with the highest score + for idx in range(partition_num): + select_lists = img_lists[unit_length*idx : unit_length*(idx+1)] + self.select_frame(skip_num, select_lists, target_partition_num, save_dir, output_name_head, idx) + + +class frame_collector: + + def __init__(self, IC9600_pretrained_weight_path, verbose) -> None: + + self.scoring = video_scoring(IC9600_pretrained_weight_path) + self.verbose = verbose + + + def video_split_by_IFrame(self, video_path, tmp_path): + ''' Split the video to its I-Frames format + Args: + video_path (str): The directory to a single video + tmp_path (str): A temporary working places to work and will be delete at the end + ''' + + # Prepare the work folder needed + if os.path.exists(tmp_path): + shutil.rmtree(tmp_path) + os.makedirs(tmp_path) + + + # Split Video I-frame + cmd = "ffmpeg -i " + video_path + " -loglevel error -vf select='eq(pict_type\,I)' -vsync 2 -f image2 -q:v 1 " + tmp_path + "/image-%06d.png" # At most support 100K I-Frames per video + + if self.verbose: + print(cmd) + os.system(cmd) + + + + def collect_frames(self, video_folder_dir, save_dir, tmp_path, skip_num, target_frames, partition_num): + ''' Automatically collect frames from the video dir + Args: + video_folder_dir (str): The directory of all videos input + save_dir (str): The directory we will store the selected frames + tmp_path (str): A temporary working places to work and will be delete at the end + skip_num (int): Only 1 in skip_num will be chosen to accelerate. + target_frames (list): [# of frames for video under 30 min, # of frames for video over 30 min] + partition_num (int): The number of partition we want to crop the video to + ''' + + # Iterate all video under video_folder_dir + for video_name in sorted(os.listdir(video_folder_dir)): + # Sanity check for this video file format + info = video_name.split('.') + if info[-1] not in ['mp4', 'mkv', '']: + continue + output_name_head, extension = info + + + # Get info of this video + video_path = os.path.join(video_folder_dir, video_name) + duration = get_duration(video_path) # unit in minutes + print("We are processing " + video_path + " with duration " + str(duration) + " min") + + + # Split the video to I-frame + self.video_split_by_IFrame(video_path, tmp_path) + + + # Score the frames and select those top scored frames we need + if duration <= 30: + target_frame_num = target_frames[0] + else: + target_frame_num = target_frames[1] + + self.scoring.run(skip_num, tmp_path, target_frame_num, save_dir, output_name_head, partition_num) + + + # Remove folders if needed + + +def get_duration(filename): + video = cv2.VideoCapture(filename) + fps = video.get(cv2.CAP_PROP_FPS) + frame_count = video.get(cv2.CAP_PROP_FRAME_COUNT) + seconds = frame_count / fps + minutes = int(seconds / 60) + return minutes + + +if __name__ == "__main__": + + # Fundamental setting + parser = argparse.ArgumentParser() + parser.add_argument('--video_folder_dir', type = str, default = '../anime_videos', help = "A folder with video sources") + parser.add_argument('--IC9600_pretrained_weight_path', type = str, default = "pretrained/ck.pth", help = "The pretrained IC9600 weight") + parser.add_argument('--save_dir', type = str, default = 'APISR_dataset', help = "The folder to store filtered dataset") + parser.add_argument('--skip_num', type = int, default = 5, help = "Only 1 in skip_num will be chosen in sequential I-frames to accelerate.") + parser.add_argument('--target_frames', type = list, default = [16, 24], help = "[# of frames for video under 30 min, # of frames for video over 30 min]") + parser.add_argument('--partition_num', type = int, default = 8, help = "The number of partition we want to crop the video to, to increase diversity of sampling") + parser.add_argument('--verbose', type = bool, default = True, help = "Whether we print log message") + args = parser.parse_args() + + + # Transform to variable + video_folder_dir = args.video_folder_dir + IC9600_pretrained_weight_path = args.IC9600_pretrained_weight_path + save_dir = args.save_dir + skip_num = args.skip_num + target_frames = args.target_frames # [# of frames for video under 30 min, # of frames for video over 30 min] + partition_num = args.partition_num + verbose = args.verbose + + + # Secondary setting + tmp_path = "tmp_dataset" + + + # Prepare + if os.path.exists(save_dir): + shutil.rmtree(save_dir) + os.makedirs(save_dir) + + + # Process + start = time.time() + + obj = frame_collector(IC9600_pretrained_weight_path, verbose) + obj.collect_frames(video_folder_dir, save_dir, tmp_path, skip_num, target_frames, partition_num) + + total_time = (time.time() - start)//60 + print("Total time spent is {} min".format(total_time)) + + shutil.rmtree(tmp_path) \ No newline at end of file diff --git a/degradation/ESR/degradation_esr_shared.py b/degradation/ESR/degradation_esr_shared.py new file mode 100644 index 0000000000000000000000000000000000000000..37449e0c273dc4b7dc5262c94e7030004b2ae6a4 --- /dev/null +++ b/degradation/ESR/degradation_esr_shared.py @@ -0,0 +1,180 @@ +# -*- coding: utf-8 -*- + +import argparse +import cv2 +import torch +import numpy as np +import os, shutil, time +import sys, random +from multiprocessing import Pool +from os import path as osp +from tqdm import tqdm +from math import log10, sqrt +import torch.nn.functional as F + +root_path = os.path.abspath('.') +sys.path.append(root_path) +from degradation.ESR.degradations_functionality import * +from degradation.ESR.diffjpeg import * +from degradation.ESR.utils import filter2D +from degradation.image_compression.jpeg import JPEG +from degradation.image_compression.webp import WEBP +from degradation.image_compression.heif import HEIF +from degradation.image_compression.avif import AVIF +from opt import opt + + +def PSNR(original, compressed): + mse = np.mean((original - compressed) ** 2) + if(mse == 0): # MSE is zero means no noise is present in the signal . + # Therefore PSNR have no importance. + return 100 + max_pixel = 255.0 + psnr = 20 * log10(max_pixel / sqrt(mse)) + return psnr + + + +def downsample_1st(out, opt): + # Resize with different mode + updown_type = random.choices(['up', 'down', 'keep'], opt['resize_prob'])[0] + if updown_type == 'up': + scale = np.random.uniform(1, opt['resize_range'][1]) + elif updown_type == 'down': + scale = np.random.uniform(opt['resize_range'][0], 1) + else: + scale = 1 + mode = random.choice(opt['resize_options']) + out = F.interpolate(out, scale_factor=scale, mode=mode) + + return out + + +def downsample_2nd(out, opt, ori_h, ori_w): + # Second Resize for 4x scaling + if opt['scale'] == 4: + updown_type = random.choices(['up', 'down', 'keep'], opt['resize_prob2'])[0] + if updown_type == 'up': + scale = np.random.uniform(1, opt['resize_range2'][1]) + elif updown_type == 'down': + scale = np.random.uniform(opt['resize_range2'][0], 1) + else: + scale = 1 + mode = random.choice(opt['resize_options']) + # Resize这边改回来原来的版本,不用连续的resize了 + # out = F.interpolate(out, scale_factor=scale, mode=mode) + out = F.interpolate( + out, size=(int(ori_h / opt['scale'] * scale), int(ori_w / opt['scale'] * scale)), mode=mode + ) + + return out + + +def common_degradation(out, opt, kernels, process_id, verbose = False): + jpeger = DiffJPEG(differentiable=False).cuda() + kernel1, kernel2 = kernels + + + downsample_1st_position = random.choices([0, 1, 2])[0] + if opt['scale'] == 4: + # Only do the second downsample at 4x scale + downsample_2nd_position = random.choices([0, 1, 2])[0] + else: + # print("We don't use the second resize") + downsample_2nd_position = -1 + + + ####---------------------------- Frist Degradation ----------------------------------#### + batch_size, _, ori_h, ori_w = out.size() + + if downsample_1st_position == 0: + out = downsample_1st(out, opt) + + # Bluring kernel + out = filter2D(out, kernel1) + if verbose: print(f"(1st) blur noise") + + + if downsample_1st_position == 1: + out = downsample_1st(out, opt) + + + # Noise effect (gaussian / poisson) + gray_noise_prob = opt['gray_noise_prob'] + if np.random.uniform() < opt['gaussian_noise_prob']: + # Gaussian noise + out = random_add_gaussian_noise_pt( + out, sigma_range=opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob) + name = "gaussian_noise" + else: + # Poisson noise + out = random_add_poisson_noise_pt( + out, + scale_range=opt['poisson_scale_range'], + gray_prob=gray_noise_prob, + clip=True, + rounds=False) + name = "poisson_noise" + if verbose: print("(1st) " + str(name)) + + + if downsample_1st_position == 2: + out = downsample_1st(out, opt) + + + # Choose an image compression codec (All degradation batch use the same codec) + image_codec = random.choices(opt['compression_codec1'], opt['compression_codec_prob1'])[0] # All lower case + if image_codec == "jpeg": + out = JPEG.compress_tensor(out) + elif image_codec == "webp": + try: + out = WEBP.compress_tensor(out, idx=process_id) + except Exception: + print("There is exception again in webp!") + out = WEBP.compress_tensor(out, idx=process_id) + elif image_codec == "heif": + out = HEIF.compress_tensor(out, idx=process_id) + elif image_codec == "avif": + out = AVIF.compress_tensor(out, idx=process_id) + else: + raise NotImplementedError("We don't have such image compression designed!") + # ########################################################################################## + + + # ####---------------------------- Second Degradation ----------------------------------#### + if downsample_2nd_position == 0: + out = downsample_2nd(out, opt, ori_h, ori_w) + + + # Add blur 2nd time + if np.random.uniform() < opt['second_blur_prob']: + # 这个bluring不是必定触发的 + if verbose: print("(2nd) blur noise") + out = filter2D(out, kernel2) + + + if downsample_2nd_position == 1: + out = downsample_2nd(out, opt, ori_h, ori_w) + + + # Add noise 2nd time + gray_noise_prob = opt['gray_noise_prob2'] + if np.random.uniform() < opt['gaussian_noise_prob2']: + # gaussian noise + if verbose: print("(2nd) gaussian noise") + out = random_add_gaussian_noise_pt( + out, sigma_range=opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob) + name = "gaussian_noise" + else: + # poisson noise + if verbose: print("(2nd) poisson noise") + out = random_add_poisson_noise_pt( + out, scale_range=opt['poisson_scale_range2'], gray_prob=gray_noise_prob, clip=True, rounds=False) + name = "poisson_noise" + + + if downsample_2nd_position == 2: + out = downsample_2nd(out, opt, ori_h, ori_w) + + + return out \ No newline at end of file diff --git a/degradation/ESR/degradations_functionality.py b/degradation/ESR/degradations_functionality.py new file mode 100644 index 0000000000000000000000000000000000000000..f9c32e7c93297fb14bcff38ebdbc1656a5f33487 --- /dev/null +++ b/degradation/ESR/degradations_functionality.py @@ -0,0 +1,785 @@ +# -*- coding: utf-8 -*- + +import cv2 +import math +import numpy as np +import random +import torch +from scipy import special +from scipy.stats import multivariate_normal +from torchvision.transforms.functional_tensor import rgb_to_grayscale + +# -------------------------------------------------------------------- # +# --------------------------- blur kernels --------------------------- # +# -------------------------------------------------------------------- # + + +# --------------------------- util functions --------------------------- # +def sigma_matrix2(sig_x, sig_y, theta): + """Calculate the rotated sigma matrix (two dimensional matrix). + + Args: + sig_x (float): + sig_y (float): + theta (float): Radian measurement. + + Returns: + ndarray: Rotated sigma matrix. + """ + d_matrix = np.array([[sig_x**2, 0], [0, sig_y**2]]) + u_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]) + return np.dot(u_matrix, np.dot(d_matrix, u_matrix.T)) + + +def mesh_grid(kernel_size): + """Generate the mesh grid, centering at zero. + + Args: + kernel_size (int): + + Returns: + xy (ndarray): with the shape (kernel_size, kernel_size, 2) + xx (ndarray): with the shape (kernel_size, kernel_size) + yy (ndarray): with the shape (kernel_size, kernel_size) + """ + ax = np.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.) + xx, yy = np.meshgrid(ax, ax) + xy = np.hstack((xx.reshape((kernel_size * kernel_size, 1)), yy.reshape(kernel_size * kernel_size, + 1))).reshape(kernel_size, kernel_size, 2) + return xy, xx, yy + + +def pdf2(sigma_matrix, grid): + """Calculate PDF of the bivariate Gaussian distribution. + + Args: + sigma_matrix (ndarray): with the shape (2, 2) + grid (ndarray): generated by :func:`mesh_grid`, + with the shape (K, K, 2), K is the kernel size. + + Returns: + kernel (ndarrray): un-normalized kernel. + """ + inverse_sigma = np.linalg.inv(sigma_matrix) + kernel = np.exp(-0.5 * np.sum(np.dot(grid, inverse_sigma) * grid, 2)) + return kernel + + +def cdf2(d_matrix, grid): + """Calculate the CDF of the standard bivariate Gaussian distribution. + Used in skewed Gaussian distribution. + + Args: + d_matrix (ndarrasy): skew matrix. + grid (ndarray): generated by :func:`mesh_grid`, + with the shape (K, K, 2), K is the kernel size. + + Returns: + cdf (ndarray): skewed cdf. + """ + rv = multivariate_normal([0, 0], [[1, 0], [0, 1]]) + grid = np.dot(grid, d_matrix) + cdf = rv.cdf(grid) + return cdf + + +def bivariate_Gaussian(kernel_size, sig_x, sig_y, theta, grid=None, isotropic=True): + """Generate a bivariate isotropic or anisotropic Gaussian kernel. + + In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored. + + Args: + kernel_size (int): + sig_x (float): + sig_y (float): + theta (float): Radian measurement. + grid (ndarray, optional): generated by :func:`mesh_grid`, + with the shape (K, K, 2), K is the kernel size. Default: None + isotropic (bool): + + Returns: + kernel (ndarray): normalized kernel. + """ + if grid is None: + grid, _, _ = mesh_grid(kernel_size) + if isotropic: + sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]]) + else: + sigma_matrix = sigma_matrix2(sig_x, sig_y, theta) + kernel = pdf2(sigma_matrix, grid) + kernel = kernel / np.sum(kernel) + return kernel + + +def bivariate_generalized_Gaussian(kernel_size, sig_x, sig_y, theta, beta, grid=None, isotropic=True): + """Generate a bivariate generalized Gaussian kernel. + Described in `Parameter Estimation For Multivariate Generalized + Gaussian Distributions`_ + by Pascal et. al (2013). + + In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored. + + Args: + kernel_size (int): + sig_x (float): + sig_y (float): + theta (float): Radian measurement. + beta (float): shape parameter, beta = 1 is the normal distribution. + grid (ndarray, optional): generated by :func:`mesh_grid`, + with the shape (K, K, 2), K is the kernel size. Default: None + + Returns: + kernel (ndarray): normalized kernel. + + .. _Parameter Estimation For Multivariate Generalized Gaussian + Distributions: https://arxiv.org/abs/1302.6498 + """ + if grid is None: + grid, _, _ = mesh_grid(kernel_size) + if isotropic: + sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]]) + else: + sigma_matrix = sigma_matrix2(sig_x, sig_y, theta) + inverse_sigma = np.linalg.inv(sigma_matrix) + kernel = np.exp(-0.5 * np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta)) + kernel = kernel / np.sum(kernel) + return kernel + + +def bivariate_plateau(kernel_size, sig_x, sig_y, theta, beta, grid=None, isotropic=True): + """Generate a plateau-like anisotropic kernel. + 1 / (1+x^(beta)) + + Ref: https://stats.stackexchange.com/questions/203629/is-there-a-plateau-shaped-distribution + + In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored. + + Args: + kernel_size (int): + sig_x (float): + sig_y (float): + theta (float): Radian measurement. + beta (float): shape parameter, beta = 1 is the normal distribution. + grid (ndarray, optional): generated by :func:`mesh_grid`, + with the shape (K, K, 2), K is the kernel size. Default: None + + Returns: + kernel (ndarray): normalized kernel. + """ + if grid is None: + grid, _, _ = mesh_grid(kernel_size) + if isotropic: + sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]]) + else: + sigma_matrix = sigma_matrix2(sig_x, sig_y, theta) + inverse_sigma = np.linalg.inv(sigma_matrix) + kernel = np.reciprocal(np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta) + 1) + kernel = kernel / np.sum(kernel) + return kernel + + +def random_bivariate_Gaussian(kernel_size, + sigma_x_range, + sigma_y_range, + rotation_range, + noise_range=None, + isotropic=True): + """Randomly generate bivariate isotropic or anisotropic Gaussian kernels. + + In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored. + + Args: + kernel_size (int): + sigma_x_range (tuple): [0.6, 5] + sigma_y_range (tuple): [0.6, 5] + rotation range (tuple): [-math.pi, math.pi] + noise_range(tuple, optional): multiplicative kernel noise, + [0.75, 1.25]. Default: None + + Returns: + kernel (ndarray): + """ + assert kernel_size % 2 == 1, 'Kernel size must be an odd number.' + assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.' + sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1]) + if isotropic is False: + assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.' + assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.' + sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1]) + rotation = np.random.uniform(rotation_range[0], rotation_range[1]) + else: + sigma_y = sigma_x + rotation = 0 + + kernel = bivariate_Gaussian(kernel_size, sigma_x, sigma_y, rotation, isotropic=isotropic) + + # add multiplicative noise + if noise_range is not None: + assert noise_range[0] < noise_range[1], 'Wrong noise range.' + noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape) + kernel = kernel * noise + kernel = kernel / np.sum(kernel) + return kernel + + +def random_bivariate_generalized_Gaussian(kernel_size, + sigma_x_range, + sigma_y_range, + rotation_range, + beta_range, + noise_range=None, + isotropic=True): + """Randomly generate bivariate generalized Gaussian kernels. + + In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored. + + Args: + kernel_size (int): + sigma_x_range (tuple): [0.6, 5] + sigma_y_range (tuple): [0.6, 5] + rotation range (tuple): [-math.pi, math.pi] + beta_range (tuple): [0.5, 8] + noise_range(tuple, optional): multiplicative kernel noise, + [0.75, 1.25]. Default: None + + Returns: + kernel (ndarray): + """ + assert kernel_size % 2 == 1, 'Kernel size must be an odd number.' + assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.' + sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1]) + if isotropic is False: + assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.' + assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.' + sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1]) + rotation = np.random.uniform(rotation_range[0], rotation_range[1]) + else: + sigma_y = sigma_x + rotation = 0 + + # assume beta_range[0] < 1 < beta_range[1] + if np.random.uniform() < 0.5: + beta = np.random.uniform(beta_range[0], 1) + else: + beta = np.random.uniform(1, beta_range[1]) + + kernel = bivariate_generalized_Gaussian(kernel_size, sigma_x, sigma_y, rotation, beta, isotropic=isotropic) + + # add multiplicative noise + if noise_range is not None: + assert noise_range[0] < noise_range[1], 'Wrong noise range.' + noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape) + kernel = kernel * noise + kernel = kernel / np.sum(kernel) + return kernel + + +def random_bivariate_plateau(kernel_size, + sigma_x_range, + sigma_y_range, + rotation_range, + beta_range, + noise_range=None, + isotropic=True): + """Randomly generate bivariate plateau kernels. + + In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored. + + Args: + kernel_size (int): + sigma_x_range (tuple): [0.6, 5] + sigma_y_range (tuple): [0.6, 5] + rotation range (tuple): [-math.pi/2, math.pi/2] + beta_range (tuple): [1, 4] + noise_range(tuple, optional): multiplicative kernel noise, + [0.75, 1.25]. Default: None + + Returns: + kernel (ndarray): + """ + assert kernel_size % 2 == 1, 'Kernel size must be an odd number.' + assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.' + sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1]) + if isotropic is False: + assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.' + assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.' + sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1]) + rotation = np.random.uniform(rotation_range[0], rotation_range[1]) + else: + sigma_y = sigma_x + rotation = 0 + + # TODO: this may be not proper + if np.random.uniform() < 0.5: + beta = np.random.uniform(beta_range[0], 1) + else: + beta = np.random.uniform(1, beta_range[1]) + + kernel = bivariate_plateau(kernel_size, sigma_x, sigma_y, rotation, beta, isotropic=isotropic) + # add multiplicative noise + if noise_range is not None: + assert noise_range[0] < noise_range[1], 'Wrong noise range.' + noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape) + kernel = kernel * noise + kernel = kernel / np.sum(kernel) + + return kernel + + +def random_mixed_kernels(kernel_list, + kernel_prob, + kernel_size=21, + sigma_x_range=(0.6, 5), + sigma_y_range=(0.6, 5), + rotation_range=(-math.pi, math.pi), + betag_range=(0.5, 8), + betap_range=(0.5, 8), + noise_range=None): + """Randomly generate mixed kernels. + + Args: + kernel_list (tuple): a list name of kernel types, + support ['iso', 'aniso', 'skew', 'generalized', 'plateau_iso', + 'plateau_aniso'] + kernel_prob (tuple): corresponding kernel probability for each + kernel type + kernel_size (int): + sigma_x_range (tuple): [0.6, 5] + sigma_y_range (tuple): [0.6, 5] + rotation range (tuple): [-math.pi, math.pi] + beta_range (tuple): [0.5, 8] + noise_range(tuple, optional): multiplicative kernel noise, + [0.75, 1.25]. Default: None + + Returns: + kernel (ndarray): + """ + kernel_type = random.choices(kernel_list, kernel_prob)[0] + if kernel_type == 'iso': + kernel = random_bivariate_Gaussian( + kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=True) + elif kernel_type == 'aniso': + kernel = random_bivariate_Gaussian( + kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=False) + elif kernel_type == 'generalized_iso': + kernel = random_bivariate_generalized_Gaussian( + kernel_size, + sigma_x_range, + sigma_y_range, + rotation_range, + betag_range, + noise_range=noise_range, + isotropic=True) + elif kernel_type == 'generalized_aniso': + kernel = random_bivariate_generalized_Gaussian( + kernel_size, + sigma_x_range, + sigma_y_range, + rotation_range, + betag_range, + noise_range=noise_range, + isotropic=False) + elif kernel_type == 'plateau_iso': + kernel = random_bivariate_plateau( + kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=True) + elif kernel_type == 'plateau_aniso': + kernel = random_bivariate_plateau( + kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=False) + return kernel + + +np.seterr(divide='ignore', invalid='ignore') + + +def circular_lowpass_kernel(cutoff, kernel_size, pad_to=0): + """2D sinc filter, ref: https://dsp.stackexchange.com/questions/58301/2-d-circularly-symmetric-low-pass-filter + =====》 这个地方好好调研一下,能做出来的效果决定了后面的上线! + Args: + cutoff (float): cutoff frequency in radians (pi is max) + kernel_size (int): horizontal and vertical size, must be odd. + pad_to (int): pad kernel size to desired size, must be odd or zero. + """ + assert kernel_size % 2 == 1, 'Kernel size must be an odd number.' + kernel = np.fromfunction( + lambda x, y: cutoff * special.j1(cutoff * np.sqrt( + (x - (kernel_size - 1) / 2)**2 + (y - (kernel_size - 1) / 2)**2)) / (2 * np.pi * np.sqrt( + (x - (kernel_size - 1) / 2)**2 + (y - (kernel_size - 1) / 2)**2)), [kernel_size, kernel_size]) + kernel[(kernel_size - 1) // 2, (kernel_size - 1) // 2] = cutoff**2 / (4 * np.pi) + kernel = kernel / np.sum(kernel) + if pad_to > kernel_size: + pad_size = (pad_to - kernel_size) // 2 + kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size))) + return kernel + + +# ------------------------------------------------------------- # +# --------------------------- noise --------------------------- # +# ------------------------------------------------------------- # + +# ----------------------- Gaussian Noise ----------------------- # + + +def generate_gaussian_noise(img, sigma=10, gray_noise=False): + """Generate Gaussian noise. + + Args: + img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32. + sigma (float): Noise scale (measured in range 255). Default: 10. + + Returns: + (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1], + float32. + """ + if gray_noise: + noise = np.float32(np.random.randn(*(img.shape[0:2]))) * sigma / 255. + noise = np.expand_dims(noise, axis=2).repeat(3, axis=2) + else: + noise = np.float32(np.random.randn(*(img.shape))) * sigma / 255. + return noise + + +def add_gaussian_noise(img, sigma=10, clip=True, rounds=False, gray_noise=False): + """Add Gaussian noise. + + Args: + img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32. + sigma (float): Noise scale (measured in range 255). Default: 10. + + Returns: + (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1], + float32. + """ + noise = generate_gaussian_noise(img, sigma, gray_noise) + out = img + noise + if clip and rounds: + out = np.clip((out * 255.0).round(), 0, 255) / 255. + elif clip: + out = np.clip(out, 0, 1) + elif rounds: + out = (out * 255.0).round() / 255. + return out + + +def generate_gaussian_noise_pt(img, sigma=10, gray_noise=0): + """Add Gaussian noise (PyTorch version). + + Args: + img (Tensor): Shape (b, c, h, w), range[0, 1], float32. + sigma (float | Tensor): 每一个batch都被分配了一个(share 一个) + gray_noise (float | Tensor): 不是1就是0 + + Returns: + (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1], + float32. + """ + b, _, h, w = img.size() + if not isinstance(sigma, (float, int)): + sigma = sigma.view(img.size(0), 1, 1, 1) + if isinstance(gray_noise, (float, int)): + cal_gray_noise = gray_noise > 0 + else: + gray_noise = gray_noise.view(b, 1, 1, 1) + cal_gray_noise = torch.sum(gray_noise) > 0 + + if cal_gray_noise: + noise_gray = torch.randn(*img.size()[2:4], dtype=img.dtype, device=img.device) * sigma / 255. + noise_gray = noise_gray.view(b, 1, h, w) + + # always calculate color noise + noise = torch.randn(*img.size(), dtype=img.dtype, device=img.device) * sigma / 255. + + if cal_gray_noise: + noise = noise * (1 - gray_noise) + noise_gray * gray_noise + return noise + + +def add_gaussian_noise_pt(img, sigma=10, gray_noise=0, clip=True, rounds=False): + """Add Gaussian noise (PyTorch version). + + Args: + img (Tensor): Shape (b, c, h, w), range[0, 1], float32. + scale (float | Tensor): Noise scale. Default: 1.0. + + Returns: + (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1], + float32. + """ + noise = generate_gaussian_noise_pt(img, sigma, gray_noise) # sigma 就是gray_noise的保存率 + out = img + noise + if clip and rounds: + out = torch.clamp((out * 255.0).round(), 0, 255) / 255. + elif clip: + out = torch.clamp(out, 0, 1) + elif rounds: + out = (out * 255.0).round() / 255. + return out + + +# ----------------------- Random Gaussian Noise ----------------------- # +def random_generate_gaussian_noise(img, sigma_range=(0, 10), gray_prob=0): + sigma = np.random.uniform(sigma_range[0], sigma_range[1]) + if np.random.uniform() < gray_prob: + gray_noise = True + else: + gray_noise = False + return generate_gaussian_noise(img, sigma, gray_noise) + + +def random_add_gaussian_noise(img, sigma_range=(0, 1.0), gray_prob=0, clip=True, rounds=False): + noise = random_generate_gaussian_noise(img, sigma_range, gray_prob) + out = img + noise + if clip and rounds: + out = np.clip((out * 255.0).round(), 0, 255) / 255. + elif clip: + out = np.clip(out, 0, 1) + elif rounds: + out = (out * 255.0).round() / 255. + return out + + +def random_generate_gaussian_noise_pt(img, sigma_range=(0, 10), gray_prob=0): + sigma = torch.rand(img.size(0), dtype=img.dtype, device=img.device) * (sigma_range[1] - sigma_range[0]) + sigma_range[0] + + gray_noise = torch.rand(img.size(0), dtype=img.dtype, device=img.device) + gray_noise = (gray_noise < gray_prob).float() + return generate_gaussian_noise_pt(img, sigma, gray_noise) + + +def random_add_gaussian_noise_pt(img, sigma_range=(0, 1.0), gray_prob=0, clip=True, rounds=False): + # sigma_range 就是noise保存比例 + noise = random_generate_gaussian_noise_pt(img, sigma_range, gray_prob) + out = img + noise + if clip and rounds: + out = torch.clamp((out * 255.0).round(), 0, 255) / 255. + elif clip: + out = torch.clamp(out, 0, 1) + elif rounds: + out = (out * 255.0).round() / 255. + return out + + +# ----------------------- Poisson (Shot) Noise ----------------------- # + + +def generate_poisson_noise(img, scale=1.0, gray_noise=False): + """Generate poisson noise. + + Ref: https://github.com/scikit-image/scikit-image/blob/main/skimage/util/noise.py#L37-L219 + + Args: + img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32. + scale (float): Noise scale. Default: 1.0. + gray_noise (bool): Whether generate gray noise. Default: False. + + Returns: + (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1], + float32. + """ + if gray_noise: + img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + # round and clip image for counting vals correctly + img = np.clip((img * 255.0).round(), 0, 255) / 255. + vals = len(np.unique(img)) + vals = 2**np.ceil(np.log2(vals)) + out = np.float32(np.random.poisson(img * vals) / float(vals)) + noise = out - img + if gray_noise: + noise = np.repeat(noise[:, :, np.newaxis], 3, axis=2) + return noise * scale + + +def add_poisson_noise(img, scale=1.0, clip=True, rounds=False, gray_noise=False): + """Add poisson noise. + + Args: + img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32. + scale (float): Noise scale. Default: 1.0. + gray_noise (bool): Whether generate gray noise. Default: False. + + Returns: + (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1], + float32. + """ + noise = generate_poisson_noise(img, scale, gray_noise) + out = img + noise + if clip and rounds: + out = np.clip((out * 255.0).round(), 0, 255) / 255. + elif clip: + out = np.clip(out, 0, 1) + elif rounds: + out = (out * 255.0).round() / 255. + return out + + +def generate_poisson_noise_pt(img, scale=1.0, gray_noise=0): + """Generate a batch of poisson noise (PyTorch version) + + Args: + img (Tensor): Input image, shape (b, c, h, w), range [0, 1], float32. + scale (float | Tensor): Noise scale. Number or Tensor with shape (b). + Default: 1.0. + 可以是个batch形式(Tensor) + gray_noise (float | Tensor): 0-1 number or Tensor with shape (b). + 0 for False, 1 for True. Default: 0. + 可以是个batch形式(Tensor) + + Returns: + (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1], + float32. + """ + b, _, h, w = img.size() + if isinstance(gray_noise, (float, int)): + cal_gray_noise = gray_noise > 0 + else: + gray_noise = gray_noise.view(b, 1, 1, 1) + # 这下面跟原论文有点小不一样的地方,如果按照我现在128 batch size,基本上每个都会有gray noise + cal_gray_noise = torch.sum(gray_noise) > 0 + if cal_gray_noise: + # 这里实际上我是觉得写的不是很efficient,因为有些地方如果不加那不是完全白计算了吗,现在gray noise的概率低得很 + img_gray = rgb_to_grayscale(img, num_output_channels=1) # 返回的只有luminance这一个channel + # round and clip image for counting vals correctly, ensure that it only has 256 possible floats at the end + img_gray = torch.clamp((img_gray * 255.0).round(), 0, 255) / 255. + # use for-loop to get the unique values for each sample + + # Note: 这里加上noise完全看的是本图片(一张)的颜色diversity,这应该就解释了为什么在比较单一的flat图像,他会noise更加明显 + vals_list = [len(torch.unique(img_gray[i, :, :, :])) for i in range(b)] + vals_list = [2**np.ceil(np.log2(vals)) for vals in vals_list] + vals = img_gray.new_tensor(vals_list).view(b, 1, 1, 1) + + # Since the img is in range [0,1], the noise by possion distribution should also lies in [0,1] + # Note: 这只是我个人的理解,现在对于单调的图片,整体会比较集中poisson noise在一个高点,就不如unique值高的图片会广泛分布(看possison distribution的图都看的出来) + out = torch.poisson(img_gray * vals) / vals + noise_gray = out - img_gray + noise_gray = noise_gray.expand(b, 3, h, w) + + # always calculate color noise + # round and clip image for counting vals correctly + img = torch.clamp((img * 255.0).round(), 0, 255) / 255. + # use for-loop to get the unique values for each sample + vals_list = [len(torch.unique(img[i, :, :, :])) for i in range(b)] + vals_list = [2**np.ceil(np.log2(vals)) for vals in vals_list] + vals = img.new_tensor(vals_list).view(b, 1, 1, 1) + out = torch.poisson(img * vals) / vals # output还是正数 + noise = out - img # 这个会导致负值的产生 + if cal_gray_noise: + # Note: 这里noise要么全加,要么不加(换成gray_noise) + noise = noise * (1 - gray_noise) + noise_gray * gray_noise # In this place, I don't know why it sometimes run out of memory + if not isinstance(scale, (float, int)): + scale = scale.view(b, 1, 1, 1) + + # Note: noise这边产出的值都是-0.x ---- +0.x 这个范围: 负的值相当于减弱pixel值的效果 + # print("poisson noise range is ", sorted(torch.unique(noise))[:10]) + # print(sorted(torch.unique(noise))[-10:]) + return noise * scale + + +def add_poisson_noise_pt(img, scale=1.0, clip=True, rounds=False, gray_noise=0): + """Add poisson noise to a batch of images (PyTorch version). + + Args: + img (Tensor): Input image, shape (b, c, h, w), range [0, 1], float32. + scale (float | Tensor): Noise scale. Number or Tensor with shape (b). + Default: 1.0. + gray_noise (float | Tensor): 0-1 number or Tensor with shape (b). + 0 for False, 1 for True. Default: 0. + + Returns: + (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1], + float32. + """ + noise = generate_poisson_noise_pt(img, scale, gray_noise) + out = img + noise + if clip and rounds: + out = torch.clamp((out * 255.0).round(), 0, 255) / 255. + elif clip: + out = torch.clamp(out, 0, 1) + elif rounds: + out = (out * 255.0).round() / 255. + return out + + +# ----------------------- Random Poisson (Shot) Noise ----------------------- # + + +def random_generate_poisson_noise(img, scale_range=(0, 1.0), gray_prob=0): + scale = np.random.uniform(scale_range[0], scale_range[1]) + if np.random.uniform() < gray_prob: + gray_noise = True + else: + gray_noise = False + return generate_poisson_noise(img, scale, gray_noise) + + +def random_add_poisson_noise(img, scale_range=(0, 1.0), gray_prob=0, clip=True, rounds=False): + noise = random_generate_poisson_noise(img, scale_range, gray_prob) + out = img + noise + if clip and rounds: + out = np.clip((out * 255.0).round(), 0, 255) / 255. + elif clip: + out = np.clip(out, 0, 1) + elif rounds: + out = (out * 255.0).round() / 255. + return out + + +def random_generate_poisson_noise_pt(img, scale_range=(0, 1.0), gray_prob=0): + # scale_range 还是保存的大小 + # img.size(0) 代表就是batch中的每个图片都有一个自己的scale level + scale = torch.rand(img.size(0), dtype=img.dtype, device=img.device) * (scale_range[1] - scale_range[0]) + scale_range[0] + + gray_noise = torch.rand(img.size(0), dtype=img.dtype, device=img.device) + gray_noise = (gray_noise < gray_prob).float() + return generate_poisson_noise_pt(img, scale, gray_noise) # scale 和 gray_noise应该都是tensor的batch形式 + + +def random_add_poisson_noise_pt(img, scale_range=(0, 1.0), gray_prob=0, clip=True, rounds=False): + noise = random_generate_poisson_noise_pt(img, scale_range, gray_prob) + out = img + noise + if clip and rounds: + out = torch.clamp((out * 255.0).round(), 0, 255) / 255. + elif clip: + out = torch.clamp(out, 0, 1) + elif rounds: + out = (out * 255.0).round() / 255. + return out + + +# ------------------------------------------------------------------------ # +# --------------------------- JPEG compression --------------------------- # +# ------------------------------------------------------------------------ # + + +def add_jpg_compression(img, quality=90): + """Add JPG compression artifacts. + + Args: + img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32. + quality (float): JPG compression quality. 0 for lowest quality, 100 for + best quality. Default: 90. + + Returns: + (Numpy array): Returned image after JPG, shape (h, w, c), range[0, 1], + float32. + """ + img = np.clip(img, 0, 1) + encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality] + _, encimg = cv2.imencode('.jpg', img * 255., encode_param) + img = np.float32(cv2.imdecode(encimg, 1)) / 255. + return img + + +def random_add_jpg_compression(img, quality_range=(90, 100)): + """Randomly add JPG compression artifacts. + + Args: + img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32. + quality_range (tuple[float] | list[float]): JPG compression quality + range. 0 for lowest quality, 100 for best quality. + Default: (90, 100). + + Returns: + (Numpy array): Returned image after JPG, shape (h, w, c), range[0, 1], + float32. + """ + quality = np.random.uniform(quality_range[0], quality_range[1]) + return add_jpg_compression(img, quality) diff --git a/degradation/ESR/diffjpeg.py b/degradation/ESR/diffjpeg.py new file mode 100644 index 0000000000000000000000000000000000000000..0eb19898d6d0e45ab7eaf6670a0e46715767ce47 --- /dev/null +++ b/degradation/ESR/diffjpeg.py @@ -0,0 +1,517 @@ +# -*- coding: utf-8 -*- + +""" +Modified from https://github.com/mlomnitz/DiffJPEG + +For images not divisible by 8 +https://dsp.stackexchange.com/questions/35339/jpeg-dct-padding/35343#35343 +""" +import itertools +import numpy as np +import torch +import torch.nn as nn +from torch.nn import functional as F + +# ------------------------ utils ------------------------# +y_table = np.array( + [[16, 11, 10, 16, 24, 40, 51, 61], [12, 12, 14, 19, 26, 58, 60, 55], [14, 13, 16, 24, 40, 57, 69, 56], + [14, 17, 22, 29, 51, 87, 80, 62], [18, 22, 37, 56, 68, 109, 103, 77], [24, 35, 55, 64, 81, 104, 113, 92], + [49, 64, 78, 87, 103, 121, 120, 101], [72, 92, 95, 98, 112, 100, 103, 99]], + dtype=np.float32).T +y_table = nn.Parameter(torch.from_numpy(y_table)) +c_table = np.empty((8, 8), dtype=np.float32) +c_table.fill(99) +c_table[:4, :4] = np.array([[17, 18, 24, 47], [18, 21, 26, 66], [24, 26, 56, 99], [47, 66, 99, 99]]).T +c_table = nn.Parameter(torch.from_numpy(c_table)) + + +def diff_round(x): + """ Differentiable rounding function + """ + return torch.round(x) + (x - torch.round(x))**3 + + +def quality_to_factor(quality): + """ Calculate factor corresponding to quality + + Args: + quality(float): Quality for jpeg compression. + + Returns: + float: Compression factor. + """ + if quality < 50: + quality = 5000. / quality + else: + quality = 200. - quality * 2 + return quality / 100. + + +# ------------------------ compression ------------------------# +class RGB2YCbCrJpeg(nn.Module): + """ Converts RGB image to YCbCr + """ + + def __init__(self): + super(RGB2YCbCrJpeg, self).__init__() + matrix = np.array([[0.299, 0.587, 0.114], [-0.168736, -0.331264, 0.5], [0.5, -0.418688, -0.081312]], + dtype=np.float32).T + self.shift = nn.Parameter(torch.tensor([0., 128., 128.])) + self.matrix = nn.Parameter(torch.from_numpy(matrix)) + + def forward(self, image): + """ + Args: + image(Tensor): batch x 3 x height x width + + Returns: + Tensor: batch x height x width x 3 + """ + image = image.permute(0, 2, 3, 1) + result = torch.tensordot(image, self.matrix, dims=1) + self.shift + return result.view(image.shape) + + +class ChromaSubsampling(nn.Module): + """ Chroma subsampling on CbCr channels + """ + + def __init__(self): + super(ChromaSubsampling, self).__init__() + + def forward(self, image): + """ + Args: + image(tensor): batch x height x width x 3 + + Returns: + y(tensor): batch x height x width + cb(tensor): batch x height/2 x width/2 + cr(tensor): batch x height/2 x width/2 + """ + image_2 = image.permute(0, 3, 1, 2).clone() + cb = F.avg_pool2d(image_2[:, 1, :, :].unsqueeze(1), kernel_size=2, stride=(2, 2), count_include_pad=False) + cr = F.avg_pool2d(image_2[:, 2, :, :].unsqueeze(1), kernel_size=2, stride=(2, 2), count_include_pad=False) + cb = cb.permute(0, 2, 3, 1) + cr = cr.permute(0, 2, 3, 1) + return image[:, :, :, 0], cb.squeeze(3), cr.squeeze(3) + + +class BlockSplitting(nn.Module): + """ Splitting image into patches + """ + + def __init__(self): + super(BlockSplitting, self).__init__() + self.k = 8 + + def forward(self, image): + """ + Args: + image(tensor): batch x height x width + + Returns: + Tensor: batch x h*w/64 x h x w + """ + height, _ = image.shape[1:3] + batch_size = image.shape[0] + image_reshaped = image.view(batch_size, height // self.k, self.k, -1, self.k) + image_transposed = image_reshaped.permute(0, 1, 3, 2, 4) + return image_transposed.contiguous().view(batch_size, -1, self.k, self.k) + + +class DCT8x8(nn.Module): + """ Discrete Cosine Transformation + """ + + def __init__(self): + super(DCT8x8, self).__init__() + tensor = np.zeros((8, 8, 8, 8), dtype=np.float32) + for x, y, u, v in itertools.product(range(8), repeat=4): + tensor[x, y, u, v] = np.cos((2 * x + 1) * u * np.pi / 16) * np.cos((2 * y + 1) * v * np.pi / 16) + alpha = np.array([1. / np.sqrt(2)] + [1] * 7) + self.tensor = nn.Parameter(torch.from_numpy(tensor).float()) + self.scale = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha) * 0.25).float()) + + def forward(self, image): + """ + Args: + image(tensor): batch x height x width + + Returns: + Tensor: batch x height x width + """ + image = image - 128 + result = self.scale * torch.tensordot(image, self.tensor, dims=2) + result.view(image.shape) + return result + + +class YQuantize(nn.Module): + """ JPEG Quantization for Y channel + + Args: + rounding(function): rounding function to use + """ + + def __init__(self, rounding): + super(YQuantize, self).__init__() + self.rounding = rounding + self.y_table = y_table + + def forward(self, image, factor=1): + """ + Args: + image(tensor): batch x height x width + + Returns: + Tensor: batch x height x width + """ + if isinstance(factor, (int, float)): + image = image.float() / (self.y_table * factor) + else: + b = factor.size(0) + table = self.y_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1) + image = image.float() / table + image = self.rounding(image) + return image + + +class CQuantize(nn.Module): + """ JPEG Quantization for CbCr channels + + Args: + rounding(function): rounding function to use + """ + + def __init__(self, rounding): + super(CQuantize, self).__init__() + self.rounding = rounding + self.c_table = c_table + + def forward(self, image, factor=1): + """ + Args: + image(tensor): batch x height x width + + Returns: + Tensor: batch x height x width + """ + if isinstance(factor, (int, float)): + image = image.float() / (self.c_table * factor) + else: + b = factor.size(0) + table = self.c_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1) + image = image.float() / table + image = self.rounding(image) + return image + + +class CompressJpeg(nn.Module): + """Full JPEG compression algorithm + + Args: + rounding(function): rounding function to use + """ + + def __init__(self, rounding=torch.round): + super(CompressJpeg, self).__init__() + self.l1 = nn.Sequential(RGB2YCbCrJpeg(), ChromaSubsampling()) + self.l2 = nn.Sequential(BlockSplitting(), DCT8x8()) + self.c_quantize = CQuantize(rounding=rounding) + self.y_quantize = YQuantize(rounding=rounding) + + def forward(self, image, factor=1): + """ + Args: + image(tensor): batch x 3 x height x width + + Returns: + dict(tensor): Compressed tensor with batch x h*w/64 x 8 x 8. + """ + y, cb, cr = self.l1(image * 255) + components = {'y': y, 'cb': cb, 'cr': cr} + for k in components.keys(): + comp = self.l2(components[k]) + if k in ('cb', 'cr'): + comp = self.c_quantize(comp, factor=factor) + else: + comp = self.y_quantize(comp, factor=factor) + + components[k] = comp + + return components['y'], components['cb'], components['cr'] + + +# ------------------------ decompression ------------------------# + + +class YDequantize(nn.Module): + """Dequantize Y channel + """ + + def __init__(self): + super(YDequantize, self).__init__() + self.y_table = y_table + + def forward(self, image, factor=1): + """ + Args: + image(tensor): batch x height x width + + Returns: + Tensor: batch x height x width + """ + if isinstance(factor, (int, float)): + out = image * (self.y_table * factor) + else: + b = factor.size(0) + table = self.y_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1) + out = image * table + return out + + +class CDequantize(nn.Module): + """Dequantize CbCr channel + """ + + def __init__(self): + super(CDequantize, self).__init__() + self.c_table = c_table + + def forward(self, image, factor=1): + """ + Args: + image(tensor): batch x height x width + + Returns: + Tensor: batch x height x width + """ + if isinstance(factor, (int, float)): + out = image * (self.c_table * factor) + else: + b = factor.size(0) + table = self.c_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1) + out = image * table + return out + + +class iDCT8x8(nn.Module): + """Inverse discrete Cosine Transformation + """ + + def __init__(self): + super(iDCT8x8, self).__init__() + alpha = np.array([1. / np.sqrt(2)] + [1] * 7) + self.alpha = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha)).float()) + tensor = np.zeros((8, 8, 8, 8), dtype=np.float32) + for x, y, u, v in itertools.product(range(8), repeat=4): + tensor[x, y, u, v] = np.cos((2 * u + 1) * x * np.pi / 16) * np.cos((2 * v + 1) * y * np.pi / 16) + self.tensor = nn.Parameter(torch.from_numpy(tensor).float()) + + def forward(self, image): + """ + Args: + image(tensor): batch x height x width + + Returns: + Tensor: batch x height x width + """ + image = image * self.alpha + result = 0.25 * torch.tensordot(image, self.tensor, dims=2) + 128 + result.view(image.shape) + return result + + +class BlockMerging(nn.Module): + """Merge patches into image + """ + + def __init__(self): + super(BlockMerging, self).__init__() + + def forward(self, patches, height, width): + """ + Args: + patches(tensor) batch x height*width/64, height x width + height(int) + width(int) + + Returns: + Tensor: batch x height x width + """ + k = 8 + batch_size = patches.shape[0] + image_reshaped = patches.view(batch_size, height // k, width // k, k, k) + image_transposed = image_reshaped.permute(0, 1, 3, 2, 4) + return image_transposed.contiguous().view(batch_size, height, width) + + +class ChromaUpsampling(nn.Module): + """Upsample chroma layers + """ + + def __init__(self): + super(ChromaUpsampling, self).__init__() + + def forward(self, y, cb, cr): + """ + Args: + y(tensor): y channel image + cb(tensor): cb channel + cr(tensor): cr channel + + Returns: + Tensor: batch x height x width x 3 + """ + + def repeat(x, k=2): + height, width = x.shape[1:3] + x = x.unsqueeze(-1) + x = x.repeat(1, 1, k, k) + x = x.view(-1, height * k, width * k) + return x + + cb = repeat(cb) + cr = repeat(cr) + return torch.cat([y.unsqueeze(3), cb.unsqueeze(3), cr.unsqueeze(3)], dim=3) + + +class YCbCr2RGBJpeg(nn.Module): + """Converts YCbCr image to RGB JPEG + """ + + def __init__(self): + super(YCbCr2RGBJpeg, self).__init__() + + matrix = np.array([[1., 0., 1.402], [1, -0.344136, -0.714136], [1, 1.772, 0]], dtype=np.float32).T + self.shift = nn.Parameter(torch.tensor([0, -128., -128.])) + self.matrix = nn.Parameter(torch.from_numpy(matrix)) + + def forward(self, image): + """ + Args: + image(tensor): batch x height x width x 3 + + Returns: + Tensor: batch x 3 x height x width + """ + result = torch.tensordot(image + self.shift, self.matrix, dims=1) + return result.view(image.shape).permute(0, 3, 1, 2) + + +class DeCompressJpeg(nn.Module): + """Full JPEG decompression algorithm + + Args: + rounding(function): rounding function to use + """ + + def __init__(self, rounding=torch.round): + super(DeCompressJpeg, self).__init__() + self.c_dequantize = CDequantize() + self.y_dequantize = YDequantize() + self.idct = iDCT8x8() + self.merging = BlockMerging() + self.chroma = ChromaUpsampling() + self.colors = YCbCr2RGBJpeg() + + def forward(self, y, cb, cr, imgh, imgw, factor=1): + """ + Args: + compressed(dict(tensor)): batch x h*w/64 x 8 x 8 + imgh(int) + imgw(int) + factor(float) + + Returns: + Tensor: batch x 3 x height x width + """ + components = {'y': y, 'cb': cb, 'cr': cr} + for k in components.keys(): + if k in ('cb', 'cr'): + comp = self.c_dequantize(components[k], factor=factor) + height, width = int(imgh / 2), int(imgw / 2) + else: + comp = self.y_dequantize(components[k], factor=factor) + height, width = imgh, imgw + comp = self.idct(comp) + components[k] = self.merging(comp, height, width) + # + image = self.chroma(components['y'], components['cb'], components['cr']) + image = self.colors(image) + + image = torch.min(255 * torch.ones_like(image), torch.max(torch.zeros_like(image), image)) + return image / 255 + + +# ------------------------ main DiffJPEG ------------------------ # + + +class DiffJPEG(nn.Module): + """This JPEG algorithm result is slightly different from cv2. + DiffJPEG supports batch processing. + + Args: + differentiable(bool): If True, uses custom differentiable rounding function, if False, uses standard torch.round + """ + + def __init__(self, differentiable=True): + super(DiffJPEG, self).__init__() + if differentiable: + rounding = diff_round + else: + rounding = torch.round + + self.compress = CompressJpeg(rounding=rounding) + self.decompress = DeCompressJpeg(rounding=rounding) + + def forward(self, x, quality): + """ + Args: + x (Tensor): Input image, bchw, rgb, [0, 1] + quality(float): Quality factor for jpeg compression scheme. + """ + factor = quality + if isinstance(factor, (int, float)): + factor = quality_to_factor(factor) + else: + for i in range(factor.size(0)): + factor[i] = quality_to_factor(factor[i]) + h, w = x.size()[-2:] + h_pad, w_pad = 0, 0 + # why should use 16 + if h % 16 != 0: + h_pad = 16 - h % 16 + if w % 16 != 0: + w_pad = 16 - w % 16 + x = F.pad(x, (0, w_pad, 0, h_pad), mode='constant', value=0) + + y, cb, cr = self.compress(x, factor=factor) + recovered = self.decompress(y, cb, cr, (h + h_pad), (w + w_pad), factor=factor) + recovered = recovered[:, :, 0:h, 0:w] + return recovered + + +if __name__ == '__main__': + import cv2 + + from basicsr.utils import img2tensor, tensor2img + + img_gt = cv2.imread('test.png') / 255. + + # -------------- cv2 -------------- # + encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), 20] + _, encimg = cv2.imencode('.jpg', img_gt * 255., encode_param) + img_lq = np.float32(cv2.imdecode(encimg, 1)) + cv2.imwrite('cv2_JPEG_20.png', img_lq) + + # -------------- DiffJPEG -------------- # + jpeger = DiffJPEG(differentiable=False).cuda() + img_gt = img2tensor(img_gt) + img_gt = torch.stack([img_gt, img_gt]).cuda() + quality = img_gt.new_tensor([20, 40]) + out = jpeger(img_gt, quality=quality) + + cv2.imwrite('pt_JPEG_20.png', tensor2img(out[0])) + cv2.imwrite('pt_JPEG_40.png', tensor2img(out[1])) diff --git a/degradation/ESR/usm_sharp.py b/degradation/ESR/usm_sharp.py new file mode 100644 index 0000000000000000000000000000000000000000..7246b1a57007404619b8876c7e90ccb909e8b270 --- /dev/null +++ b/degradation/ESR/usm_sharp.py @@ -0,0 +1,114 @@ +# -*- coding: utf-8 -*- + +import cv2 +import numpy as np +import torch +from torch.nn import functional as F + +import os, sys +root_path = os.path.abspath('.') +sys.path.append(root_path) +from degradation.ESR.utils import filter2D, np2tensor, tensor2np + + +def usm_sharp_func(img, weight=0.5, radius=50, threshold=10): + """USM sharpening. + + Input image: I; Blurry image: B. + 1. sharp = I + weight * (I - B) + 2. Mask = 1 if abs(I - B) > threshold, else: 0 + 3. Blur mask: + 4. Out = Mask * sharp + (1 - Mask) * I + + + Args: + img (Numpy array): Input image, HWC, BGR; float32, [0, 1]. + weight (float): Sharp weight. Default: 1. + radius (float): Kernel size of Gaussian blur. Default: 50. + threshold (int): + """ + if radius % 2 == 0: + radius += 1 + blur = cv2.GaussianBlur(img, (radius, radius), 0) + residual = img - blur + mask = np.abs(residual) * 255 > threshold + mask = mask.astype('float32') + soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) + + sharp = img + weight * residual + sharp = np.clip(sharp, 0, 1) + return soft_mask * sharp + (1 - soft_mask) * img + + + +class USMSharp(torch.nn.Module): + + def __init__(self, type, radius=50, sigma=0): + super(USMSharp, self).__init__() + if radius % 2 == 0: + radius += 1 + self.radius = radius + kernel = cv2.getGaussianKernel(radius, sigma) + kernel = torch.FloatTensor(np.dot(kernel, kernel.transpose())).unsqueeze_(0).cuda() + self.register_buffer('kernel', kernel) + + self.type = type + + + def forward(self, img, weight=0.5, threshold=10, store=False): + + if self.type == "cv2": + # pre-process cv2 type + img = np2tensor(img) + + blur = filter2D(img, self.kernel.cuda()) + if store: + cv2.imwrite("blur.png", tensor2np(blur)) + + residual = img - blur + if store: + cv2.imwrite("residual.png", tensor2np(residual)) + + mask = torch.abs(residual) * 255 > threshold + if store: + cv2.imwrite("mask.png", tensor2np(mask)) + + + mask = mask.float() + soft_mask = filter2D(mask, self.kernel.cuda()) + if store: + cv2.imwrite("soft_mask.png", tensor2np(soft_mask)) + + sharp = img + weight * residual + sharp = torch.clip(sharp, 0, 1) + if store: + cv2.imwrite("sharp.png", tensor2np(sharp)) + + output = soft_mask * sharp + (1 - soft_mask) * img + if self.type == "cv2": + output = tensor2np(output) + + return output + + + +if __name__ == "__main__": + + usm_sharper = USMSharp(type="cv2") + img = cv2.imread("sample3.png") + print(img.shape) + sharp_output = usm_sharper(img, store=False, threshold=10) + cv2.imwrite(os.path.join("output.png"), sharp_output) + + + # dir = r"C:\Users\HikariDawn\Desktop\Real-CUGAN\datasets\sample" + # output_dir = r"C:\Users\HikariDawn\Desktop\Real-CUGAN\datasets\sharp_regular" + # if not os.path.exists(output_dir): + # os.makedirs(output_dir) + + # for file_name in sorted(os.listdir(dir)): + # print(file_name) + # file = os.path.join(dir, file_name) + # img = cv2.imread(file) + # sharp_output = usm_sharper(img) + # cv2.imwrite(os.path.join(output_dir, file_name), sharp_output) diff --git a/degradation/ESR/utils.py b/degradation/ESR/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2cb4adfde53e9127dbad4e8434a3d7b298b704ac --- /dev/null +++ b/degradation/ESR/utils.py @@ -0,0 +1,126 @@ +# -*- coding: utf-8 -*- + +''' + From ESRGAN +''' + + +import os, sys +import cv2 +import numpy as np +import torch +from torch.nn import functional as F +from scipy import special +import random +import math +from torchvision.utils import make_grid + +from degradation.ESR.degradations_functionality import * + +root_path = os.path.abspath('.') +sys.path.append(root_path) + + +def np2tensor(np_frame): + return torch.from_numpy(np.transpose(np_frame, (2, 0, 1))).unsqueeze(0).cuda().float()/255 + +def tensor2np(tensor): + # tensor should be batch size1 and cannot be grayscale input + return (np.transpose(tensor.detach().squeeze(0).cpu().numpy(), (1, 2, 0))) * 255 + +def mass_tensor2np(tensor): + ''' The input tensor is massive tensor + ''' + return (np.transpose(tensor.detach().squeeze(0).cpu().numpy(), (0, 2, 3, 1))) * 255 + +def save_img(tensor, save_name): + np_img = tensor2np(tensor)[:,:,16] + # np_img = np.expand_dims(np_img, axis=2) + cv2.imwrite(save_name, np_img) + + +def filter2D(img, kernel): + """PyTorch version of cv2.filter2D + + Args: + img (Tensor): (b, c, h, w) + kernel (Tensor): (b, k, k) + """ + k = kernel.size(-1) + b, c, h, w = img.size() + if k % 2 == 1: + img = F.pad(img, (k // 2, k // 2, k // 2, k // 2), mode='reflect') + else: + raise ValueError('Wrong kernel size') + + ph, pw = img.size()[-2:] + + if kernel.size(0) == 1: + # apply the same kernel to all batch images + img = img.view(b * c, 1, ph, pw) + kernel = kernel.view(1, 1, k, k) + return F.conv2d(img, kernel, padding=0).view(b, c, h, w) + else: + img = img.view(1, b * c, ph, pw) + kernel = kernel.view(b, 1, k, k).repeat(1, c, 1, 1).view(b * c, 1, k, k) + return F.conv2d(img, kernel, groups=b * c).view(b, c, h, w) + + +def generate_kernels(opt): + + kernel_range = [2 * v + 1 for v in range(opt["kernel_range"][0], opt["kernel_range"][1])] + + # ------------------------ Generate kernels (used in the first degradation) ------------------------ # + kernel_size = random.choice(kernel_range) + if np.random.uniform() < opt['sinc_prob']: + # 里面加一层sinc filter,但是10%的概率 + # this sinc filter setting is for kernels ranging from [7, 21] + if kernel_size < 13: + omega_c = np.random.uniform(np.pi / 3, np.pi) + else: + omega_c = np.random.uniform(np.pi / 5, np.pi) + kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False) + else: + kernel = random_mixed_kernels( + opt['kernel_list'], + opt['kernel_prob'], + kernel_size, + opt['blur_sigma'], + opt['blur_sigma'], [-math.pi, math.pi], + opt['betag_range'], + opt['betap_range'], + noise_range=None) + # pad kernel: -在v2我是直接省略了padding + pad_size = (21 - kernel_size) // 2 + kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size))) + + + # ------------------------ Generate kernels (used in the second degradation) ------------------------ # + kernel_size = random.choice(kernel_range) + if np.random.uniform() < opt['sinc_prob2']: + # 里面加一层sinc filter,但是10%的概率 + if kernel_size < 13: + omega_c = np.random.uniform(np.pi / 3, np.pi) + else: + omega_c = np.random.uniform(np.pi / 5, np.pi) + kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False) + else: + kernel2 = random_mixed_kernels( + opt['kernel_list2'], + opt['kernel_prob2'], + kernel_size, + opt['blur_sigma2'], + opt['blur_sigma2'], [-math.pi, math.pi], + opt['betag_range2'], + opt['betap_range2'], + noise_range=None) + + # pad kernel + pad_size = (21 - kernel_size) // 2 + kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size))) + + kernel = torch.FloatTensor(kernel) + kernel2 = torch.FloatTensor(kernel2) + return (kernel, kernel2) + + diff --git a/degradation/degradation_esr.py b/degradation/degradation_esr.py new file mode 100644 index 0000000000000000000000000000000000000000..44b531e81731e742905111cbdba7397899970cd1 --- /dev/null +++ b/degradation/degradation_esr.py @@ -0,0 +1,110 @@ +# -*- coding: utf-8 -*- +import torch +import os +import sys +import torch.nn.functional as F + +root_path = os.path.abspath('.') +sys.path.append(root_path) +# Import files from the local folder +from opt import opt +from degradation.ESR.utils import generate_kernels, mass_tensor2np, tensor2np +from degradation.ESR.degradations_functionality import * +from degradation.ESR.degradation_esr_shared import common_degradation as regular_common_degradation +from degradation.image_compression.jpeg import JPEG # 这里最好后面用一个继承解决一切 +from degradation.image_compression.webp import WEBP +from degradation.image_compression.heif import HEIF +from degradation.image_compression.avif import AVIF +from degradation.video_compression.h264 import H264 +from degradation.video_compression.h265 import H265 +from degradation.video_compression.mpeg2 import MPEG2 +from degradation.video_compression.mpeg4 import MPEG4 + + +class degradation_v1: + def __init__(self): + self.kernel1, self.kernel2, self.sinc_kernel = None, None, None + self.queue_size = 160 + + # Init the compression instance + self.jpeg_instance = JPEG() + self.webp_instance = WEBP() + # self.heif_instance = HEIF() + self.avif_instance = AVIF() + self.H264_instance = H264() + self.H265_instance = H265() + self.MPEG2_instance = MPEG2() + self.MPEG4_instance = MPEG4() + + + def reset_kernels(self, opt): + kernel1, kernel2 = generate_kernels(opt) + self.kernel1 = kernel1.unsqueeze(0).cuda() + self.kernel2 = kernel2.unsqueeze(0).cuda() + + + @torch.no_grad() + def degradate_process(self, out, opt, store_path, process_id, verbose = False): + ''' ESR Degradation V1 mode (Same as the original paper) + Args: + out (tensor): BxCxHxW All input images as tensor + opt (dict): All configuration we need to process + store_path (str): Store Directory + process_id (int): The id we used to store temporary file + verbose (bool): Whether print some information for auxiliary log (default: False) + ''' + + batch_size, _, ori_h, ori_w = out.size() + + # Shared degradation until the last step + resize_mode = random.choice(opt['resize_options']) + out = regular_common_degradation(out, opt, [self.kernel1, self.kernel2], process_id, verbose=verbose) + + + # Resize back + out = F.interpolate(out, size=(ori_h // opt['scale'], ori_w // opt['scale']), mode = resize_mode) + out = torch.clamp(out, 0, 1) + # TODO: 可能Tensor2Numpy会放在之前,而不是在这里,一起转换节约时间 + + # Tensor2np + np_frame = tensor2np(out) + + # Choose an image compression codec (All degradation batch use the same codec) + compression_codec = random.choices(opt['compression_codec2'], opt['compression_codec_prob2'])[0] # All lower case + + if compression_codec == "jpeg": + self.jpeg_instance.compress_and_store(np_frame, store_path, process_id) + + elif compression_codec == "webp": + try: + self.webp_instance.compress_and_store(np_frame, store_path, process_id) + except Exception: + print("There appears to be exception in webp again!") + if os.path.exists(store_path): + os.remove(store_path) + self.webp_instance.compress_and_store(np_frame, store_path, process_id) + + elif compression_codec == "avif": + self.avif_instance.compress_and_store(np_frame, store_path, process_id) + + elif compression_codec == "h264": + self.H264_instance.compress_and_store(np_frame, store_path, process_id) + + elif compression_codec == "h265": + self.H265_instance.compress_and_store(np_frame, store_path, process_id) + + elif compression_codec == "mpeg2": + self.MPEG2_instance.compress_and_store(np_frame, store_path, process_id) + + elif compression_codec == "mpeg4": + self.MPEG4_instance.compress_and_store(np_frame, store_path, process_id) + + else: + raise NotImplementedError("This compression codec is not supported! Please check the implementation!") + + + + + + + diff --git a/degradation/image_compression/avif.py b/degradation/image_compression/avif.py new file mode 100644 index 0000000000000000000000000000000000000000..56737a1465e590877a353aaffed5d0347de61b56 --- /dev/null +++ b/degradation/image_compression/avif.py @@ -0,0 +1,88 @@ +import torch, sys, os, random +import torch.nn.functional as F +import numpy as np +import cv2 +from multiprocessing import Process, Queue +from PIL import Image +import pillow_heif + +root_path = os.path.abspath('.') +sys.path.append(root_path) +# Import files from the local folder +from opt import opt +from degradation.ESR.utils import tensor2np, np2tensor + + + +class AVIF(): + def __init__(self) -> None: + # Choose an image compression degradation + pass + + def compress_and_store(self, np_frames, store_path, idx): + ''' Compress and Store the whole batch as AVIF (~ AV1) + Args: + np_frames (numpy): The numpy format of the data (Shape:?) + store_path (str): The store path + Return: + None + ''' + # Init call for avif + pillow_heif.register_avif_opener() + + + single_frame = np_frames + + # Prepare + essential_name = "tmp/temp_"+str(idx) + + # Choose the quality + quality = random.randint(*opt['avif_quality_range2']) + method = random.randint(*opt['avif_encode_speed2']) + + # Transform to PIL and then compress + PIL_image = Image.fromarray(np.uint8(single_frame[...,::-1])).convert('RGB') + PIL_image.save(essential_name+'.avif', quality=quality, method=method) + + # Read as png + avif_file = pillow_heif.open_heif(essential_name+'.avif', convert_hdr_to_8bit=False, bgr_mode=True) + np_array = np.asarray(avif_file) + cv2.imwrite(store_path, np_array) + + os.remove(essential_name+'.avif') + + + + @staticmethod + def compress_tensor(tensor_frames, idx=0): + ''' Compress tensor input to AVIF and then return it + Args: + tensor_frame (tensor): Tensor inputs + Returns: + result (tensor): Tensor outputs (same shape as input) + ''' + # Init call for avif + pillow_heif.register_avif_opener() + + # Prepare + single_frame = tensor2np(tensor_frames) + essential_name = "tmp/temp_"+str(idx) + + # Choose the quality + quality = random.randint(*opt['avif_quality_range1']) + method = random.randint(*opt['avif_encode_speed1']) + + # Transform to PIL and then compress + PIL_image = Image.fromarray(np.uint8(single_frame[...,::-1])).convert('RGB') + PIL_image.save(essential_name+'.avif', quality=quality, method=method) + + # Transform as png format + avif_file = pillow_heif.open_heif(essential_name+'.avif', convert_hdr_to_8bit=False, bgr_mode=True) + decimg = np.asarray(avif_file) + os.remove(essential_name+'.avif') + + # Read back + result = np2tensor(decimg) + + + return result \ No newline at end of file diff --git a/degradation/image_compression/heif.py b/degradation/image_compression/heif.py new file mode 100644 index 0000000000000000000000000000000000000000..74771d8ebc99fabe4ebde2b8f370565cd70c6773 --- /dev/null +++ b/degradation/image_compression/heif.py @@ -0,0 +1,90 @@ +import torch, sys, os, random +import torch.nn.functional as F +import numpy as np +import cv2 +from multiprocessing import Process, Queue +from PIL import Image +from pillow_heif import register_heif_opener +import pillow_heif + +root_path = os.path.abspath('.') +sys.path.append(root_path) +# Import files from the local folder +from opt import opt +from degradation.ESR.utils import tensor2np, np2tensor + + + + +class HEIF(): + def __init__(self) -> None: + # Choose an image compression degradation + pass + + def compress_and_store(self, np_frames, store_path): + ''' Compress and Store the whole batch as HEIF (~ HEVC) + Args: + np_frames (numpy): The numpy format of the data (Shape:?) + store_path (str): The store path + Return: + None + ''' + # Init call for heif + register_heif_opener() + + single_frame = np_frames + + # Prepare + essential_name = store_path.split('.')[0] + + # Choose the quality + quality = random.randint(*opt['heif_quality_range1']) + method = random.randint(*opt['heif_encode_speed1']) + + # Transform to PIL and then compress + PIL_image = Image.fromarray(np.uint8(single_frame[...,::-1])).convert('RGB') + PIL_image.save(essential_name+'.heic', quality=quality, method=method) + + # Transform as png format + heif_file = pillow_heif.open_heif(essential_name+'.heic', convert_hdr_to_8bit=False, bgr_mode=True) + np_array = np.asarray(heif_file) + cv2.imwrite(store_path, np_array) + + os.remove(essential_name+'.heic') + + + @staticmethod + def compress_tensor(tensor_frames, idx=0): + ''' Compress tensor input to HEIF and then return it + Args: + tensor_frame (tensor): Tensor inputs + Returns: + result (tensor): Tensor outputs (same shape as input) + ''' + + # Init call for heif + register_heif_opener() + + # Prepare + single_frame = tensor2np(tensor_frames) + essential_name = "tmp/temp_"+str(idx) + + # Choose the quality + quality = random.randint(*opt['heif_quality_range1']) + method = random.randint(*opt['heif_encode_speed1']) + + # Transform to PIL and then compress + PIL_image = Image.fromarray(np.uint8(single_frame[...,::-1])).convert('RGB') + PIL_image.save(essential_name+'.heic', quality=quality, method=method) + + # Transform as png format + heif_file = pillow_heif.open_heif(essential_name+'.heic', convert_hdr_to_8bit=False, bgr_mode=True) + decimg = np.asarray(heif_file) + os.remove(essential_name+'.heic') + + # Read back + result = np2tensor(decimg) + + return result + + diff --git a/degradation/image_compression/jpeg.py b/degradation/image_compression/jpeg.py new file mode 100644 index 0000000000000000000000000000000000000000..86fbe5816dacf200fc69033661853bfb0291d73e --- /dev/null +++ b/degradation/image_compression/jpeg.py @@ -0,0 +1,68 @@ +import sys, os, random +import cv2, torch +from multiprocessing import Process, Queue + +root_path = os.path.abspath('.') +sys.path.append(root_path) +# Import files from the local folder +from opt import opt +from degradation.ESR.utils import tensor2np, np2tensor + + + +class JPEG(): + def __init__(self) -> None: + # Choose an image compression degradation + # self.jpeger = DiffJPEG(differentiable=False).cuda() + pass + + def compress_and_store(self, np_frames, store_path, idx): + ''' Compress and Store the whole batch as JPEG + Args: + np_frames (numpy): The numpy format of the data (Shape:?) + store_path (str): The store path + Return: + None + ''' + + # Preparation + single_frame = np_frames + + # Compress as JPEG + jpeg_quality = random.randint(*opt['jpeg_quality_range2']) + + encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), jpeg_quality] + _, encimg = cv2.imencode('.jpg', single_frame, encode_param) + decimg = cv2.imdecode(encimg, 1) + + # Store the image with quality + cv2.imwrite(store_path, decimg) + + + + @staticmethod + def compress_tensor(tensor_frames): + ''' Compress tensor input to JPEG and then return it + Args: + tensor_frame (tensor): Tensor inputs + Returns: + result (tensor): Tensor outputs (same shape as input) + ''' + + single_frame = tensor2np(tensor_frames) + + # Compress as JPEG + jpeg_quality = random.randint(*opt['jpeg_quality_range1']) + + encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), jpeg_quality] + _, encimg = cv2.imencode('.jpg', single_frame, encode_param) + decimg = cv2.imdecode(encimg, 1) + + # Store the image with quality + # cv2.imwrite(store_name, decimg) + result = np2tensor(decimg) + + return result + + + \ No newline at end of file diff --git a/degradation/image_compression/webp.py b/degradation/image_compression/webp.py new file mode 100644 index 0000000000000000000000000000000000000000..ff951e2c136bb4d0227d349a68aa41d8b94024d5 --- /dev/null +++ b/degradation/image_compression/webp.py @@ -0,0 +1,65 @@ +import torch, sys, os, random +import torch.nn.functional as F +import numpy as np +import cv2 +from multiprocessing import Process, Queue +from PIL import Image + +root_path = os.path.abspath('.') +sys.path.append(root_path) +# Import files from the local folder +from opt import opt +from degradation.ESR.utils import tensor2np, np2tensor + + + + +class WEBP(): + def __init__(self) -> None: + # Choose an image compression degradation + pass + + def compress_and_store(self, np_frames, store_path, idx): + ''' Compress and Store the whole batch as WebP (~ VP8) + Args: + np_frames (numpy): The numpy format of the data (Shape:?) + store_path (str): The store path + Return: + None + ''' + single_frame = np_frames + + # Choose the quality + quality = random.randint(*opt['webp_quality_range2']) + method = random.randint(*opt['webp_encode_speed2']) + + # Transform to PIL and then compress + PIL_image = Image.fromarray(np.uint8(single_frame[...,::-1])).convert('RGB') + PIL_image.save(store_path, 'webp', quality=quality, method=method) + + + @staticmethod + def compress_tensor(tensor_frames, idx = 0): + ''' Compress tensor input to WEBP and then return it + Args: + tensor_frame (tensor): Tensor inputs + Returns: + result (tensor): Tensor outputs (same shape as input) + ''' + single_frame = tensor2np(tensor_frames) + + # Choose the quality + quality = random.randint(*opt['webp_quality_range1']) + method = random.randint(*opt['webp_encode_speed1']) + + # Transform to PIL and then compress + PIL_image = Image.fromarray(np.uint8(single_frame[...,::-1])).convert('RGB') + store_path = os.path.join("tmp", "temp_"+str(idx)+".webp") + PIL_image.save(store_path, 'webp', quality=quality, method=method) + + # Read back + decimg = cv2.imread(store_path) + result = np2tensor(decimg) + os.remove(store_path) + + return result \ No newline at end of file diff --git a/degradation/video_compression/h264.py b/degradation/video_compression/h264.py new file mode 100644 index 0000000000000000000000000000000000000000..a5ab74073fb9b364d7236c47a6727dd43c11cbeb --- /dev/null +++ b/degradation/video_compression/h264.py @@ -0,0 +1,73 @@ +import torch, sys, os, random +import cv2 +import shutil + +root_path = os.path.abspath('.') +sys.path.append(root_path) +# Import files from the local folder +from opt import opt + + + +class H264(): + def __init__(self) -> None: + # Choose an image compression degradation + pass + + def compress_and_store(self, single_frame, store_path, idx): + ''' Compress and Store the whole batch as H.264 (for 2nd stage) + Args: + single_frame (numpy): The numpy format of the data (Shape:?) + store_path (str): The store path + idx (int): A unique process idx + Return: + None + ''' + + # Prepare + temp_input_path = "tmp/input_"+str(idx) + video_store_dir = "tmp/encoded_"+str(idx)+".mp4" + temp_store_path = "tmp/output_"+str(idx) + os.makedirs(temp_input_path) + os.makedirs(temp_store_path) + + # Move frame + cv2.imwrite(os.path.join(temp_input_path, "1.png"), single_frame) + + + # Decide the quality + crf = str(random.randint(*opt['h264_crf_range2'])) + preset = random.choices(opt['h264_preset_mode2'], opt['h264_preset_prob2'])[0] + + # Encode + ffmpeg_encode_cmd = "ffmpeg -i " + temp_input_path + "/%d.png -vcodec libx264 -crf " + crf + " -preset " + preset + " -pix_fmt yuv420p " + video_store_dir + " -loglevel 0" + os.system(ffmpeg_encode_cmd) + + + # Decode + ffmpeg_decode_cmd = "ffmpeg -i " + video_store_dir + " " + temp_store_path + "/%d.png -loglevel 0" + os.system(ffmpeg_decode_cmd) + if len(os.listdir(temp_store_path)) != 1: + print("This is strange") + assert(len(os.listdir(temp_store_path)) == 1) + + # Move frame to the target places + shutil.copy(os.path.join(temp_store_path, "1.png"), store_path) + + # Clean temp files + os.remove(video_store_dir) + shutil.rmtree(temp_input_path) + shutil.rmtree(temp_store_path) + + + + @staticmethod + def compress_tensor(tensor_frames, idx=0): + ''' Compress tensor input to H.264 and then return it (for 1st stage) + Args: + tensor_frame (tensor): Tensor inputs + Returns: + result (tensor): Tensor outputs (same shape as input) + ''' + + pass \ No newline at end of file diff --git a/degradation/video_compression/h265.py b/degradation/video_compression/h265.py new file mode 100644 index 0000000000000000000000000000000000000000..d68e0526a6947134331639c2eee54da1fef8fca6 --- /dev/null +++ b/degradation/video_compression/h265.py @@ -0,0 +1,71 @@ +import torch, sys, os, random +import cv2 +import shutil + +root_path = os.path.abspath('.') +sys.path.append(root_path) +# Import files from the local folder +from opt import opt + + + +class H265(): + def __init__(self) -> None: + # Choose an image compression degradation + pass + + def compress_and_store(self, single_frame, store_path, idx): + ''' Compress and Store the whole batch as H.265 (for 2nd stage) + Args: + single_frame (numpy): The numpy format of the data (Shape:?) + store_path (str): The store path + idx (int): A unique process idx + Return: + None + ''' + + # Prepare + temp_input_path = "tmp/input_"+str(idx) + video_store_dir = "tmp/encoded_"+str(idx)+".mp4" + temp_store_path = "tmp/output_"+str(idx) + os.makedirs(temp_input_path) + os.makedirs(temp_store_path) + + # Move frame + cv2.imwrite(os.path.join(temp_input_path, "1.png"), single_frame) + + + # Decide the quality + crf = str(random.randint(*opt['h265_crf_range2'])) + preset = random.choices(opt['h265_preset_mode2'], opt['h265_preset_prob2'])[0] + + # Encode + ffmpeg_encode_cmd = "ffmpeg -i " + temp_input_path + "/%d.png -vcodec libx265 -x265-params log-level=error -crf " + crf + " -preset " + preset + " -pix_fmt yuv420p " + video_store_dir + " -loglevel 0" + os.system(ffmpeg_encode_cmd) + + + # Decode + ffmpeg_decode_cmd = "ffmpeg -i " + video_store_dir + " " + temp_store_path + "/%d.png -loglevel 0" + os.system(ffmpeg_decode_cmd) + assert(len(os.listdir(temp_store_path)) == 1) + + # Move frame to the target places + shutil.copy(os.path.join(temp_store_path, "1.png"), store_path) + + # Clean temp files + os.remove(video_store_dir) + shutil.rmtree(temp_input_path) + shutil.rmtree(temp_store_path) + + + + @staticmethod + def compress_tensor(tensor_frames, idx=0): + ''' Compress tensor input to H.265 and then return it (for 1st stage) + Args: + tensor_frame (tensor): Tensor inputs + Returns: + result (tensor): Tensor outputs (same shape as input) + ''' + + pass \ No newline at end of file diff --git a/degradation/video_compression/mpeg2.py b/degradation/video_compression/mpeg2.py new file mode 100644 index 0000000000000000000000000000000000000000..abdcfc6b0d1c499503179133267ed12009c9cb9a --- /dev/null +++ b/degradation/video_compression/mpeg2.py @@ -0,0 +1,71 @@ +import torch, sys, os, random +import cv2 +import shutil + +root_path = os.path.abspath('.') +sys.path.append(root_path) +# Import files from the local folder +from opt import opt + + + +class MPEG2(): + def __init__(self) -> None: + # Choose an image compression degradation + pass + + def compress_and_store(self, single_frame, store_path, idx): + ''' Compress and Store the whole batch as MPEG-2 (for 2nd stage) + Args: + single_frame (numpy): The numpy format of the data (Shape:?) + store_path (str): The store path + idx (int): A unique process idx + Return: + None + ''' + + # Prepare + temp_input_path = "tmp/input_"+str(idx) + video_store_dir = "tmp/encoded_"+str(idx)+".mp4" + temp_store_path = "tmp/output_"+str(idx) + os.makedirs(temp_input_path) + os.makedirs(temp_store_path) + + # Move frame + cv2.imwrite(os.path.join(temp_input_path, "1.png"), single_frame) + + + # Decide the quality + quality = str(random.randint(*opt['mpeg2_quality2'])) + preset = random.choices(opt['mpeg2_preset_mode2'], opt['mpeg2_preset_prob2'])[0] + + # Encode + ffmpeg_encode_cmd = "ffmpeg -i " + temp_input_path + "/%d.png -vcodec mpeg2video -qscale:v " + quality + " -preset " + preset + " -pix_fmt yuv420p " + video_store_dir + " -loglevel 0" + os.system(ffmpeg_encode_cmd) + + + # Decode + ffmpeg_decode_cmd = "ffmpeg -i " + video_store_dir + " " + temp_store_path + "/%d.png -loglevel 0" + os.system(ffmpeg_decode_cmd) + assert(len(os.listdir(temp_store_path)) == 1) + + # Move frame to the target places + shutil.copy(os.path.join(temp_store_path, "1.png"), store_path) + + # Clean temp files + os.remove(video_store_dir) + shutil.rmtree(temp_input_path) + shutil.rmtree(temp_store_path) + + + + @staticmethod + def compress_tensor(tensor_frames, idx=0): + ''' Compress tensor input to H.264 and then return it (for 1st stage) + Args: + tensor_frame (tensor): Tensor inputs + Returns: + result (tensor): Tensor outputs (same shape as input) + ''' + + pass \ No newline at end of file diff --git a/degradation/video_compression/mpeg4.py b/degradation/video_compression/mpeg4.py new file mode 100644 index 0000000000000000000000000000000000000000..517b72b773f6ce9611ed3388ad502ddd5da61632 --- /dev/null +++ b/degradation/video_compression/mpeg4.py @@ -0,0 +1,71 @@ +import torch, sys, os, random +import cv2 +import shutil + +root_path = os.path.abspath('.') +sys.path.append(root_path) +# Import files from the local folder +from opt import opt + + + +class MPEG4(): + def __init__(self) -> None: + # Choose an image compression degradation + pass + + def compress_and_store(self, single_frame, store_path, idx): + ''' Compress and Store the whole batch as MPEG-4 (for 2nd stage) + Args: + single_frame (numpy): The numpy format of the data (Shape:?) + store_path (str): The store path + idx (int): A unique process idx + Return: + None + ''' + + # Prepare + temp_input_path = "tmp/input_"+str(idx) + video_store_dir = "tmp/encoded_"+str(idx)+".mp4" + temp_store_path = "tmp/output_"+str(idx) + os.makedirs(temp_input_path) + os.makedirs(temp_store_path) + + # Move frame + cv2.imwrite(os.path.join(temp_input_path, "1.png"), single_frame) + + + # Decide the quality + quality = str(random.randint(*opt['mpeg4_quality2'])) + preset = random.choices(opt['mpeg4_preset_mode2'], opt['mpeg4_preset_prob2'])[0] + + # Encode + ffmpeg_encode_cmd = "ffmpeg -i " + temp_input_path + "/%d.png -vcodec libxvid -qscale:v " + quality + " -preset " + preset + " -pix_fmt yuv420p " + video_store_dir + " -loglevel 0" + os.system(ffmpeg_encode_cmd) + + + # Decode + ffmpeg_decode_cmd = "ffmpeg -i " + video_store_dir + " " + temp_store_path + "/%d.png -loglevel 0" + os.system(ffmpeg_decode_cmd) + assert(len(os.listdir(temp_store_path)) == 1) + + # Move frame to the target places + shutil.copy(os.path.join(temp_store_path, "1.png"), store_path) + + # Clean temp files + os.remove(video_store_dir) + shutil.rmtree(temp_input_path) + shutil.rmtree(temp_store_path) + + + + @staticmethod + def compress_tensor(tensor_frames, idx=0): + ''' Compress tensor input to MPEG4 and then return it (for 1st stage) + Args: + tensor_frame (tensor): Tensor inputs + Returns: + result (tensor): Tensor outputs (same shape as input) + ''' + + pass \ No newline at end of file diff --git a/docs/model_zoo.md b/docs/model_zoo.md new file mode 100644 index 0000000000000000000000000000000000000000..433b25babe4c7ae1412dd3fb716e0e22a2a2ae76 --- /dev/null +++ b/docs/model_zoo.md @@ -0,0 +1,24 @@ +# :european_castle: Model Zoo + +- [For Paper weight](#for-paper-weight) +- [For Diverse Upscaler](#for-diverse-upscaler) + + + +## For Paper Weight + +| Models | Scale | Description | +| ------------------------------------------------------------------------------------------------------------------------------- | :---- | :------------------------------------------- | +| [4x_APISR_GRL_GAN_generator](https://github.com/Kiteretsu77/APISR/releases/download/v0.1.0/4x_APISR_GRL_GAN_generator.pth) | 4X | 4X GRL model used in the paper | + + +## For Diverse Upscaler + +Actually, I am not that much like GRL. Though they can have the smallest param size with higher numerical results, they are not very memory efficient and the processing speed is slow for Transformer model. One more concern come from the TensorRT deployment, where Transformer architecture is hard to be adapted (needless to say for a modified version of Transformer like GRL). + +Thus, for other weights, I will not train a GRL network and also real-world SR of GRL only supports 4x. + + +| Models | Scale | Description | +| ------------------------------------------------------------------------------------------------------------------------------- | :---- | :------------------------------------------- | +| [2x_APISR_RRDB_GAN_generator](https://github.com/Kiteretsu77/APISR/releases/download/v0.1.0/2x_APISR_RRDB_GAN_generator.pth) | 2X | 2X upscaler by RRDB-6blocks | \ No newline at end of file diff --git a/loss/anime_perceptual_loss.py b/loss/anime_perceptual_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..592c46625557e80aa42959e543df215aa1b3ce92 --- /dev/null +++ b/loss/anime_perceptual_loss.py @@ -0,0 +1,497 @@ +import os, sys +from collections import OrderedDict +import cv2 +import torch.nn as nn +import torch +from torchvision import models +import torchvision.transforms as transforms + +''' +---------------------------------------------------------------- + Layer (type) Output Shape Param # +================================================================ + Conv2d-1 [-1, 64, 112, 112] 9,408 + BatchNorm2d-2 [-1, 64, 112, 112] 128 + ReLU-3 [-1, 64, 112, 112] 0 + MaxPool2d-4 [-1, 64, 56, 56] 0 + Conv2d-5 [-1, 64, 56, 56] 4,096 + BatchNorm2d-6 [-1, 64, 56, 56] 128 + ReLU-7 [-1, 64, 56, 56] 0 + Conv2d-8 [-1, 64, 56, 56] 36,864 + BatchNorm2d-9 [-1, 64, 56, 56] 128 + ReLU-10 [-1, 64, 56, 56] 0 + Conv2d-11 [-1, 256, 56, 56] 16,384 + BatchNorm2d-12 [-1, 256, 56, 56] 512 + Conv2d-13 [-1, 256, 56, 56] 16,384 + BatchNorm2d-14 [-1, 256, 56, 56] 512 + ReLU-15 [-1, 256, 56, 56] 0 + Bottleneck-16 [-1, 256, 56, 56] 0 + Conv2d-17 [-1, 64, 56, 56] 16,384 + BatchNorm2d-18 [-1, 64, 56, 56] 128 + ReLU-19 [-1, 64, 56, 56] 0 + Conv2d-20 [-1, 64, 56, 56] 36,864 + BatchNorm2d-21 [-1, 64, 56, 56] 128 + ReLU-22 [-1, 64, 56, 56] 0 + Conv2d-23 [-1, 256, 56, 56] 16,384 + BatchNorm2d-24 [-1, 256, 56, 56] 512 + ReLU-25 [-1, 256, 56, 56] 0 + Bottleneck-26 [-1, 256, 56, 56] 0 + Conv2d-27 [-1, 64, 56, 56] 16,384 + BatchNorm2d-28 [-1, 64, 56, 56] 128 + ReLU-29 [-1, 64, 56, 56] 0 + Conv2d-30 [-1, 64, 56, 56] 36,864 + BatchNorm2d-31 [-1, 64, 56, 56] 128 + ReLU-32 [-1, 64, 56, 56] 0 + Conv2d-33 [-1, 256, 56, 56] 16,384 + BatchNorm2d-34 [-1, 256, 56, 56] 512 + ReLU-35 [-1, 256, 56, 56] 0 + Bottleneck-36 [-1, 256, 56, 56] 0 + Conv2d-37 [-1, 128, 56, 56] 32,768 + BatchNorm2d-38 [-1, 128, 56, 56] 256 + ReLU-39 [-1, 128, 56, 56] 0 + Conv2d-40 [-1, 128, 28, 28] 147,456 + BatchNorm2d-41 [-1, 128, 28, 28] 256 + ReLU-42 [-1, 128, 28, 28] 0 + Conv2d-43 [-1, 512, 28, 28] 65,536 + BatchNorm2d-44 [-1, 512, 28, 28] 1,024 + Conv2d-45 [-1, 512, 28, 28] 131,072 + BatchNorm2d-46 [-1, 512, 28, 28] 1,024 + ReLU-47 [-1, 512, 28, 28] 0 + Bottleneck-48 [-1, 512, 28, 28] 0 + Conv2d-49 [-1, 128, 28, 28] 65,536 + BatchNorm2d-50 [-1, 128, 28, 28] 256 + ReLU-51 [-1, 128, 28, 28] 0 + Conv2d-52 [-1, 128, 28, 28] 147,456 + BatchNorm2d-53 [-1, 128, 28, 28] 256 + ReLU-54 [-1, 128, 28, 28] 0 + Conv2d-55 [-1, 512, 28, 28] 65,536 + BatchNorm2d-56 [-1, 512, 28, 28] 1,024 + ReLU-57 [-1, 512, 28, 28] 0 + Bottleneck-58 [-1, 512, 28, 28] 0 + Conv2d-59 [-1, 128, 28, 28] 65,536 + BatchNorm2d-60 [-1, 128, 28, 28] 256 + ReLU-61 [-1, 128, 28, 28] 0 + Conv2d-62 [-1, 128, 28, 28] 147,456 + BatchNorm2d-63 [-1, 128, 28, 28] 256 + ReLU-64 [-1, 128, 28, 28] 0 + Conv2d-65 [-1, 512, 28, 28] 65,536 + BatchNorm2d-66 [-1, 512, 28, 28] 1,024 + ReLU-67 [-1, 512, 28, 28] 0 + Bottleneck-68 [-1, 512, 28, 28] 0 + Conv2d-69 [-1, 128, 28, 28] 65,536 + BatchNorm2d-70 [-1, 128, 28, 28] 256 + ReLU-71 [-1, 128, 28, 28] 0 + Conv2d-72 [-1, 128, 28, 28] 147,456 + BatchNorm2d-73 [-1, 128, 28, 28] 256 + ReLU-74 [-1, 128, 28, 28] 0 + Conv2d-75 [-1, 512, 28, 28] 65,536 + BatchNorm2d-76 [-1, 512, 28, 28] 1,024 + ReLU-77 [-1, 512, 28, 28] 0 + Bottleneck-78 [-1, 512, 28, 28] 0 + Conv2d-79 [-1, 256, 28, 28] 131,072 + BatchNorm2d-80 [-1, 256, 28, 28] 512 + ReLU-81 [-1, 256, 28, 28] 0 + Conv2d-82 [-1, 256, 14, 14] 589,824 + BatchNorm2d-83 [-1, 256, 14, 14] 512 + ReLU-84 [-1, 256, 14, 14] 0 + Conv2d-85 [-1, 1024, 14, 14] 262,144 + BatchNorm2d-86 [-1, 1024, 14, 14] 2,048 + Conv2d-87 [-1, 1024, 14, 14] 524,288 + BatchNorm2d-88 [-1, 1024, 14, 14] 2,048 + ReLU-89 [-1, 1024, 14, 14] 0 + Bottleneck-90 [-1, 1024, 14, 14] 0 + Conv2d-91 [-1, 256, 14, 14] 262,144 + BatchNorm2d-92 [-1, 256, 14, 14] 512 + ReLU-93 [-1, 256, 14, 14] 0 + Conv2d-94 [-1, 256, 14, 14] 589,824 + BatchNorm2d-95 [-1, 256, 14, 14] 512 + ReLU-96 [-1, 256, 14, 14] 0 + Conv2d-97 [-1, 1024, 14, 14] 262,144 + BatchNorm2d-98 [-1, 1024, 14, 14] 2,048 + ReLU-99 [-1, 1024, 14, 14] 0 + Bottleneck-100 [-1, 1024, 14, 14] 0 + Conv2d-101 [-1, 256, 14, 14] 262,144 + BatchNorm2d-102 [-1, 256, 14, 14] 512 + ReLU-103 [-1, 256, 14, 14] 0 + Conv2d-104 [-1, 256, 14, 14] 589,824 + BatchNorm2d-105 [-1, 256, 14, 14] 512 + ReLU-106 [-1, 256, 14, 14] 0 + Conv2d-107 [-1, 1024, 14, 14] 262,144 + BatchNorm2d-108 [-1, 1024, 14, 14] 2,048 + ReLU-109 [-1, 1024, 14, 14] 0 + Bottleneck-110 [-1, 1024, 14, 14] 0 + Conv2d-111 [-1, 256, 14, 14] 262,144 + BatchNorm2d-112 [-1, 256, 14, 14] 512 + ReLU-113 [-1, 256, 14, 14] 0 + Conv2d-114 [-1, 256, 14, 14] 589,824 + BatchNorm2d-115 [-1, 256, 14, 14] 512 + ReLU-116 [-1, 256, 14, 14] 0 + Conv2d-117 [-1, 1024, 14, 14] 262,144 + BatchNorm2d-118 [-1, 1024, 14, 14] 2,048 + ReLU-119 [-1, 1024, 14, 14] 0 + Bottleneck-120 [-1, 1024, 14, 14] 0 + Conv2d-121 [-1, 256, 14, 14] 262,144 + BatchNorm2d-122 [-1, 256, 14, 14] 512 + ReLU-123 [-1, 256, 14, 14] 0 + Conv2d-124 [-1, 256, 14, 14] 589,824 + BatchNorm2d-125 [-1, 256, 14, 14] 512 + ReLU-126 [-1, 256, 14, 14] 0 + Conv2d-127 [-1, 1024, 14, 14] 262,144 + BatchNorm2d-128 [-1, 1024, 14, 14] 2,048 + ReLU-129 [-1, 1024, 14, 14] 0 + Bottleneck-130 [-1, 1024, 14, 14] 0 + Conv2d-131 [-1, 256, 14, 14] 262,144 + BatchNorm2d-132 [-1, 256, 14, 14] 512 + ReLU-133 [-1, 256, 14, 14] 0 + Conv2d-134 [-1, 256, 14, 14] 589,824 + BatchNorm2d-135 [-1, 256, 14, 14] 512 + ReLU-136 [-1, 256, 14, 14] 0 + Conv2d-137 [-1, 1024, 14, 14] 262,144 + BatchNorm2d-138 [-1, 1024, 14, 14] 2,048 + ReLU-139 [-1, 1024, 14, 14] 0 + Bottleneck-140 [-1, 1024, 14, 14] 0 + Conv2d-141 [-1, 512, 14, 14] 524,288 + BatchNorm2d-142 [-1, 512, 14, 14] 1,024 + ReLU-143 [-1, 512, 14, 14] 0 + Conv2d-144 [-1, 512, 7, 7] 2,359,296 + BatchNorm2d-145 [-1, 512, 7, 7] 1,024 + ReLU-146 [-1, 512, 7, 7] 0 + Conv2d-147 [-1, 2048, 7, 7] 1,048,576 + BatchNorm2d-148 [-1, 2048, 7, 7] 4,096 + Conv2d-149 [-1, 2048, 7, 7] 2,097,152 + BatchNorm2d-150 [-1, 2048, 7, 7] 4,096 + ReLU-151 [-1, 2048, 7, 7] 0 + Bottleneck-152 [-1, 2048, 7, 7] 0 + Conv2d-153 [-1, 512, 7, 7] 1,048,576 + BatchNorm2d-154 [-1, 512, 7, 7] 1,024 + ReLU-155 [-1, 512, 7, 7] 0 + Conv2d-156 [-1, 512, 7, 7] 2,359,296 + BatchNorm2d-157 [-1, 512, 7, 7] 1,024 + ReLU-158 [-1, 512, 7, 7] 0 + Conv2d-159 [-1, 2048, 7, 7] 1,048,576 + BatchNorm2d-160 [-1, 2048, 7, 7] 4,096 + ReLU-161 [-1, 2048, 7, 7] 0 + Bottleneck-162 [-1, 2048, 7, 7] 0 + Conv2d-163 [-1, 512, 7, 7] 1,048,576 + BatchNorm2d-164 [-1, 512, 7, 7] 1,024 + ReLU-165 [-1, 512, 7, 7] 0 + Conv2d-166 [-1, 512, 7, 7] 2,359,296 + BatchNorm2d-167 [-1, 512, 7, 7] 1,024 + ReLU-168 [-1, 512, 7, 7] 0 + Conv2d-169 [-1, 2048, 7, 7] 1,048,576 + BatchNorm2d-170 [-1, 2048, 7, 7] 4,096 + ReLU-171 [-1, 2048, 7, 7] 0 + Bottleneck-172 [-1, 2048, 7, 7] 0 +AdaptiveMaxPool2d-173 [-1, 2048, 1, 1] 0 +AdaptiveAvgPool2d-174 [-1, 2048, 1, 1] 0 +AdaptiveConcatPool2d-175 [-1, 4096, 1, 1] 0 + Flatten-176 [-1, 4096] 0 + BatchNorm1d-177 [-1, 4096] 8,192 + Dropout-178 [-1, 4096] 0 + Linear-179 [-1, 512] 2,097,664 + ReLU-180 [-1, 512] 0 + BatchNorm1d-181 [-1, 512] 1,024 + Dropout-182 [-1, 512] 0 + Linear-183 [-1, 6000] 3,078,000 +================================================================ +Total params: 28,692,912 +Trainable params: 28,692,912 +Non-trainable params: 0 +---------------------------------------------------------------- +Input size (MB): 0.57 +Forward/backward pass size (MB): 286.75 +Params size (MB): 109.45 +Estimated Total Size (MB): 396.78 +---------------------------------------------------------------- +''' + + +class AdaptiveConcatPool2d(nn.Module): + """ + Layer that concats `AdaptiveAvgPool2d` and `AdaptiveMaxPool2d`. + Source: Fastai. This code was taken from the fastai library at url + https://github.com/fastai/fastai/blob/master/fastai/layers.py#L176 + """ + def __init__(self, sz=None): + "Output will be 2*sz or 2 if sz is None" + super().__init__() + self.output_size = sz or 1 + self.ap = nn.AdaptiveAvgPool2d(self.output_size) + self.mp = nn.AdaptiveMaxPool2d(self.output_size) + + def forward(self, x): return torch.cat([self.mp(x), self.ap(x)], 1) + + +class Flatten(nn.Module): + """ + Flatten `x` to a single dimension. Adapted from fastai's Flatten() layer, + at https://github.com/fastai/fastai/blob/master/fastai/layers.py#L25 + """ + def __init__(self): super().__init__() + def forward(self, x): return x.view(x.size(0), -1) + + +def bn_drop_lin(n_in:int, n_out:int, bn:bool=True, p:float=0., actn=None): + """ + Sequence of batchnorm (if `bn`), dropout (with `p`) and linear (`n_in`,`n_out`) layers followed by `actn`. + Adapted from Fastai at https://github.com/fastai/fastai/blob/master/fastai/layers.py#L44 + """ + layers = [nn.BatchNorm1d(n_in)] if bn else [] + if p != 0: layers.append(nn.Dropout(p)) + layers.append(nn.Linear(n_in, n_out)) + if actn is not None: layers.append(actn) + return layers + +def create_head(top_n_tags, nf, ps=0.5): + nc = top_n_tags + + lin_ftrs = [nf, 512, nc] + p1 = 0.25 # dropout for second last layer + p2 = 0.5 # dropout for last layer + + actns = [nn.ReLU(inplace=True),] + [None] + pool = AdaptiveConcatPool2d() + layers = [pool, Flatten()] + + layers += [ + *bn_drop_lin(lin_ftrs[0], lin_ftrs[1], True, p1, nn.ReLU(inplace=True)), + *bn_drop_lin(lin_ftrs[1], lin_ftrs[2], True, p2) + ] + + return nn.Sequential(*layers) + + +def _resnet(base_arch, top_n, **kwargs): + cut = -2 + s = base_arch(pretrained=False, **kwargs) + body = nn.Sequential(*list(s.children())[:cut]) + + if base_arch in [models.resnet18, models.resnet34]: + num_features_model = 512 + elif base_arch in [models.resnet50, models.resnet101]: + num_features_model = 2048 + + nf = num_features_model * 2 + nc = top_n + + # head = create_head(nc, nf) + model = body # nn.Sequential(body, head) + + return model + + +def resnet50(pretrained=True, progress=True, top_n=6000, **kwargs): + r""" + Resnet50 model trained on the full Danbooru2018 dataset's top 6000 tags + + Args: + pretrained (bool): kwargs, load pretrained weights into the model. + top_n (int): kwargs, pick to load the model for predicting the top `n` tags, + currently only supports top_n=6000. + """ + model = _resnet(models.resnet50, top_n, **kwargs) # Take Resnet without the head (we don't care about final FC layers) + + if pretrained: + if top_n == 6000: + state = torch.hub.load_state_dict_from_url("https://github.com/RF5/danbooru-pretrained/releases/download/v0.1/resnet50-13306192.pth", + progress=progress) + old_keys = [key for key in state] + for old_key in old_keys: + if old_key[0] == '0': + new_key = old_key[2:] + state[new_key] = state[old_key] + del state[old_key] + elif old_key[0] == '1': + del state[old_key] + + model.load_state_dict(state) + else: + raise ValueError("Sorry, the resnet50 model only supports the top-6000 tags \ + at the moment") + + + return model + + + + +class resnet50_Extractor(nn.Module): + """ResNet50 network for feature extraction. + """ + def get_activation(self, name): + def hook(model, input, output): + self.activation[name] = output.detach() + return hook + + + def __init__(self, + model, + layer_labels, + use_input_norm=True, + range_norm=False, + requires_grad=False + ): + super(resnet50_Extractor, self).__init__() + + + self.model = model + self.use_input_norm = use_input_norm + self.range_norm = range_norm + self.layer_labels = layer_labels + self.activation = {} + + + # Extract needed features + for layer_label in layer_labels: + elements = layer_label.split('_') + if len(elements) == 1: + # modified_net[layer_label] = getattr(model, elements[0]) + getattr(self.model, elements[0]).register_forward_hook(self.get_activation(layer_label)) + else: + body_layer = self.model + for element in elements[:-1]: + # Iterate until the last element + assert(isinstance(int(element), int)) + body_layer = body_layer[int(element)] + getattr(body_layer, elements[-1]).register_forward_hook(self.get_activation(layer_label)) + + + # Set as evaluation + if not requires_grad: + self.model.eval() + for param in self.parameters(): + param.requires_grad = False + + + if self.use_input_norm: + # the mean is for image with range [0, 1] + self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) + # the std is for image with range [0, 1] + self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) + + + + def forward(self, x): + """Forward function. + + Args: + x (Tensor): Input tensor with shape (n, c, h, w). + + Returns: + Tensor: Forward results. + """ + if self.range_norm: + x = (x + 1) / 2 + if self.use_input_norm: + x = (x - self.mean) / self.std + + # Execute model first + output = self.model(x) # Zomby input + + # Extract the layers we need + store = {} + for layer_label in self.layer_labels: + store[layer_label] = self.activation[layer_label] + + + return store + + +class Anime_PerceptualLoss(nn.Module): + """Anime Perceptual loss + + Args: + layer_weights (dict): The weight for each layer of vgg feature. + Here is an example: {'conv5_4': 1.}, which means the conv5_4 + feature layer (before relu5_4) will be extracted with weight + 1.0 in calculating losses. + perceptual_weight (float): If `perceptual_weight > 0`, the perceptual + loss will be calculated and the loss will multiplied by the + weight. Default: 1.0. + criterion (str): Criterion used for perceptual loss. Default: 'l1'. + """ + + def __init__(self, + layer_weights, + perceptual_weight=1.0, + criterion='l1'): + super(Anime_PerceptualLoss, self).__init__() + + + model = resnet50() + self.perceptual_weight = perceptual_weight + self.layer_weights = layer_weights + self.layer_labels = layer_weights.keys() + self.resnet50 = resnet50_Extractor(model, self.layer_labels).cuda() + + if criterion == 'l1': + self.criterion = torch.nn.L1Loss() + else: + raise NotImplementedError("We don't support such criterion loss in perceptual loss") + + + def forward(self, gen, gt): + """Forward function. + + Args: + gen (Tensor): Input tensor with shape (n, c, h, w). + gt (Tensor): Ground-truth tensor with shape (n, c, h, w). + + Returns: + Tensor: Forward results. + """ + # extract vgg features + gen_features = self.resnet50(gen) + gt_features = self.resnet50(gt.detach()) + + + temp_store = [] + + # calculate perceptual loss + if self.perceptual_weight > 0: + percep_loss = 0 + for idx, k in enumerate(gen_features.keys()): + raw_comparison = self.criterion(gen_features[k], gt_features[k]) + percep_loss += raw_comparison * self.layer_weights[k] + + # print("layer" + str(idx) + " has loss " + str(raw_comparison.cpu().numpy())) + # temp_store.append(float(raw_comparison.cpu().numpy())) + + percep_loss *= self.perceptual_weight + else: + percep_loss = None + + # 第一个是为了Debug purpose + if len(temp_store) != 0: + return temp_store, percep_loss + else: + return percep_loss + + + + +if __name__ == "__main__": + import torchvision.transforms as transforms + import cv2 + import collections + + + loss = Anime_PerceptualLoss({"0": 0.5, "4_2_conv3": 20, "5_3_conv3": 30, "6_5_conv3": 1, "7_2_conv3": 1}).cuda() + + + store = collections.defaultdict(list) + for img_name in sorted(os.listdir('datasets/train_gen/')): + gen = transforms.ToTensor()(cv2.imread('datasets/train_gen/'+img_name)).cuda() + gt = transforms.ToTensor()(cv2.imread('datasets/train_hr_anime_usm/'+img_name)).cuda() + temp_store, _ = loss(gen, gt) + + for idx in range(len(temp_store)): + store[idx].append(temp_store[idx]) + + for idx in range(len(store)): + print("Average layer" + str(idx) + " has loss " + str(sum(store[idx]) / len(store[idx]))) + + + # model = loss.vgg + # pytorch_total_params = sum(p.numel() for p in model.parameters()) + # print(f"Perceptual VGG has param {pytorch_total_params//1000000} M params") \ No newline at end of file diff --git a/loss/gan_loss.py b/loss/gan_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..2d13f4cbc008f3e6ab2ec4755ab29f8fdcf836ad --- /dev/null +++ b/loss/gan_loss.py @@ -0,0 +1,108 @@ +# -*- coding: utf-8 -*- + +import math +import torch +from torch import autograd as autograd +from torch import nn as nn +from torch.nn import functional as F +import cv2 +import numpy as np +import os, sys + +root_path = os.path.abspath('.') +sys.path.append(root_path) + +from loss.perceptual_loss import VGGFeatureExtractor +from degradation.ESR.utils import np2tensor, tensor2np, save_img + +class GANLoss(nn.Module): + """Define GAN loss. + From Real-ESRGAN code + Args: + gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'. + real_label_val (float): The value for real label. Default: 1.0. + fake_label_val (float): The value for fake label. Default: 0.0. + loss_weight (float): Loss weight. Default: 1.0. + Note that loss_weight is only for generators; and it is always 1.0 + for discriminators. + """ + + def __init__(self, gan_type="vanilla", real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0): + super(GANLoss, self).__init__() + self.loss_weight = loss_weight + self.real_label_val = real_label_val + self.fake_label_val = fake_label_val + + # gan type is vanilla usually + if gan_type == "vanilla": + self.loss = nn.BCEWithLogitsLoss() + elif gan_type == "lsgan": + self.loss = nn.MSELoss() + else: + raise NotImplementedError("We didn't implement this GAN type") + + + # Skip wgan part here + + + def get_target_label(self, input, target_is_real): + """Get target label. + + Args: + input (Tensor): Input tensor. + target_is_real (bool): Whether the target is real or fake. + + Returns: + (bool | Tensor): Target tensor. Return bool for wgan, otherwise, + return Tensor. + """ + + + target_val = (self.real_label_val if target_is_real else self.fake_label_val) + return input.new_ones(input.size()) * target_val + + def forward(self, input, target_is_real, is_disc=False): + """ + Args: + input (Tensor): The input for the loss module, i.e., the network + prediction. + target_is_real (bool): Whether the targe is real or fake. + is_disc (bool): Whether the loss for discriminators or not. + Default: False. + + Returns: + Tensor: GAN loss value. + """ + target_label = self.get_target_label(input, target_is_real) + + loss = self.loss(input, target_label) + + # loss_weight is always 1.0 for discriminators + return loss if is_disc else loss * self.loss_weight + + +class MultiScaleGANLoss(GANLoss): + """ + MultiScaleGANLoss accepts a list of predictions + """ + + def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0): + super(MultiScaleGANLoss, self).__init__(gan_type, real_label_val, fake_label_val, loss_weight) + + def forward(self, input, target_is_real, is_disc=False): + """ + The input is a list of tensors, or a list of (a list of tensors) + """ + if isinstance(input, list): + loss = 0 + for pred_i in input: + if isinstance(pred_i, list): + # Only compute GAN loss for the last layer + # in case of multiscale feature matching + pred_i = pred_i[-1] + # Safe operation: 0-dim tensor calling self.mean() does nothing + loss_tensor = super().forward(pred_i, target_is_real, is_disc).mean() + loss += loss_tensor + return loss / len(input) + else: + return super().forward(input, target_is_real, is_disc) \ No newline at end of file diff --git a/loss/perceptual_loss.py b/loss/perceptual_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..eece8a6f4df176f3bc6f123b2e9fac1f4f306adb --- /dev/null +++ b/loss/perceptual_loss.py @@ -0,0 +1,262 @@ +# -*- coding: utf-8 -*- + +import os +import torch +from collections import OrderedDict +from torch import nn as nn +from torchvision.models import vgg as vgg + + +NAMES = { + 'vgg11': [ + 'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2', 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', + 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', + 'pool5' + ], + 'vgg13': [ + 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', + 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', + 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5' + ], + 'vgg16': [ + 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', + 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', + 'relu4_2', 'conv4_3', 'relu4_3', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', + 'pool5' + ], + 'vgg19': [ + 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', + 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 'conv4_1', + 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1', + 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5' + ] +} + +def insert_bn(names): + """Insert bn layer after each conv. + + Args: + names (list): The list of layer names. + + Returns: + list: The list of layer names with bn layers. + """ + names_bn = [] + for name in names: + names_bn.append(name) + if 'conv' in name: + position = name.replace('conv', '') + names_bn.append('bn' + position) + return names_bn + + +class VGGFeatureExtractor(nn.Module): + """VGG network for feature extraction. + + In this implementation, we allow users to choose whether use normalization + in the input feature and the type of vgg network. Note that the pretrained + path must fit the vgg type. + + Args: + layer_name_list (list[str]): Forward function returns the corresponding + features according to the layer_name_list. + Example: {'relu1_1', 'relu2_1', 'relu3_1'}. + vgg_type (str): Set the type of vgg network. Default: 'vgg19'. + use_input_norm (bool): If True, normalize the input image. Importantly, + the input feature must in the range [0, 1]. Default: True. + range_norm (bool): If True, norm images with range [-1, 1] to [0, 1]. + Default: False. + requires_grad (bool): If true, the parameters of VGG network will be + optimized. Default: False. + remove_pooling (bool): If true, the max pooling operations in VGG net + will be removed. Default: False. + pooling_stride (int): The stride of max pooling operation. Default: 2. + """ + + def __init__(self, + layer_name_list, + vgg_type, + use_input_norm=True, + range_norm=False, + requires_grad=False, + remove_pooling=False, + pooling_stride=2): + super(VGGFeatureExtractor, self).__init__() + + self.layer_name_list = layer_name_list + self.use_input_norm = use_input_norm + self.range_norm = range_norm + + self.names = NAMES[vgg_type.replace('_bn', '')] + if 'bn' in vgg_type: + self.names = insert_bn(self.names) + + # only borrow layers that will be used to avoid unused params + max_idx = 0 + for v in layer_name_list: + idx = self.names.index(v) + if idx > max_idx: + max_idx = idx + + VGG_PRETRAIN_PATH = {"vgg19": "pre_trinaed/vgg19-dcbb9e9d.pth", + "vgg16": "pre_trinaed/vgg16-397923af.pth", + "vgg13": "pre_trinaed/vgg13-19584684.pth"} + if os.path.exists(VGG_PRETRAIN_PATH[vgg_type]): + vgg_net = getattr(vgg, vgg_type)(pretrained=False) + state_dict = torch.load(VGG_PRETRAIN_PATH[vgg_type], map_location=lambda storage, loc: storage) + vgg_net.load_state_dict(state_dict) + else: + vgg_net = getattr(vgg, vgg_type)(pretrained=True) + + features = vgg_net.features[:max_idx + 1] + + modified_net = OrderedDict() + for k, v in zip(self.names, features): + if 'pool' in k: + # if remove_pooling is true, pooling operation will be removed + if remove_pooling: + continue + else: + # in some cases, we may want to change the default stride + modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride) + else: + modified_net[k] = v + + self.vgg_net = nn.Sequential(modified_net) + + if not requires_grad: + self.vgg_net.eval() + for param in self.parameters(): + param.requires_grad = False + + + if self.use_input_norm: + # the mean is for image with range [0, 1] + self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) + # the std is for image with range [0, 1] + self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) + + def forward(self, x): + """Forward function. + + Args: + x (Tensor): Input tensor with shape (n, c, h, w). + + Returns: + Tensor: Forward results. + """ + if self.range_norm: + x = (x + 1) / 2 + if self.use_input_norm: + x = (x - self.mean) / self.std + + output = {} + for key, layer in self.vgg_net._modules.items(): + x = layer(x) + if key in self.layer_name_list: + output[key] = x.clone() + + return output + + def get_params_num(self): + inp = torch.rand(1, 3, 400, 400) + + pytorch_total_params = sum(p.numel() for p in self.vgg_net.parameters()) + + # count_ops(self.vgg_net, inp) + print(f"pathGAN has param {pytorch_total_params//1000} K params") + + + + +class PerceptualLoss(nn.Module): + """Perceptual loss with commonly used style loss. + + Args: + layer_weights (dict): The weight for each layer of vgg feature. + Here is an example: {'conv5_4': 1.}, which means the conv5_4 + feature layer (before relu5_4) will be extracted with weight + 1.0 in calculating losses. + vgg_type (str): The type of vgg network used as feature extractor. + Default: 'vgg19'. + use_input_norm (bool): If True, normalize the input image in vgg. + Default: True. + range_norm (bool): If True, norm images with range [-1, 1] to [0, 1]. + Default: False. + perceptual_weight (float): If `perceptual_weight > 0`, the perceptual + loss will be calculated and the loss will multiplied by the + weight. Default: 1.0. + style_weight (float): If `style_weight > 0`, the style loss will be + calculated and the loss will multiplied by the weight. + Default: 0. + criterion (str): Criterion used for perceptual loss. Default: 'l1'. + """ + + def __init__(self, + layer_weights, + vgg_type, + use_input_norm=True, + range_norm=False, + perceptual_weight=1.0, + style_weight=0., + criterion='l1'): + super(PerceptualLoss, self).__init__() + self.perceptual_weight = perceptual_weight + self.layer_weights = layer_weights + self.vgg = VGGFeatureExtractor( + layer_name_list=list(layer_weights.keys()), + vgg_type=vgg_type, + use_input_norm=use_input_norm, + range_norm=range_norm).cuda() + + self.criterion_type = criterion + self.criterion = torch.nn.L1Loss() + self.vgg_type = vgg_type + + + def forward(self, x, gt): + """Forward function. + + Args: + x (Tensor): Input tensor with shape (n, c, h, w). + gt (Tensor): Ground-truth tensor with shape (n, c, h, w). + + Returns: + Tensor: Forward results. + """ + # extract vgg features + x_features = self.vgg(x) + gt_features = self.vgg(gt.detach()) + + # calculate perceptual loss + if self.perceptual_weight > 0: + percep_loss = 0 + for k in x_features.keys(): + # save_img(x_features[k], str(k) + "_out") + # save_img(gt_features[k], str(k) + "_gt") + layer_weight = self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k] + percep_loss += layer_weight + percep_loss *= self.perceptual_weight + else: + percep_loss = None + + # No style_loss + + return percep_loss + + + +if __name__ == "__main__": + layer_weights = {'conv1_2': 0.1, 'conv2_2': 0.1, 'conv3_4': 1, 'conv4_4': 1, 'conv5_4': 1} + vgg_type = 'vgg19' + loss = PerceptualLoss(layer_weights, vgg_type, perceptual_weight=1.0).cuda() + + import torchvision.transforms as transforms + import cv2 + gen = transforms.ToTensor()(cv2.imread('datasets/train_gen/img_00002.png')).cuda() + gt = transforms.ToTensor()(cv2.imread('datasets/train_hr_anime_usm_720p/img_00002.png')).cuda() + loss(gen, gt) + + # model = loss.vgg + # pytorch_total_params = sum(p.numel() for p in model.parameters()) + # print(f"Perceptual VGG has param {pytorch_total_params//1000000} M params") \ No newline at end of file diff --git a/loss/pixel_loss.py b/loss/pixel_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..85376f9e3d0734398ba4731634b1fddfc152bacb --- /dev/null +++ b/loss/pixel_loss.py @@ -0,0 +1,139 @@ +# -*- coding: utf-8 -*- + +import os +import torch +from torch import nn as nn +import torch.nn.functional as F + + +class PixelLoss(nn.Module): + def __init__(self) -> None: + super(PixelLoss, self).__init__() + + self.criterion = torch.nn.L1Loss().cuda() # its default will take the mean of this batch + + def forward(self, gen_hr, org_hr, batch_idx): + + # Calculate general PSNR + pixel_loss = self.criterion(gen_hr, org_hr) + + return pixel_loss + + +class L1_Charbonnier_loss(nn.Module): + """L1 Charbonnierloss.""" + def __init__(self): + super(L1_Charbonnier_loss, self).__init__() + self.eps = 1e-6 # already use square root + + def forward(self, X, Y, batch_idx): + diff = torch.add(X, -Y) + error = torch.sqrt(diff * diff + self.eps) + loss = torch.mean(error) + return loss + + + +""" +Created on Thu Dec 3 00:28:15 2020 +@author: Yunpeng Li, Tianjin University +""" +class MS_SSIM_L1_LOSS(nn.Module): + # Have to use cuda, otherwise the speed is too slow. + def __init__(self, alpha, + gaussian_sigmas=[0.5, 1.0, 2.0, 4.0, 8.0], + data_range = 1.0, + K=(0.01, 0.4), + compensation=1.0, + cuda_dev=0,): + super(MS_SSIM_L1_LOSS, self).__init__() + self.DR = data_range + self.C1 = (K[0] * data_range) ** 2 + self.C2 = (K[1] * data_range) ** 2 + self.pad = int(2 * gaussian_sigmas[-1]) + self.alpha = alpha + self.compensation=compensation + filter_size = int(4 * gaussian_sigmas[-1] + 1) + g_masks = torch.zeros((3*len(gaussian_sigmas), 1, filter_size, filter_size)) + for idx, sigma in enumerate(gaussian_sigmas): + # r0,g0,b0,r1,g1,b1,...,rM,gM,bM + g_masks[3*idx+0, 0, :, :] = self._fspecial_gauss_2d(filter_size, sigma) + g_masks[3*idx+1, 0, :, :] = self._fspecial_gauss_2d(filter_size, sigma) + g_masks[3*idx+2, 0, :, :] = self._fspecial_gauss_2d(filter_size, sigma) + self.g_masks = g_masks.cuda(cuda_dev) + + from torch.utils.tensorboard import SummaryWriter + self.writer = SummaryWriter() + + def _fspecial_gauss_1d(self, size, sigma): + """Create 1-D gauss kernel + Args: + size (int): the size of gauss kernel + sigma (float): sigma of normal distribution + + Returns: + torch.Tensor: 1D kernel (size) + """ + coords = torch.arange(size).to(dtype=torch.float) + coords -= size // 2 + g = torch.exp(-(coords ** 2) / (2 * sigma ** 2)) + g /= g.sum() + return g.reshape(-1) + + def _fspecial_gauss_2d(self, size, sigma): + """Create 2-D gauss kernel + Args: + size (int): the size of gauss kernel + sigma (float): sigma of normal distribution + + Returns: + torch.Tensor: 2D kernel (size x size) + """ + gaussian_vec = self._fspecial_gauss_1d(size, sigma) + return torch.outer(gaussian_vec, gaussian_vec) + + def forward(self, x, y, batch_idx): + ''' + Args: + x (tensor): the input for a tensor + y (tensor): the input for another tensor + batch_idx (int): the iteration now + Returns: + combined_loss (torch): loss value of L1 with MS-SSIM loss + ''' + + # b, c, h, w = x.shape + mux = F.conv2d(x, self.g_masks, groups=3, padding=self.pad) + muy = F.conv2d(y, self.g_masks, groups=3, padding=self.pad) + + mux2 = mux * mux + muy2 = muy * muy + muxy = mux * muy + + sigmax2 = F.conv2d(x * x, self.g_masks, groups=3, padding=self.pad) - mux2 + sigmay2 = F.conv2d(y * y, self.g_masks, groups=3, padding=self.pad) - muy2 + sigmaxy = F.conv2d(x * y, self.g_masks, groups=3, padding=self.pad) - muxy + + # l(j), cs(j) in MS-SSIM + l = (2 * muxy + self.C1) / (mux2 + muy2 + self.C1) # [B, 15, H, W] + cs = (2 * sigmaxy + self.C2) / (sigmax2 + sigmay2 + self.C2) + + lM = l[:, -1, :, :] * l[:, -2, :, :] * l[:, -3, :, :] + PIcs = cs.prod(dim=1) + + loss_ms_ssim = 1 - lM*PIcs # [B, H, W] + + loss_l1 = F.l1_loss(x, y, reduction='none') # [B, 3, H, W] + # average l1 loss in 3 channels + gaussian_l1 = F.conv2d(loss_l1, self.g_masks.narrow(dim=0, start=-3, length=3), + groups=3, padding=self.pad).mean(1) # [B, H, W] + + loss_mix = self.alpha * loss_ms_ssim + (1 - self.alpha) * gaussian_l1 / self.DR + loss_mix = self.compensation*loss_mix # Currently, we set compensation to 1.0 + + combined_loss = loss_mix.mean() + + self.writer.add_scalar('Loss/ms_ssim_loss-iteration', loss_ms_ssim.mean(), batch_idx) + self.writer.add_scalar('Loss/l1_loss-iteration', gaussian_l1.mean(), batch_idx) + + return combined_loss diff --git a/opt.py b/opt.py new file mode 100644 index 0000000000000000000000000000000000000000..e13404394a46d8f8448aa608768bf70a3a1ecae9 --- /dev/null +++ b/opt.py @@ -0,0 +1,251 @@ +# -*- coding: utf-8 -*- +import os + + +opt = {} +##################################################### Frequently Changed Setting ########################################################### +opt['description'] = "4x_GRL_paper" # Description to add to the log + +opt['architecture'] = "GRL" # "ESRNET" || "ESRGAN" || "GRL" || "GRLGAN" (GRL only support 4x) + + +# Essential Setting +opt['scale'] = 4 # In default, this is 4x +opt["full_patch_source"] = "../datasets_anime/APISR_dataset" # The HR image without cropping +opt["degrade_hr_dataset_path"] = "datasets/train_hr" # The cropped GT images +opt["train_hr_dataset_path"] = "datasets/train_hr_enhanced" # The cropped Pseudo-GT path (after hand-drawn line enhancement) +################################################################################################################################ + +# GPU setting +opt['CUDA_VISIBLE_DEVICES'] = '0' # '0' / '1' based on different GPU you have. +os.environ['CUDA_VISIBLE_DEVICES'] = opt['CUDA_VISIBLE_DEVICES'] + + +##################################################### Setting for General Training ############################################# + +# Dataset Setting +opt["lr_dataset_path"] = "datasets/train_lr" # Where you temporally store the LR synthetic images +opt['hr_size'] = 256 + + +# Loss function +opt['pixel_loss'] = "L1" # Usually it is "L1" + + +# Adam optimizer setting +opt["adam_beta1"] = 0.9 +opt["adam_beta2"] = 0.99 +opt['decay_gamma'] = 0.5 # Decay the learning rate per decay_iteration + + +# Miscellaneous Setting +opt['degradate_generation_freq'] = 1 # How frequent we degradate HR to LR (1: means Real-Time Degrade) [No need to change this] +opt['train_dataloader_workers'] = 5 # Number of workers for DataLoader +opt['checkpoints_freq'] = 50 # frequency to store checkpoints in the folder (unit: epoch) + + +################################################################################################################################# + + +# Add setting for different architecture (Please go through the model architecture you want!) +if opt['architecture'] == "ESRNET": + + # Setting for ESRNET Training + opt['ESR_blocks_num'] = 6 # How many RRDB blocks you need + opt['train_iterations'] = 500000 # Training Iterations (500K for large resolution large dataset overlap training) + opt['train_batch_size'] = 32 # + + # Learning Rate + opt["start_learning_rate"] = 0.0002 # Training Epoch, use the as Real-ESRGAN: 0.0001 - 0.0002 is ok, based on your need + opt['decay_iteration'] = 100000 # Decay iteration + opt['double_milestones'] = [] # Iteration based time you double your learning rate + + +elif opt['architecture'] == "ESRGAN": + + # Setting for ESRGAN Training + opt['ESR_blocks_num'] = 6 # How many RRDB blocks you need + opt['train_iterations'] = 200000 # Training Iterations + opt['train_batch_size'] = 32 # + + # Learning Rate + opt["start_learning_rate"] = 0.0001 # Training Epoch, use the as Real-ESRGAN: 0.0001 - 0.0002 is ok, based on your need + opt['decay_iteration'] = 100000 # Fixed decay gap + opt['double_milestones'] = [] # Just put this empty + + # Perceptual loss + opt["danbooru_perceptual_loss_weight"] = 0.5 # ResNet50 Danbooru Perceptual loss weight scale + opt["vgg_perceptual_loss_weight"] = 0.5 # VGG PhotoRealistic Perceptual loss weight scale + opt['train_perceptual_vgg_type'] = 'vgg19' # VGG16/19 (Just use 19 by default) + opt['train_perceptual_layer_weights'] = {'conv1_2': 0.1, 'conv2_2': 0.1, 'conv3_4': 1, 'conv4_4': 1, 'conv5_4': 1} # Middle-Layer weight for VGG + opt['Danbooru_layer_weights'] = {"0": 0.1, "4_2_conv3": 20, "5_3_conv3": 25, "6_5_conv3": 1, "7_2_conv3": 1} # Middle-Layer weight for ResNet + + # GAN loss + opt["discriminator_type"] = "PatchDiscriminator" # "PatchDiscriminator" || "UNetDiscriminator" + opt["gan_loss_weight"] = 0.2 # + + + +elif opt['architecture'] == "CUNET": + # Setting for CUNET Training + opt['train_iterations'] = 500000 # Training Iterations (700K for large resolution large dataset overlap training) + opt['train_batch_size'] = 16 + + opt["start_learning_rate"] = 0.0002 # Training Epoch, use the as Real-ESRGAN: 0.0001 - 0.0002 is ok, based on your need + opt['decay_iteration'] = 100000 # Decay iteration + opt['double_milestones'] = [] # Iteration based time you double your learning rate + + +elif opt['architecture'] == "CUGAN": + # Setting for ESRGAN Training + opt['ESR_blocks_num'] = 6 # How many RRDB blocks you need + opt['train_iterations'] = 200000 # Training Iterations + opt['train_batch_size'] = 16 + opt["start_learning_rate"] = 0.0001 # Training Epoch, use the as Real-ESRGAN: 0.0001 - 0.0002 is ok, based on your need + + opt["perceptual_loss_weight"] = 1.0 + opt['train_perceptual_vgg_type'] = 'vgg19' + opt['train_perceptual_layer_weights'] = {'conv1_2': 0.1, 'conv2_2': 0.1, 'conv3_4': 1, 'conv4_4': 1, 'conv5_4': 1} + opt['Danbooru_layer_weights'] = {"0": 0.1, "4_2_conv3": 20, "5_3_conv3": 25, "6_5_conv3": 1, "7_2_conv3": 1} # Middle-Layer weight for ResNet + opt["gan_loss_weight"] = 0.2 # This one is very important, Don't neglect it. Based on the paper, it should be 0.1 scale + + opt['decay_iteration'] = 100000 # Decay iteration + opt['double_milestones'] = [] # Iteration based time you double your learning rate + + +elif opt['architecture'] == "GRL": # L1 loss training version + # Setting for GRL Training + opt['model_size'] = "tiny2" # "tiny2" in default + + opt['train_iterations'] = 300000 # Training Iterations + opt['train_batch_size'] = 32 # 4x: 32 (256x256); 2x: 4? + + # Learning Rate + opt["start_learning_rate"] = 0.0002 # Training Epoch, use the as Real-ESRGAN: 0.0001 - 0.0002 is ok, based on your need + opt['decay_iteration'] = 100000 # Decay iteration + opt['double_milestones'] = [] # Iteration based time you double your learning rate (Just ignore this one) + + +elif opt['architecture'] == "GRLGAN": # L1 + Preceptual + Discriminator Loss version + # Setting for GRL Training + opt['model_size'] = "tiny2" # "small" || "tiny" || "tiny2" (Use tiny2 by default, No need to change) + + # Setting for GRL-GAN Traning + opt['train_iterations'] = 300000 # Training Iterations + opt['train_batch_size'] = 32 # 4x: 32 batch size (for 256x256); 2x: 4 + + # Learning Rate + opt["start_learning_rate"] = 0.0001 # Training Epoch, use the as Real-ESRGAN: 0.0001 - 0.0002 is ok, based on your need + opt['decay_iteration'] = 100000 # Fixed decay gap + opt['double_milestones'] = [] # Just put this empty + + # Perceptual loss + opt["danbooru_perceptual_loss_weight"] = 0.5 # ResNet50 Danbooru Perceptual loss weight scale + opt["vgg_perceptual_loss_weight"] = 0.5 # VGG PhotoRealistic Perceptual loss weight scale + opt['train_perceptual_vgg_type'] = 'vgg19' # VGG16/19 (Just use 19 by default) + opt['train_perceptual_layer_weights'] = {'conv1_2': 0.1, 'conv2_2': 0.1, 'conv3_4': 1, 'conv4_4': 1, 'conv5_4': 1} # Middle-Layer weight for VGG + opt['Danbooru_layer_weights'] = {"0": 0.1, "4_2_conv3": 20, "5_3_conv3": 25, "6_5_conv3": 1, "7_2_conv3": 1} # Middle-Layer weight for ResNet + + # GAN loss + opt["discriminator_type"] = "PatchDiscriminator" # "PatchDiscriminator" || "UNetDiscriminator" + opt["gan_loss_weight"] = 0.2 # + +else: + raise NotImplementedError("Please check you architecture option setting!") + + + + +# Basic setting for degradation +opt["degradation_batch_size"] = 128 # Degradation batch size +opt["augment_prob"] = 0.5 # Probability of augmenting (Flip, Rotate) the HR and LR dataset in dataset loading part + + +if opt['architecture'] in ["ESRNET", "ESRGAN", "GRL", "GRLGAN", "CUNET", "CUGAN"]: + # Parallel Process + opt['parallel_num'] = 8 # Multi-Processing num; Recommend 6 + + # Blur kernel1 + opt['kernel_range'] = [3, 11] + opt['kernel_list'] = ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'] + opt['kernel_prob'] = [0.45, 0.25, 0.12, 0.03, 0.12, 0.03] + opt['sinc_prob'] = 0.1 + opt['blur_sigma'] = [0.2, 3] + opt['betag_range'] = [0.5, 4] + opt['betap_range'] = [1, 2] + + # Blur kernel2 + opt['kernel_list2'] = ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'] + opt['kernel_prob2'] = [0.45, 0.25, 0.12, 0.03, 0.12, 0.03] + opt['sinc_prob2'] = 0.1 + opt['blur_sigma2'] = [0.2, 1.5] + opt['betag_range2'] = [0.5, 4] + opt['betap_range2'] = [1, 2] + + # The first degradation process + opt['resize_prob'] = [0.2, 0.7, 0.1] + opt['resize_range'] = [0.1, 1.2] # Was [0.15, 1.5] in Real-ESRGAN + opt['gaussian_noise_prob'] = 0.5 + opt['noise_range'] = [1, 30] + opt['poisson_scale_range'] = [0.05, 3] + opt['gray_noise_prob'] = 0.4 + opt['jpeg_range'] = [30, 95] + + # The second degradation process + opt['second_blur_prob'] = 0.8 + opt['resize_prob2'] = [0.2, 0.7, 0.1] # [up, down, keep] Resize Probability + opt['resize_range2'] = [0.15, 1.2] + opt['gaussian_noise_prob2'] = 0.5 + opt['noise_range2'] = [1, 25] + opt['poisson_scale_range2'] = [0.05, 2.5] + opt['gray_noise_prob2'] = 0.4 + + # Other common settings + opt['resize_options'] = ['area', 'bilinear', 'bicubic'] # Should be supported by F.interpolate + + + # First image compression + opt['compression_codec1'] = ["jpeg", "webp", "heif", "avif"] # Compression codec: heif is the intra frame version of HEVC (H.265) and avif is the intra frame version of AV1 + opt['compression_codec_prob1'] = [0.4, 0.6, 0.0, 0.0] + + # Specific Setting + opt["jpeg_quality_range1"] = [20, 95] + opt["webp_quality_range1"] = [20, 95] + opt["webp_encode_speed1"] = [0, 6] + opt["heif_quality_range1"] = [30, 100] + opt["heif_encode_speed1"] = [0, 6] # Useless now + opt["avif_quality_range1"] = [30, 100] + opt["avif_encode_speed1"] = [0, 6] # Useless now + + + ######################################## Setting for Degradation with Intra-Prediction ######################################## + opt['compression_codec2'] = ["jpeg", "webp", "avif", "mpeg2", "mpeg4", "h264", "h265"] # Compression codec: similar to VCISR but more intense degradation settings + opt['compression_codec_prob2'] = [0.06, 0.1, 0.1, 0.12, 0.12, 0.3, 0.2] + + # Image compression setting + opt["jpeg_quality_range2"] = [20, 95] + + opt["webp_quality_range2"] = [20, 95] + opt["webp_encode_speed2"] = [0, 6] + + opt["avif_quality_range2"] = [20, 95] + opt["avif_encode_speed2"] = [0, 6] # Useless now + + # Video compression I-Frame setting + opt['h264_crf_range2'] = [23, 38] + opt['h264_preset_mode2'] = ["slow", "medium", "fast", "faster", "superfast"] + opt['h264_preset_prob2'] = [0.05, 0.35, 0.3, 0.2, 0.1] + + opt['h265_crf_range2'] = [28, 42] + opt['h265_preset_mode2'] = ["slow", "medium", "fast", "faster", "superfast"] + opt['h265_preset_prob2'] = [0.05, 0.35, 0.3, 0.2, 0.1] + + opt['mpeg2_quality2'] = [8, 31] # linear scale 2-31 (the lower the higher quality) + opt['mpeg2_preset_mode2'] = ["slow", "medium", "fast", "faster", "superfast"] + opt['mpeg2_preset_prob2'] = [0.05, 0.35, 0.3, 0.2, 0.1] + + opt['mpeg4_quality2'] = [8, 31] # should be the same as mpeg2_quality2 + opt['mpeg4_preset_mode2'] = ["slow", "medium", "fast", "faster", "superfast"] + opt['mpeg4_preset_prob2'] = [0.05, 0.35, 0.3, 0.2, 0.1] + + diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..d03c4887fa6374a90a1ea9f777c7b3251b285e91 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,21 @@ +# Relatively static library (These libraries are comparatively stable, so the version here can be versatile, we attached the version we used in the experiments) +tqdm==4.66.1 +numpy==1.26.0 +torchsummary==1.5.1 +opencv-python==4.8.1.78 +scipy==1.11.3 +omegaconf==2.3.0 +fairscale==0.4.13 +timm==0.9.7 +pandas==2.1.1 +pillow==10.0.1 +requests==2.31.0 +pyyaml==6.0.1 +kornia==0.7.0 +gradio + + +# Relatively dynamic library (We think that these libraries may frequently modify their API, so it is better to use the same version as below) +pyiqa==0.1.7 +pthflops==0.4.2 +pillow-heif==0.13.0 diff --git a/scripts/anime_strong_usm.py b/scripts/anime_strong_usm.py new file mode 100644 index 0000000000000000000000000000000000000000..26b721a7827d64931dfd4d1e919f9480483bd238 --- /dev/null +++ b/scripts/anime_strong_usm.py @@ -0,0 +1,389 @@ +import cv2 +import argparse +import numpy as np +import copy +import os, sys, copy, shutil +from kornia import morphology as morph +import math +import gc, time +import torch +import torch.multiprocessing as mp +from torch.nn import functional as F +from multiprocessing import set_start_method + +# Import files from the local folder +root_path = os.path.abspath('.') +sys.path.append(root_path) +from opt import opt +from degradation.ESR.utils import filter2D, np2tensor, tensor2np + + + +# This config is found by the author +# modify if not the desired output +XDoG_config = dict( + size=0, + sigma=0.6, + eps=-15, + phi=10e8, + k=2.5, + gamma=0.97 +) + +# I wanted the gamma between [0.97, 0.98], but it depends on the image so I made it move randomly comment out if this is not needed +# In our case, black means background information; white means hand-drawn line +XDoG_config['gamma'] += 0.01 * np.random.rand(1) +dilation_kernel = torch.tensor([[1, 1, 1],[1, 1, 1],[1, 1, 1]]).cuda() +white_color_value = 1 # In binary map, 0 stands for black and 1 stands for white + + + +def DoG(image, size, sigma, k=1.6, gamma=1.): + g1 = cv2.GaussianBlur(image, (size, size), sigma) + g2 = cv2.GaussianBlur(image, (size, size), sigma*k) + return g1 - gamma * g2 + + +def XDoG(image, size, sigma, eps, phi, k=1.6, gamma=1.): + eps /= 255 + d = DoG(image, size, sigma, k, gamma) + d /= d.max() + e = 1 + np.tanh(phi * (d - eps)) + e[e >= 1] = 1 + return e + + + +class USMSharp(torch.nn.Module): + ''' + Basically, the same as Real-ESRGAN + ''' + + def __init__(self, type, radius=50, sigma=0): + # 感觉radius有点大 + super(USMSharp, self).__init__() + if radius % 2 == 0: + radius += 1 + self.radius = radius + kernel = cv2.getGaussianKernel(radius, sigma) + kernel = torch.FloatTensor(np.dot(kernel, kernel.transpose())).unsqueeze_(0).cuda() + self.register_buffer('kernel', kernel) + + self.type = type + + + def forward(self, img, weight=0.5, threshold=10, store=False): + # weight=0.5, threshold=10 + + if self.type == "cv2": + # pre-process cv2 type + img = np2tensor(img) + + blur = filter2D(img, self.kernel.cuda()) + if store: + cv2.imwrite("blur.png", tensor2np(blur)) + + residual = img - blur + if store: + cv2.imwrite("residual.png", tensor2np(residual)) + + mask = torch.abs(residual) * 255 > threshold + if store: + cv2.imwrite("mask.png", tensor2np(mask)) + + + mask = mask.float() + soft_mask = filter2D(mask, self.kernel.cuda()) + if store: + cv2.imwrite("soft_mask.png", tensor2np(soft_mask)) + + sharp = img + weight * residual + sharp = torch.clip(sharp, 0, 1) + if store: + cv2.imwrite("sharp.png", tensor2np(sharp)) + + output = soft_mask * sharp + (1 - soft_mask) * img + if self.type == "cv2": + output = tensor2np(output) + + return output + + + +def get_xdog_sketch_map(img_bgr, outlier_threshold): + + gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY) + sketch_map = gen_xdog_image(gray, outlier_threshold) + sketch_map = np.stack((sketch_map, sketch_map, sketch_map), axis=2) # concatenate to 3 dim + + return np.uint8(sketch_map) + + +def process_single_img(queue, usm_sharper, extra_sharpen_time, outlier_threshold): + + counter = 0 + while True: + counter += 1 + if counter == 10: + counter = 0 + gc.collect() + print("We will sleep here to clear memory") + time.sleep(5) + info = queue[0] + queue = queue[1:] + if info == None: + break + + img_dir, store_path = info + print("We are processing ", img_dir) + img = cv2.imread(img_dir) + + img = usm_sharper(img, store=False, threshold=10) + first_sharpened_img = copy.deepcopy(img) + + for _ in range(extra_sharpen_time): + # sketch_map = get_xdog_sketch_map(img_temp) + img = usm_sharper(img, store=False, threshold=10) + # img = (sharpened_img * sketch_map) + (org_img * (1-sketch_map)) + + sketch_map = get_xdog_sketch_map(img, outlier_threshold) + img = (img * sketch_map) + (first_sharpened_img * (1-sketch_map)) + + + cv2.imwrite(store_path, img) + + print("Finish all program") + + + +def outlier_removal(img, outlier_threshold): + ''' Remove outlier pixel after finding the sketch + Here, black(0) means background information; white(1) means hand-drawn line + ''' + + global_list = set() + h,w = img.shape + + def dfs(i, j): + ''' + Using Depth First Search to find the full area of mapping + ''' + if (i,j) in visited: + # If this is an already visited pixel, return + return + + if (i,j) in global_list: + # If it is already existed in the global list, return + return + + if i >= h or j >= w or i < 0 or j < 0: + # If it is out of boundary, return + return + + if img[i][j] == white_color_value: + visited.add((i,j)) + + # If it is over threshold, we won't remove them + if len(visited) >= 100: + return + + dfs(i+1, j) + dfs(i, j+1) + dfs(i-1, j) + dfs(i, j-1) + dfs(i-1, j-1) + dfs(i+1, j+1) + dfs(i-1, j+1) + dfs(i+1, j-1) + + return + + def bfs(i, j): + ''' + Using Breadth First Search to find the full area of mapping + ''' + if (i,j) in visited: + # If this is an already visited pixel, return + return + + if (i,j) in global_list: + # If it is already existed in the global list, return + return + + visited.add((i,j)) + if img[i][j] != white_color_value: + return + + queue = [(i, j)] + while queue: + base_row, base_col = queue.pop(0) + + for dx, dy in [(-1, -1), (-1, 0), (-1, 1), (0, -1), (0, 1), (1, -1), (1, 0), (1, 1)]: + row, col = base_row+dx, base_col+dy + + if (row, col) in visited: + # If this is an already visited pixel, continue + continue + + if (row, col) in global_list: + # If it is already existed in the global list, continue + continue + + if row >= h or col >= w or row < 0 or col < 0: + # If it is out of boundary, continue + continue + + if img[row][col] == white_color_value: + visited.add((row, col)) + queue.append((row, col)) + + + temp = np.copy(img) + for i in range(h): + for j in range(w): + if (i,j) in global_list: + continue + if temp[i][j] != white_color_value: + # We only consider white color (hand-drawn line) situation + continue + + global visited + visited = set() + + # USE depth/breadth first search to find neighbor white value + bfs(i, j) + + if len(visited) < outlier_threshold: + # If the number of white pixels counting all neighbors are less than the outlier_threshold, paint the whole region to black (0:background symbol) + for u, v in visited: + temp[u][v] = 0 + + # Add those searched line to global_list to speed up + for u, v in visited: + global_list.add((u, v)) + + return temp + + +def active_dilate(img): + def np2tensor(np_frame): + return torch.from_numpy(np.transpose(np_frame, (2, 0, 1))).unsqueeze(0).cuda().float()/255 + def tensor2np(tensor): + # tensor should be batch size1 and cannot be grayscale input + return (np.transpose(tensor.detach().cpu().numpy(), (1, 2, 0))) * 255 + + dilated_edge_map = morph.dilation(np2tensor(np.expand_dims(img, 2)), dilation_kernel) + + return tensor2np(dilated_edge_map[0]).squeeze(2) + + +def passive_dilate(img): + # IF there is 3 white pixel in 9 block, we will fill in + h,w = img.shape + + def detect_fill(i, j): + if img[i][j] == white_color_value: + return False + + def sanity_check(i, j): + if i >= h or j >= w or i < 0 or j < 0: + return False + + if img[i][j] == white_color_value: + return True + return False + + + num_white = sanity_check(i-1,j-1) + sanity_check(i-1,j) + sanity_check(i-1,j+1) + sanity_check(i,j-1) + sanity_check(i,j+1) + sanity_check(i+1,j-1) + sanity_check(i+1,j) + sanity_check(i+1,j+1) + if num_white >= 3: + return True + + + temp = np.copy(img) + for i in range(h): + for j in range(w): + global visited + visited = set() + + should_fill = detect_fill(i, j) + if should_fill: + temp[i][j] = 1 + + # return True to say that we need to remove it; else, we don't need to remove it + return temp + + +def gen_xdog_image(gray, outlier_threshold): + ''' + Returns: + dogged (numpy): binary map in range (1 stands for white pixel) + ''' + + dogged = XDoG(gray, **XDoG_config) + dogged = 1 - dogged # black white transform + + + # Remove unnecessary outlier + dogged = outlier_removal(dogged, outlier_threshold) + + # Dilate the image + dogged = passive_dilate(dogged) + + + return dogged + + + +if __name__ == "__main__": + + + # Parse variables available + parser = argparse.ArgumentParser() + parser.add_argument('-i', '--input_dir', type = str) + parser.add_argument('-o', '--store_dir', type = str) + parser.add_argument('--outlier_threshold', type = int, default=32) + args = parser.parse_args() + + input_dir = args.input_dir + store_dir = args.store_dir + outlier_threshold = args.outlier_threshold + + + print("We are handling Strong USM sharpening on hand-drawn line for Anime images!") + + + num_workers = 8 + extra_sharpen_time = 2 + + + if os.path.exists(store_dir): + shutil.rmtree(store_dir) + os.makedirs(store_dir) + + + dir_list = [] + for img_name in sorted(os.listdir(input_dir)): + input_path = os.path.join(input_dir, img_name) + output_path = os.path.join(store_dir, img_name) + dir_list.append((input_path, output_path)) + + length = len(dir_list) + + + # USM sharpener preparation + usm_sharper = USMSharp(type="cv2").cuda() + usm_sharper.share_memory() + + for idx in range(num_workers): + set_start_method('spawn', force=True) + + num = math.ceil(length / num_workers) + request_list = dir_list[:num] + request_list.append(None) + dir_list = dir_list[num:] + + # process_single_img(request_list, usm_sharper, extra_sharpen_time) # This is for debug purpose + p = mp.Process(target=process_single_img, args=(request_list, usm_sharper, extra_sharpen_time, outlier_threshold)) + p.start() + + print("Submitted all jobs!") \ No newline at end of file diff --git a/scripts/crop_images.py b/scripts/crop_images.py new file mode 100644 index 0000000000000000000000000000000000000000..0336337fb78906768bada6abcfe6f472ebad76d0 --- /dev/null +++ b/scripts/crop_images.py @@ -0,0 +1,174 @@ +# -*- coding: utf-8 -*- +import argparse +import cv2 +import numpy as np +import os, shutil +import sys +from multiprocessing import Pool +from os import path as osp +from tqdm import tqdm +import random +from collections import namedtuple + +# Import files from the local folder +root_path = os.path.abspath('.') +sys.path.append(root_path) +from degradation.ESR.usm_sharp import USMSharp + + +class worker: + def __init__(self, start_index=1): + # The index you want to start with + self.output_index = start_index + + def process(self, path, opt, usm_sharper): + ''' crop the image here (also do usm here) + Args: + path (str): path of the image + opt (dict): all setting in a dictionary + usm_sharper (class): usm sharpener + + Returns: + cropped_num (int): how many cropped images you have for this path + ''' + + crop_size = opt['crop_size'] # usually 400 + + # read image + img = cv2.imread(path) + height, width = img.shape[0:2] + + res_store = [] + crop_num = (height//crop_size)*(width//crop_size) + random_num = opt['crop_num_per_img'] + + # Use shift offset to make image more cover origional image size + shift_offset_h, shift_offset_w = 0, 0 + + if random_num == -1: + # We should select all sub-frames order by order (not randomly select here) + choices = [i for i in range(crop_num)] + shift_offset_h = 0 #random.randint(0, height - crop_size * (height//crop_size)) + shift_offset_w = 0 #random.randint(0, width - crop_size * (width//crop_size)) + else: + # Divide imgs by crop_size x crop_size and choose opt['crop_num_per_img'] num of them to avoid overlap + num = min(random_num, crop_num) + choices = random.sample(range(crop_num), num) + + for choice in choices: + row_num = (width//crop_size) + x, y = crop_size * (choice // row_num), crop_size * (choice % row_num) + # add offset + res_store.append((x, y)) + + + # Sharp the image before selection + if opt['usm_save_folder'] != None: + sharpened_img = usm_sharper(img) + + + for (h, w) in res_store: + cropped_img = img[h+shift_offset_h : h+crop_size+shift_offset_h, w+shift_offset_w : w+crop_size+shift_offset_w, ...] + cropped_img = np.ascontiguousarray(cropped_img) + cv2.imwrite(osp.join(opt['save_folder'], f'img_{self.output_index:06d}.png'), cropped_img, [cv2.IMWRITE_PNG_COMPRESSION, 0]) # Save in lossless mode + + # store the sharpened cropped image + if opt['usm_save_folder'] != None: + cropped_sharpened_img = sharpened_img[h+shift_offset_h : h+crop_size+shift_offset_h, w+shift_offset_w : w+crop_size+shift_offset_w, ...] + cropped_sharpened_img = np.ascontiguousarray(cropped_sharpened_img) + cv2.imwrite(osp.join(opt['usm_save_folder'], f'img_{self.output_index:06d}.png'), cropped_sharpened_img, [cv2.IMWRITE_PNG_COMPRESSION, 0]) + + self.output_index += 1 + + + cropped_num = len(res_store) + return cropped_num + + +def extract_subimages(opt): + + # Input + input_folders = opt['input_folders'] + + # Make folders + save_folder = opt['save_folder'] + usm_save_folder = opt['usm_save_folder'] + + if osp.exists(save_folder): + print(f'Folder {save_folder} already exists. Program will delete this folder!') + shutil.rmtree(save_folder) + + os.makedirs(save_folder) + if usm_save_folder != None: + if osp.exists(usm_save_folder): + print(f'Folder {usm_save_folder} already exists. Program will delete this folder!') + shutil.rmtree(usm_save_folder) + + print("Use usm sharp") + os.makedirs(usm_save_folder) + + # USM + usm_sharper = USMSharp(type="cv2") + + # Iterate all datasets' folders + start_index = 1 + for input_folder in input_folders: + print(input_folder, start_index) + + # Scan all images + img_list = [] + for file in sorted(os.listdir(input_folder)): + if file.split(".")[-1] in ["png", "jpg"]: + img_list.append(osp.join(input_folder, file)) + + # Iterate can crop + obj = worker(start_index=start_index) # The start_index determines where you will start your naming your image (usually start from 0) + for path in img_list: + if random.random() < opt['select_rate']: + cropped_num = obj.process(path, opt, usm_sharper) + start_index += cropped_num + print(start_index, path) + else: + print("SKIP") + + + print('All processes done.') + + +def main(args): + opt = {} + + input_folders = [] + if type(args.input_folder) == str: + input_folders.append(args.input_folder) + else: + for input_folder in args.input_folder: + input_folders.append(input_folder) + print("input folders have ", input_folders) + + + opt['input_folders'] = input_folders + opt['save_folder'] = args.save_folder + opt['usm_save_folder'] = args.output_usm + opt['crop_size'] = args.crop_size + opt['crop_num_per_img'] = args.crop_num_per_img + opt['select_rate'] = args.select_rate + + # Extract subimages + extract_subimages(opt) + + +if __name__ == '__main__': + random.seed(777) # We setup a random seed such that all program get the same cropped images + + parser = argparse.ArgumentParser() + # Try to split image after default + parser.add_argument('-i', '--input_folder', nargs='+', type=str, default='datasets/all_Anime_hq_frames_resize', help='Input folder') # TODO: support multiple image input + parser.add_argument('-o', '--save_folder', type=str, default='datasets/train_hr', help='Output folder') + parser.add_argument('--output_usm', type=str, help='usm sharpened hr folder') + parser.add_argument('--crop_size', type=int, default=360, help='Crop size') + parser.add_argument('--select_rate', type=float, default=1, help='(0-1): Proportion to keep; 1 means to keep them all') + parser.add_argument('--crop_num_per_img', type=int, default=-1, help='Crop size (int); -1 means use all possible sub-frames') + args = parser.parse_args() + + main(args) \ No newline at end of file diff --git a/scripts/generate_lr_esr.py b/scripts/generate_lr_esr.py new file mode 100644 index 0000000000000000000000000000000000000000..981d12dd5d958d4864ba54af1c2e5e0874c4190a --- /dev/null +++ b/scripts/generate_lr_esr.py @@ -0,0 +1,189 @@ +# -*- coding: utf-8 -*- +import argparse +import cv2 +import torch +import os, shutil, time +import sys +from multiprocessing import Process, Queue +from os import path as osp +from tqdm import tqdm +import copy +import warnings +import gc + +warnings.filterwarnings("ignore") + +# import same folder files # +root_path = os.path.abspath('.') +sys.path.append(root_path) +from degradation.ESR.utils import np2tensor +from degradation.ESR.degradations_functionality import * +from degradation.ESR.diffjpeg import * +from degradation.degradation_esr import degradation_v1 +from opt import opt +os.environ['CUDA_VISIBLE_DEVICES'] = opt['CUDA_VISIBLE_DEVICES'] #'0,1' + + + +def crop_process(path, crop_size, lr_dataset_path, output_index): + ''' crop the image here (also do usm here) + Args: + path (str): Path of the image + crop_size (int): Crop size + lr_dataset_path (str): LR dataset path folder name + output_index (int): The index we used to store images + Returns: + output_index (int): The next index we need to use to store images + ''' + + # read image + img = cv2.imread(path) + height, width = img.shape[0:2] + + res_store = [] + crop_num = (height//crop_size)*(width//crop_size) + + # Use shift offset to make image more cover origional image size + shift_offset_h, shift_offset_w = 0, 0 + + + # Select all sub-frames order by order (not randomly select here) + choices = [i for i in range(crop_num)] + shift_offset_h = 0 #random.randint(0, height - crop_size * (height//crop_size)) + shift_offset_w = 0 #random.randint(0, width - crop_size * (width//crop_size)) + + + for choice in choices: + row_num = (width//crop_size) + x, y = crop_size * (choice // row_num), crop_size * (choice % row_num) + # add offset + res_store.append((x, y)) + + + + for (h, w) in res_store: + cropped_img = img[h+shift_offset_h : h+crop_size+shift_offset_h, w+shift_offset_w : w+crop_size+shift_offset_w, ...] + cropped_img = np.ascontiguousarray(cropped_img) + cv2.imwrite(osp.join(lr_dataset_path, f'img_{output_index:06d}.png'), cropped_img, [cv2.IMWRITE_PNG_COMPRESSION, 0]) # Save in lossless mode + + output_index += 1 + + return output_index + + + +def single_process(queue, opt, process_id): + ''' Multi Process instance + Args: + queue (multiprocessing.Queue): The input queue + opt (dict): The setting we need to use + process_id (int): The id we used to store temporary file + ''' + + # Initialization + obj_img = degradation_v1() + + while True: + items = queue.get() + if items == None: + break + input_path, store_path = items + + # Reset kernels in every degradation batch for ESR + obj_img.reset_kernels(opt) + + # Read all images and transform them to tensor + img_bgr = cv2.imread(input_path) + + out = np2tensor(img_bgr) # tensor + + # ESR Degradation execution + obj_img.degradate_process(out, opt, store_path, process_id, verbose = False) + + + +@torch.no_grad() +def generate_low_res_esr(org_opt, verbose=False): + ''' Generate LR dataset from HR ones by ESR degradation + Args: + org_opt (dict): The setting we will use + verbose (bool): Whether we print out some information + ''' + + # Prepare folders and files + input_folder = org_opt['input_folder'] + save_folder = org_opt['save_folder'] + if osp.exists(save_folder): + shutil.rmtree(save_folder) + if osp.exists("tmp"): + shutil.rmtree("tmp") + os.makedirs(save_folder) + os.makedirs("tmp") + if os.path.exists("datasets/degradation_log.txt"): + os.remove("datasets/degradation_log.txt") + + + # Scan all images + input_img_lists, output_img_lists = [], [] + for file in sorted(os.listdir(input_folder)): + input_img_lists.append(osp.join(input_folder, file)) + output_img_lists.append(osp.join("tmp", file)) + assert(len(input_img_lists) == len(output_img_lists)) + + + # Multi-Process Preparation + parallel_num = opt['parallel_num'] + queue = Queue() + + + # Save all files in the Queue + for idx in range(len(input_img_lists)): + # Find the needed img lists + queue.put((input_img_lists[idx], output_img_lists[idx])) + + + # Start the process + Processes = [] + for process_id in range(parallel_num): + p1 = Process(target=single_process, args =(queue, opt, process_id, )) + p1.start() + Processes.append(p1) + for _ in range(parallel_num): + queue.put(None) # Used to end the process + # print("All Process starts") + + # tqdm wait progress + for idx in tqdm(range(0, len(output_img_lists)), desc ="Degradation"): + while True: + if os.path.exists(output_img_lists[idx]): + break + time.sleep(0.1) + + # Merge all processes + for process in Processes: + process.join() + + + + # Crop images under folder "tmp" + output_index = 1 + for img_name in sorted(os.listdir("tmp")): + path = os.path.join("tmp", img_name) + output_index = crop_process(path, opt['hr_size']//opt['scale'], opt['save_folder'], output_index) + + + +def main(args): + opt['input_folder'] = args.input + opt['save_folder'] = args.output + + generate_low_res_esr(opt) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--input', type=str, default = opt["full_patch_source"], help='Input folder') + parser.add_argument('--output', type=str, default = opt["lr_dataset_path"], help='Output folder') + args = parser.parse_args() + + main(args) \ No newline at end of file diff --git a/scripts/prepare_datasets.sh b/scripts/prepare_datasets.sh new file mode 100644 index 0000000000000000000000000000000000000000..7f3fb75cd1533be9022ddb03f313ed6506af7768 --- /dev/null +++ b/scripts/prepare_datasets.sh @@ -0,0 +1,29 @@ +#!/bin/bash + +# Set up the path (The following three paths will be kept and used in training) +full_patch_source=../APISR_dataset +degrade_hr_dataset_path=datasets/train_hr +train_hr_dataset_path=datasets/train_hr_enhanced + + +# tmp path (No need to change, we will remove them at the end of process) +tmp_dir_720p=APISR_720p_tmp +tmp_dir_720p_4xcrop=APISR_720p_crop_tmp +tmp_enhanced_dir=APISR_sharpen_tmp + + +# Resize images and prepare usm sharpening in Anime +python tools/720P_resize.py -i $full_patch_source -o $tmp_dir_720p +python tools/4x_crop.py -i $tmp_dir_720p -o $tmp_dir_720p_4xcrop +python scripts/anime_strong_usm.py -i $tmp_dir_720p_4xcrop -o $tmp_enhanced_dir --outlier_threshold 32 + + +# Crop images to the target HR and degradate_HR dataset +python scripts/crop_images.py -i $tmp_dir_720p_4xcrop --crop_size 256 -o $degrade_hr_dataset_path +python scripts/crop_images.py -i $tmp_enhanced_dir --crop_size 256 -o $train_hr_dataset_path + + +# Clean unnecessary file +rm -rf $tmp_dir_720p +rm -rf $tmp_dir_720p_4xcrop +rm -rf $tmp_enhanced_dir \ No newline at end of file diff --git a/test_code/inference.py b/test_code/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..0be40b2904f252d921b52bb735d6a81749ab774f --- /dev/null +++ b/test_code/inference.py @@ -0,0 +1,143 @@ +''' + This is file is to execute the inference for a single image or a folder input +''' +import argparse +import os, sys, cv2, shutil, warnings +import torch +from torchvision.transforms import ToTensor +from torchvision.utils import save_image +warnings.simplefilter("default") +os.environ["PYTHONWARNINGS"] = "default" + + +# Import files from the local folder +root_path = os.path.abspath('.') +sys.path.append(root_path) +from test_code.test_utils import load_grl, load_rrdb, load_cunet + + + +@torch.no_grad # You must add these time, else it will have Out of Memory +def super_resolve_img(generator, input_path, output_path=None, weight_dtype=torch.float32, crop_for_4x=True): + ''' Super Resolve a low resolution image + Args: + generator (torch): the generator class that is already loaded + input_path (str): the path to the input lr images + output_path (str): the directory to store the generated images + weight_dtype (bool): the weight type (float32/float16) + crop_for_4x (bool): whether we crop the lr images to match 4x scale (needed for some situation) + ''' + print("Processing image {}".format(input_path)) + + # Read the image and do preprocess + img_lr = cv2.imread(input_path) + # Crop if needed + if crop_for_4x: + h, w, _ = img_lr.shape + if h % 4 != 0: + img_lr = img_lr[:4*(h//4),:,:] + if w % 4 != 0: + img_lr = img_lr[:,:4*(w//4),:] + + # Transform to tensor + img_lr = cv2.cvtColor(img_lr, cv2.COLOR_BGR2RGB) + img_lr = ToTensor()(img_lr).unsqueeze(0).cuda() # Use tensor format + img_lr = img_lr.to(dtype=weight_dtype) + + + # Model inference + print("lr shape is ", img_lr.shape) + super_resolved_img = generator(img_lr) + + # Store the generated result + with torch.cuda.amp.autocast(): + if output_path is not None: + save_image(super_resolved_img, output_path) + + # Empty the cache everytime you finish processing one image + torch.cuda.empty_cache() + + return super_resolved_img + + + + +if __name__ == "__main__": + + # Fundamental setting + parser = argparse.ArgumentParser() + parser.add_argument('--input_dir', type = str, default = '__assets__/lr_inputs', help="Can be either single image input or a folder input") + parser.add_argument('--model', type = str, default = 'GRL', help=" 'GRL' || 'RRDB' (for ESRNET & ESRGAN) || 'CUNET' (for Real-ESRGAN) ") + parser.add_argument('--scale', type = int, default = 4, help="Up scaler factor") + parser.add_argument('--weight_path', type = str, default = 'pretrained/4x_APISR_GRL_GAN_generator.pth', help="Weight path directory, usually under saved_models folder") + parser.add_argument('--store_dir', type = str, default = 'sample_outputs', help="The folder to store the super-resolved images") + parser.add_argument('--float16_inference', type = bool, default = False, help="The folder to store the super-resolved images") # Currently, this is only supported in RRDB, there is some bug with GRL model + args = parser.parse_args() + + # Sample Command + # 4x GRL (Default): python test_code/inference.py --model GRL --scale 4 --weight_path pretrained/4x_APISR_GRL_GAN_generator.pth + # 2x RRDB: python test_code/inference.py --model RRDB --scale 2 --weight_path pretrained/2x_APISR_RRDB_GAN_generator.pth + + + # Read argument and prepare the folder needed + input_dir = args.input_dir + model = args.model + weight_path = args.weight_path + store_dir = args.store_dir + scale = args.scale + float16_inference = args.float16_inference + + + # Check the path of the weight + if not os.path.exists(weight_path): + print("we cannot locate weight path ", weight_path) + # TODO: I am not sure if I should automatically download weight from github release based on the upscale factor and model name. + os._exit(0) + + + # Prepare the store folder + if os.path.exists(store_dir): + shutil.rmtree(store_dir) + os.makedirs(store_dir) + + + + # Define the weight type + if float16_inference: + torch.backends.cudnn.benchmark = True + weight_dtype = torch.float16 + else: + weight_dtype = torch.float32 + + + # Load the model + if model == "GRL": + generator = load_grl(weight_path, scale=scale) # GRL for Real-World SR only support 4x upscaling + elif model == "RRDB": + generator = load_rrdb(weight_path, scale=scale) # Can be any size + generator = generator.to(dtype=weight_dtype) + + + # Take the input path and do inference + if os.path.isdir(store_dir): # If the input is a directory, we will iterate it + for filename in sorted(os.listdir(input_dir)): + input_path = os.path.join(input_dir, filename) + output_path = os.path.join(store_dir, filename) + # In default, we will automatically use crop to match 4x size + super_resolve_img(generator, input_path, output_path, weight_dtype, crop_for_4x=True) + + else: # If the input is a single image, we will process it directly and write on the same folder + filename = os.path.split(input_dir)[-1].split('.')[0] + output_path = os.path.join(store_dir, filename+"_"+str(scale)+"x.png") + # In default, we will automatically use crop to match 4x size + super_resolve_img(generator, input_dir, output_path, weight_dtype, crop_for_4x=True) + + + + + + + + + + \ No newline at end of file diff --git a/test_code/test_utils.py b/test_code/test_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b2d06aa601714d302a5ab91dff93a1cd947e633f --- /dev/null +++ b/test_code/test_utils.py @@ -0,0 +1,176 @@ +import os, sys +import torch + +# Import files from same folder +root_path = os.path.abspath('.') +sys.path.append(root_path) +from opt import opt +from architecture.rrdb import RRDBNet +from architecture.grl import GRL +from architecture.swinir import SwinIR +from architecture.cunet import UNet_Full + + +def load_rrdb(generator_weight_PATH, scale, print_options=False): + ''' A simpler API to load RRDB model from Real-ESRGAN + Args: + generator_weight_PATH (str): The path to the weight + scale (int): the scaling factor + print_options (bool): whether to print options to show what kinds of setting is used + Returns: + generator (torch): the generator instance of the model + ''' + + # Load the checkpoint + checkpoint_g = torch.load(generator_weight_PATH) + + # Find the generator weight + if 'params_ema' in checkpoint_g: + # For official ESRNET/ESRGAN weight + weight = checkpoint_g['params_ema'] + generator = RRDBNet(3, 3, scale=scale) # Default blocks num is 6 + + elif 'params' in checkpoint_g: + # For official ESRNET/ESRGAN weight + weight = checkpoint_g['params'] + generator = RRDBNet(3, 3, scale=scale) + + elif 'model_state_dict' in checkpoint_g: + # For my personal trained weight + weight = checkpoint_g['model_state_dict'] + generator = RRDBNet(3, 3, scale=scale) + + else: + print("This weight is not supported") + os._exit(0) + + + # Handle torch.compile weight key rename + old_keys = [key for key in weight] + for old_key in old_keys: + if old_key[:10] == "_orig_mod.": + new_key = old_key[10:] + weight[new_key] = weight[old_key] + del weight[old_key] + + generator.load_state_dict(weight) + generator = generator.eval().cuda() + + + # Print options to show what kinds of setting is used + if print_options: + if 'opt' in checkpoint_g: + for key in checkpoint_g['opt']: + value = checkpoint_g['opt'][key] + print(f'{key} : {value}') + + return generator + + +def load_cunet(generator_weight_PATH, scale, print_options=False): + ''' A simpler API to load CUNET model from Real-CUGAN + Args: + generator_weight_PATH (str): The path to the weight + scale (int): the scaling factor + print_options (bool): whether to print options to show what kinds of setting is used + Returns: + generator (torch): the generator instance of the model + ''' + # This func is deprecated now + + if scale != 2: + raise NotImplementedError("We only support 2x in CUNET") + + # Load the checkpoint + checkpoint_g = torch.load(generator_weight_PATH) + + # Find the generator weight + if 'model_state_dict' in checkpoint_g: + # For my personal trained weight + weight = checkpoint_g['model_state_dict'] + loss = checkpoint_g["lowest_generator_weight"] + if "iteration" in checkpoint_g: + iteration = checkpoint_g["iteration"] + else: + iteration = "NAN" + generator = UNet_Full() + # generator = torch.compile(generator)# torch.compile + print(f"the generator weight is {loss} at iteration {iteration}") + + else: + print("This weight is not supported") + os._exit(0) + + + # Handle torch.compile weight key rename + old_keys = [key for key in weight] + for old_key in old_keys: + if old_key[:10] == "_orig_mod.": + new_key = old_key[10:] + weight[new_key] = weight[old_key] + del weight[old_key] + + generator.load_state_dict(weight) + generator = generator.eval().cuda() + + + # Print options to show what kinds of setting is used + if print_options: + if 'opt' in checkpoint_g: + for key in checkpoint_g['opt']: + value = checkpoint_g['opt'][key] + print(f'{key} : {value}') + + return generator + +def load_grl(generator_weight_PATH, scale=4): + ''' A simpler API to load GRL model + Args: + generator_weight_PATH (str): The path to the weight + scale (int): Scale Factor (Usually Set as 4) + Returns: + generator (torch): the generator instance of the model + ''' + + # Load the checkpoint + checkpoint_g = torch.load(generator_weight_PATH) + + # Find the generator weight + if 'model_state_dict' in checkpoint_g: + weight = checkpoint_g['model_state_dict'] + + # GRL tiny model (Note: tiny2 version) + generator = GRL( + upscale = scale, + img_size = 64, + window_size = 8, + depths = [4, 4, 4, 4], + embed_dim = 64, + num_heads_window = [2, 2, 2, 2], + num_heads_stripe = [2, 2, 2, 2], + mlp_ratio = 2, + qkv_proj_type = "linear", + anchor_proj_type = "avgpool", + anchor_window_down_factor = 2, + out_proj_type = "linear", + conv_type = "1conv", + upsampler = "nearest+conv", # Change + ).cuda() + + else: + print("This weight is not supported") + os._exit(0) + + + generator.load_state_dict(weight) + generator = generator.eval().cuda() + + + num_params = 0 + for p in generator.parameters(): + if p.requires_grad: + num_params += p.numel() + print(f"Number of parameters {num_params / 10 ** 6: 0.2f}") + + + return generator diff --git a/tools/4x_crop.py b/tools/4x_crop.py new file mode 100644 index 0000000000000000000000000000000000000000..8efd9a2d3f88cea50b881a3d7135fc12a9a7ba3f --- /dev/null +++ b/tools/4x_crop.py @@ -0,0 +1,47 @@ +import os, sys, cv2, shutil, argparse + + + +if __name__ == "__main__": + + # Parse variables available + parser = argparse.ArgumentParser() + parser.add_argument('-i', '--input_dir', type = str) + parser.add_argument('-o', '--store_dir', type = str) + args = parser.parse_args() + + input_dir = args.input_dir + store_dir = args.store_dir + + + print("We are cropping the image for 4x scale such that it is suitable for video compression") + + # Check file + if os.path.exists(store_dir): + shutil.rmtree(store_dir) + os.makedirs(store_dir) + + # Process + for file_name in sorted(os.listdir(input_dir)): + need_reszie = False + source_path = os.path.join(input_dir, file_name) + destination_path = os.path.join(store_dir, file_name) + + img = cv2.imread(source_path) + h, w, c = img.shape + + if h % 8 != 0: + print("We need vertical resize") + need_reszie = True + img = img[:8 * (h // 8), :, :] + + if w % 8 != 0: + print("We need horizontal resize") + need_reszie = True + img = img[:, :8 * (w // 8), :] + + + if need_reszie: + cv2.imwrite(destination_path, img, [cv2.IMWRITE_PNG_COMPRESSION, 0]) + else: + shutil.copy(source_path, destination_path) \ No newline at end of file diff --git a/tools/720P_resize.py b/tools/720P_resize.py new file mode 100644 index 0000000000000000000000000000000000000000..06e0375b34b25a324fc4691b9155348866ea7c7f --- /dev/null +++ b/tools/720P_resize.py @@ -0,0 +1,44 @@ +import os, cv2, shutil, argparse + + +if __name__ == "__main__": + + # Parse variables available + parser = argparse.ArgumentParser() + parser.add_argument('-i', '--input_dir', type = str) + parser.add_argument('-o', '--store_dir', type = str) + args = parser.parse_args() + + input_dir = args.input_dir + store_dir = args.store_dir + + print("We are doing the 720p Resize check!") + + # File Check + if os.path.exists(store_dir): + shutil.rmtree(store_dir) + os.makedirs(store_dir) + + scale = 4 + num = 0 + for file_name in sorted(os.listdir(input_dir)): + source_path = os.path.join(input_dir, file_name) + destination_path = os.path.join(store_dir, file_name) + img = cv2.imread(source_path) + h,w,c = img.shape + + if h == 720: + # It is already 720P so we directly move them + shutil.copy(source_path, destination_path) + continue + elif h < 720: + print("It is weird that there is an image with height less than 720 ", file_name) + break + + # Else, here we need to resize them (All resize to 720P) + + new_w = int(w*(720/h)) + img_bicubic = cv2.resize(img, (new_w, 720), interpolation=cv2.INTER_CUBIC) + cv2.imwrite(os.path.join(store_dir, file_name), img_bicubic, [cv2.IMWRITE_PNG_COMPRESSION, 0]) + + print("The total resize num is ", num) \ No newline at end of file diff --git a/tools/clean_weight_info.py b/tools/clean_weight_info.py new file mode 100644 index 0000000000000000000000000000000000000000..59dccb9123a99e246307117ace751322a0d934b9 --- /dev/null +++ b/tools/clean_weight_info.py @@ -0,0 +1,34 @@ +''' + Clean uncessary information in the weight (*.pth) +''' +import torch + + +if __name__ == "__main__": + weight_path = "saved_models/esrgan_best_generator.pth" + store_path = "1x_APISR_RRDB_GAN_generator.pth" + + # Load the checkpoint + checkpoint_g = torch.load(weight_path) + keys = [] + for key in checkpoint_g: + keys.append(key) + print(key) + for key in keys: + if key != "model_state_dict": + del checkpoint_g[key] + + + # Access the weight + old_keys = [key for key in checkpoint_g['model_state_dict']] + for old_key in old_keys: + if old_key[:10] == "_orig_mod.": + new_key = old_key[10:] + checkpoint_g['model_state_dict'][new_key] = checkpoint_g['model_state_dict'][old_key] + del checkpoint_g['model_state_dict'][old_key] + + torch.save(checkpoint_g, store_path) + + + + diff --git a/tools/compress.py b/tools/compress.py new file mode 100644 index 0000000000000000000000000000000000000000..0d3435ecada9fc2e3832d7c3911d6a4c3f4db2e9 --- /dev/null +++ b/tools/compress.py @@ -0,0 +1,38 @@ +''' + This file is to help us make figure of CRF vs Preset in Video Compression +''' +import os, sys, shutil + +def compress(input_folder, codec, crf, preset): + video_store_name = "compressed.mp4" + store_dir = input_folder + "_crf" + str(crf) + "_" + preset + + if os.path.exists(video_store_name): + os.remove(video_store_name) + if os.path.exists(store_dir): + shutil.rmtree(store_dir) + os.makedirs(store_dir) + + + # Encode + os.system("ffmpeg -r 30 -f image2 -i " + input_folder + "/%d.png -vcodec " + codec + " -crf " + str(crf) + " -preset " + preset + " -pix_fmt yuv420p " + video_store_name) + + # Split to frames + os.system("ffmpeg -i " + video_store_name + " " + store_dir + "/test_%06d.png") + + + +if __name__ == "__main__": + input_folders = ["ReadySetGo", "Jockey"] + codec = "libx264" + crf_ranges = [25 + 5*i for i in range(6)] + preset_ranges = ["ultrafast", "veryfast", "fast", "medium", "slow", "veryslow", "placebo"] + + + for input_folder in input_folders: + for crf in crf_ranges: + for preset in preset_ranges: + print("We are handling {} with crf {} with preset {}".format(input_folder, crf, preset)) + compress(input_folder, codec, crf, preset) + + \ No newline at end of file diff --git a/tools/select_N_frame.py b/tools/select_N_frame.py new file mode 100644 index 0000000000000000000000000000000000000000..b89938332aeccaafa8753e907f583e222776d6b5 --- /dev/null +++ b/tools/select_N_frame.py @@ -0,0 +1,19 @@ +import os, sys, shutil + +parnet_dir = "/media/hikaridawn/w/AVC_train_all" +save_dir = "/media/hikaridawn/w/AVC_train_select5" +if os.path.exists(save_dir): + shutil.rmtree(save_dir) +os.makedirs(save_dir) + +select_num = 5 + +gap = 100//select_num + +for idx, img_name in enumerate(sorted(os.listdir(parnet_dir))): + if idx % gap != 0: + continue + + source_path = os.path.join(parnet_dir, img_name) + destination_path = os.path.join(save_dir, img_name) + shutil.copy(source_path, destination_path) \ No newline at end of file diff --git a/tools/video_dataset_select_frame.py b/tools/video_dataset_select_frame.py new file mode 100644 index 0000000000000000000000000000000000000000..67768c0d005d344e87299939f2df17433132f134 --- /dev/null +++ b/tools/video_dataset_select_frame.py @@ -0,0 +1,67 @@ +''' + The purpose of this file is to select first, second, and the last frame from the video datasets. +''' + +import os, sys, shutil, cv2 + +dirs = [ + # "../datasets/VideoLQ", + # "../datasets/REDS_blur_MPEG", + "../datasets_real/AVC-RealLQ", +] +store_dirs = [ + # "../datasets/VideoLQ_select", + # "../datasets/REDS_blur_MPEG_select", + "AVC", +] +crop_large_img = True # If the image is larger than 720p, we will first crop them +assert(len(dirs) == len(store_dirs)) + + + +# Iterate each dataset +for idx, parent_dir in enumerate(dirs): + print("This dir is ", parent_dir) + + # Make new dir + store_dir = store_dirs[idx] + if os.path.exists(store_dir): + shutil.rmtree(store_dir) + os.makedirs(store_dir) + + # Iterate to Sub Folder sequence + for sub_folder in sorted(os.listdir(parent_dir)): + folder_dir = os.path.join(parent_dir, sub_folder) + + # Find all image paths + image_paths = [] + for img_name in sorted(os.listdir(folder_dir)): + if img_name.split('.')[-1] in ['jpg', 'png']: + # Sometimes the folder may contain unneeded info, we don't consider them + image_paths.append(img_name) + image_paths = sorted(image_paths) + + # Find three frames (First, Middle, Last) + first, middle, last = image_paths[0], image_paths[len(image_paths)//2], image_paths[-1] + print("First, Middle, Last image name is ", first, middle, last) + + # Save the three images + for img_name in [first, middle, last]: + input_name = os.path.join(folder_dir, img_name) + + img = cv2.imread(input_name) + h, w, _ = img.shape + if crop_large_img and h*w > 720*1080: + # This means that this image is too big we need to crop them + print("We will use cropping for images that is too large") + crop1 = img[:,:w//2,:] + crop2 = img[:,w//2:,:] + + store_name1 = os.path.join(store_dir, sub_folder + "_crop1_"+ img_name) + store_name2 = os.path.join(store_dir, sub_folder + "_crop2_"+ img_name) + + cv2.imwrite(store_name1, crop1) + cv2.imwrite(store_name2, crop2) + else: + store_name = os.path.join(store_dir, sub_folder + "_" + img_name) + shutil.copy(input_name, store_name) diff --git a/tools/video_dataset_select_one.py b/tools/video_dataset_select_one.py new file mode 100644 index 0000000000000000000000000000000000000000..1784dea361eefa4c150a5d087ae98bfbb6559346 --- /dev/null +++ b/tools/video_dataset_select_one.py @@ -0,0 +1,21 @@ +import os, shutil + +dir = "../datasets/VideoLQ" +store_dir = "../datasets/VideoLQ_select_one" +if os.path.exists(store_dir): + shutil.rmtree(store_dir) +os.makedirs(store_dir) + + +search_idx = 0 +for sub_folder_name in sorted(os.listdir(dir)): + sub_folder_dir = os.path.join(dir, sub_folder_name) + for idx, img_name in enumerate(sorted(os.listdir(sub_folder_dir))): + if idx != search_idx: + continue + img_path = os.path.join(sub_folder_dir, img_name) + target_path = os.path.join(store_dir, img_name) + + shutil.copy(img_path, target_path) + + search_idx += 1 \ No newline at end of file diff --git a/train_code/train.py b/train_code/train.py new file mode 100644 index 0000000000000000000000000000000000000000..9496790d0b873a0ded9a1338d3cf4f570cadd58a --- /dev/null +++ b/train_code/train.py @@ -0,0 +1,110 @@ +# -*- coding: utf-8 -*- + +import argparse +import os, shutil, sys +import time +import warnings + +warnings.filterwarnings("ignore") + +# import from local folder +root_path = os.path.abspath('.') +sys.path.append(root_path) +from opt import opt + + +def storage_manage(): + if not os.path.exists("runs_last/"): + os.makedirs("runs_last/") + + # copy to the new address + new_address = "runs_last/"+str(int(time.time()))+"/" + shutil.copytree("runs/", new_address) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--auto_resume_closest', action='store_true') + parser.add_argument('--auto_resume_best', action='store_true') + parser.add_argument('--pretrained_path', type = str, default="") + + global args + args = parser.parse_args() + + + if args.auto_resume_closest and args.auto_resume_best: + print("you could only resume either nearest or best, not both") + os._exit(0) + + + + if not args.auto_resume_closest and not args.auto_resume_best: + # Restart tensorboard (delete all things under ./runs) + if os.path.exists("./runs"): + storage_manage() + shutil.rmtree("./runs") + + +def folder_prepare(): + def _make_folder(folder_name): + if not os.path.exists(folder_name): + os.makedirs(folder_name) + + def _delete_and_make_folder(folder_name): + if os.path.exists(folder_name): + shutil.rmtree(folder_name) + os.makedirs(folder_name) + + # The lists we care about + make_folder_name_lists = ["saved_models/", "saved_models/checkpoints/", "datasets/"] + delete_and_make_folder_name_lists = [] + + for folder_name in make_folder_name_lists: + _make_folder(folder_name) + + for folder_name in delete_and_make_folder_name_lists: + _delete_and_make_folder(folder_name) + + + +def process(options): + print(args) + start = time.time() + + # Switch based on the model architecture + if options['architecture'] == "ESRNET": + from train_esrnet import train_esrnet + obj = train_esrnet(options, args) + elif options['architecture'] == "ESRGAN": + from train_esrgan import train_esrgan + obj = train_esrgan(options, args) + elif options['architecture'] == "GRL": + from train_grl import train_grl + obj = train_grl(options, args) + elif options['architecture'] == "GRLGAN": + from train_grlgan import train_grlgan + obj = train_grlgan(options, args) + elif options['architecture'] == "CUNET": + from train_cunet import train_cunet + obj = train_cunet(options, args) + elif options['architecture'] == "CUGAN": + from train_cugan import train_cugan + obj = train_cugan(options, args) + else: + raise NotImplementedError("This is not a supported model architecture") + + + obj.run() + + total_time = time.time() - start + print("All programs spent {} hour {} min {} s".format(str(total_time//3600), str((total_time%3600)//60), str(total_time%3600))) + + +def main(): + parse_args() + + folder_prepare() + process(opt) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/train_code/train_cugan.py b/train_code/train_cugan.py new file mode 100644 index 0000000000000000000000000000000000000000..9623315261f26ce4f9ae491c9ae77f0e62a86077 --- /dev/null +++ b/train_code/train_cugan.py @@ -0,0 +1,71 @@ +# -*- coding: utf-8 -*- + +import sys +import os +import torch + +# import important files +root_path = os.path.abspath('.') +sys.path.append(root_path) +from architecture.cunet import UNet_Full +from architecture.discriminator import UNetDiscriminatorSN +from train_code.train_master import train_master + + + +class train_cugan(train_master): + def __init__(self, options, args) -> None: + super().__init__(options, args, "cugan", True) # Pass a model name unique code + + + def loss_init(self): + + # prepare pixel loss (Generator) + self.pixel_loss_load() + + # prepare perceptual loss + self.GAN_loss_load() + + + def call_model(self): + self.generator = UNet_Full().cuda() + # self.generator = torch.compile(self.generator).cuda() + self.discriminator = UNetDiscriminatorSN(3).cuda() + # self.discriminator = torch.compile(self.discriminator).cuda() + self.generator.train(); self.discriminator.train() + + + def run(self): + self.master_run() + + + def calculate_loss(self, gen_hr, imgs_hr): + + ###################### We have 3 losses on Generator ###################### + # Generator Pixel loss (l1 loss): generated vs. GT + l_g_pix = self.cri_pix(gen_hr, imgs_hr) + self.generator_loss += l_g_pix + self.weight_store["pixel_loss"] = l_g_pix + + + # Generator perceptual loss: generated vs. perceptual + l_g_percep_danbooru = self.cri_danbooru_perceptual(gen_hr, imgs_hr) + l_g_percep_vgg = self.cri_vgg_perceptual(gen_hr, imgs_hr) + l_g_percep = l_g_percep_danbooru + l_g_percep_vgg + self.generator_loss += l_g_percep + self.weight_store["perceptual_loss"] = l_g_percep + + + # Generator GAN loss label correction + fake_g_preds = self.discriminator(gen_hr) + l_g_gan = self.cri_gan(fake_g_preds, True, is_disc=False) # loss_weight (self.gan_loss_weight) is included + self.generator_loss += l_g_gan + self.weight_store["gan_loss"] = l_g_gan # Already with gan_loss_weight (0.2/1) + + + def tensorboard_report(self, iteration): + self.writer.add_scalar('Loss/train-Generator_Loss-Iteration', self.generator_loss, iteration) + self.writer.add_scalar('Loss/train-Pixel_Loss-Iteration', self.weight_store["pixel_loss"], iteration) + self.writer.add_scalar('Loss/train-Perceptual_Loss-Iteration', self.weight_store["perceptual_loss"], iteration) + self.writer.add_scalar('Loss/train-Discriminator_Loss-Iteration', self.weight_store["gan_loss"], iteration) + diff --git a/train_code/train_cunet.py b/train_code/train_cunet.py new file mode 100644 index 0000000000000000000000000000000000000000..56123561a649216facdcd0aea76b4729d0ef99e5 --- /dev/null +++ b/train_code/train_cunet.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- +import sys +import os +import torch + + +# Import important files +root_path = os.path.abspath('.') +sys.path.append(root_path) +from architecture.cunet import UNet_Full # This place need to adjust for different models +from train_code.train_master import train_master + + + +# Mixed precision training +scaler = torch.cuda.amp.GradScaler() + + +class train_cunet(train_master): + def __init__(self, options, args) -> None: + super().__init__(options, args, "cunet") # Pass a model name unique code + + + def loss_init(self): + # Prepare pixel loss + self.pixel_loss_load() + + + def call_model(self): + # Generator Prepare (Don't formet torch.compile if needed) + self.generator = UNet_Full().cuda() # Cunet only support 2x SR + # self.generator = torch.compile(self.generator).cuda() + self.generator.train() + + + def run(self): + self.master_run() + + + + def calculate_loss(self, gen_hr, imgs_hr): + + # Generator pixel loss (l1 loss): generated vs. GT + l_g_pix = self.cri_pix(gen_hr, imgs_hr, self.batch_idx) + self.weight_store["pixel_loss"] = l_g_pix + self.generator_loss += l_g_pix + + + def tensorboard_report(self, iteration): + # self.writer.add_scalar('Loss/train-Generator_Loss-Iteration', self.generator_loss, iteration) + self.writer.add_scalar('Loss/train-Pixel_Loss-Iteration', self.weight_store["pixel_loss"], iteration) diff --git a/train_code/train_esrgan.py b/train_code/train_esrgan.py new file mode 100644 index 0000000000000000000000000000000000000000..d9aaa503e15595bb53659fc527621fe9e727c3c1 --- /dev/null +++ b/train_code/train_esrgan.py @@ -0,0 +1,73 @@ +# -*- coding: utf-8 -*- + +import sys +import os +import torch + +# import important files +root_path = os.path.abspath('.') +sys.path.append(root_path) +from architecture.rrdb import RRDBNet +from architecture.discriminator import UNetDiscriminatorSN +from train_code.train_master import train_master + + + +class train_esrgan(train_master): + def __init__(self, options, args) -> None: + super().__init__(options, args, "esrgan", True) # Pass a model name unique code + + + def loss_init(self): + + # prepare pixel loss (Generator) + self.pixel_loss_load() + + # prepare perceptual loss + self.GAN_loss_load() + + + def call_model(self): + # Generator + self.generator = RRDBNet(3, 3, scale=self.options['scale'], num_block=self.options['ESR_blocks_num']).cuda() + # self.generator = torch.compile(self.generator).cuda() + self.discriminator = UNetDiscriminatorSN(3).cuda() + # self.discriminator = torch.compile(self.discriminator).cuda() + self.generator.train(); self.discriminator.train() + + + def run(self): + self.master_run() + + + + def calculate_loss(self, gen_hr, imgs_hr): + + ###################### We have 3 losses on Generator ###################### + # Generator Pixel loss (l1 loss): generated vs. GT + l_g_pix = self.cri_pix(gen_hr, imgs_hr) + self.generator_loss += l_g_pix + self.weight_store["pixel_loss"] = l_g_pix + + + # Generator perceptual loss: generated vs. perceptual + l_g_percep_danbooru = self.cri_danbooru_perceptual(gen_hr, imgs_hr) + l_g_percep_vgg = self.cri_vgg_perceptual(gen_hr, imgs_hr) + l_g_percep = l_g_percep_danbooru + l_g_percep_vgg + self.generator_loss += l_g_percep + self.weight_store["perceptual_loss"] = l_g_percep + + + # Generator GAN loss label correction + fake_g_preds = self.discriminator(gen_hr) + l_g_gan = self.cri_gan(fake_g_preds, True, is_disc=False) # loss_weight (self.gan_loss_weight) is included + self.generator_loss += l_g_gan + self.weight_store["gan_loss"] = l_g_gan # Already with gan_loss_weight (0.2/1) + + + def tensorboard_report(self, iteration): + self.writer.add_scalar('Loss/train-Generator_Loss-Iteration', self.generator_loss, iteration) + self.writer.add_scalar('Loss/train-Pixel_Loss-Iteration', self.weight_store["pixel_loss"], iteration) + self.writer.add_scalar('Loss/train-Perceptual_Loss-Iteration', self.weight_store["perceptual_loss"], iteration) + self.writer.add_scalar('Loss/train-Discriminator_Loss-Iteration', self.weight_store["gan_loss"], iteration) + diff --git a/train_code/train_esrnet.py b/train_code/train_esrnet.py new file mode 100644 index 0000000000000000000000000000000000000000..a67a99f152988b6139b975a83fa1fe476c333758 --- /dev/null +++ b/train_code/train_esrnet.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- +import sys +import os +import torch + + +# Import important files +root_path = os.path.abspath('.') +sys.path.append(root_path) +from architecture.rrdb import RRDBNet +from train_code.train_master import train_master + + + +# Mixed precision training +scaler = torch.cuda.amp.GradScaler() + + +class train_esrnet(train_master): + def __init__(self, options, args) -> None: + super().__init__(options, args, "esrnet") # Pass a model name unique code + + + def loss_init(self): + # Prepare pixel loss + self.pixel_loss_load() + + + def call_model(self): + # Generator Prepare (Don't formet torch.compile if needed) + self.generator = RRDBNet(3, 3, scale=self.options['scale'], num_block=self.options['ESR_blocks_num']).cuda() + # self.generator = torch.compile(self.generator).cuda() + self.generator.train() + + + def run(self): + self.master_run() + + + + def calculate_loss(self, gen_hr, imgs_hr): + + # Generator pixel loss (l1 loss): generated vs. GT + l_g_pix = self.cri_pix(gen_hr, imgs_hr, self.batch_idx) + self.weight_store["pixel_loss"] = l_g_pix + self.generator_loss += l_g_pix + + + def tensorboard_report(self, iteration): + # self.writer.add_scalar('Loss/train-Generator_Loss-Iteration', self.generator_loss, iteration) + self.writer.add_scalar('Loss/train-Pixel_Loss-Iteration', self.weight_store["pixel_loss"], iteration) diff --git a/train_code/train_grl.py b/train_code/train_grl.py new file mode 100644 index 0000000000000000000000000000000000000000..0900c63db415a24d6b780d02420db5e0618ea7e7 --- /dev/null +++ b/train_code/train_grl.py @@ -0,0 +1,115 @@ +# -*- coding: utf-8 -*- +import sys +import os +import torch + + +# Import files from the local folder +root_path = os.path.abspath('.') +sys.path.append(root_path) +from opt import opt +from architecture.grl import GRL # This place need to adjust for different models +from train_code.train_master import train_master + + + +# Mixed precision training +scaler = torch.cuda.amp.GradScaler() + + +class train_grl(train_master): + def __init__(self, options, args) -> None: + super().__init__(options, args, "grl") # Pass a model name unique code + + + def loss_init(self): + # Prepare pixel loss + self.pixel_loss_load() + + + def call_model(self): + patch_size = 144 + window_size = 8 + + if opt['model_size'] == "small": + # GRL small model + self.generator = GRL( + upscale = opt['scale'], + img_size = patch_size, + window_size = 8, + depths = [4, 4, 4, 4], + embed_dim = 128, + num_heads_window = [2, 2, 2, 2], + num_heads_stripe = [2, 2, 2, 2], + mlp_ratio = 2, + qkv_proj_type = "linear", + anchor_proj_type = "avgpool", + anchor_window_down_factor = 2, + out_proj_type = "linear", + conv_type = "1conv", + upsampler = "pixelshuffle", + ).cuda() + + elif opt['model_size'] == "tiny": + # GRL tiny model + self.generator = GRL( + upscale = opt['scale'], + img_size = 64, + window_size = 8, + depths = [4, 4, 4, 4], + embed_dim = 64, + num_heads_window = [2, 2, 2, 2], + num_heads_stripe = [2, 2, 2, 2], + mlp_ratio = 2, + qkv_proj_type = "linear", + anchor_proj_type = "avgpool", + anchor_window_down_factor = 2, + out_proj_type = "linear", + conv_type = "1conv", + upsampler = "pixelshuffledirect", + ).cuda() + + + elif opt['model_size'] == "tiny2": + # GRL tiny model + self.generator = GRL( + upscale = opt['scale'], + img_size = 64, + window_size = 8, + depths = [4, 4, 4, 4], + embed_dim = 64, + num_heads_window = [2, 2, 2, 2], + num_heads_stripe = [2, 2, 2, 2], + mlp_ratio = 2, + qkv_proj_type = "linear", + anchor_proj_type = "avgpool", + anchor_window_down_factor = 2, + out_proj_type = "linear", + conv_type = "1conv", + upsampler = "nearest+conv", # Change + ).cuda() + + else: + raise NotImplementedError("We don't support such model size in GRL model") + + # self.generator = torch.compile(self.generator).cuda() # Don't use this for 3090Ti + self.generator.train() + + + def run(self): + self.master_run() + + + + def calculate_loss(self, gen_hr, imgs_hr): + # Define the loss function here + + # Generator pixel loss (l1 loss): generated vs. GT + l_g_pix = self.cri_pix(gen_hr, imgs_hr, self.batch_idx) + self.weight_store["pixel_loss"] = l_g_pix + self.generator_loss += l_g_pix + + + def tensorboard_report(self, iteration): + # self.writer.add_scalar('Loss/train-Generator_Loss-Iteration', self.generator_loss, iteration) + self.writer.add_scalar('Loss/train-Pixel_Loss-Iteration', self.weight_store["pixel_loss"], iteration) diff --git a/train_code/train_grlgan.py b/train_code/train_grlgan.py new file mode 100644 index 0000000000000000000000000000000000000000..36803704c55ed176f572ff109a3ceb863618091d --- /dev/null +++ b/train_code/train_grlgan.py @@ -0,0 +1,138 @@ +# -*- coding: utf-8 -*- + +import sys +import os +import torch + +# import important files +root_path = os.path.abspath('.') +sys.path.append(root_path) +from opt import opt +from architecture.grl import GRL +from architecture.discriminator import UNetDiscriminatorSN, MultiScaleDiscriminator +from train_code.train_master import train_master + + + +class train_grlgan(train_master): + def __init__(self, options, args) -> None: + super().__init__(options, args, "grlgan", True) # Pass a model name unique code + + + def loss_init(self): + + # prepare pixel loss (Generator) + self.pixel_loss_load() + + # prepare perceptual loss + self.GAN_loss_load() + + + def call_model(self): + # Generator: GRL Small + patch_size = 144 + if opt['model_size'] == "small": + # GRL small model + self.generator = GRL( + upscale = opt['scale'], + img_size = patch_size, + window_size = 8, + depths = [4, 4, 4, 4], + embed_dim = 128, + num_heads_window = [2, 2, 2, 2], + num_heads_stripe = [2, 2, 2, 2], + mlp_ratio = 2, + qkv_proj_type = "linear", + anchor_proj_type = "avgpool", + anchor_window_down_factor = 2, + out_proj_type = "linear", + conv_type = "1conv", + upsampler = "pixelshuffle", + ).cuda() + + elif opt['model_size'] == "tiny": + # GRL tiny model + self.generator = GRL( + upscale = opt['scale'], + img_size = 64, + window_size = 8, + depths = [4, 4, 4, 4], + embed_dim = 64, + num_heads_window = [2, 2, 2, 2], + num_heads_stripe = [2, 2, 2, 2], + mlp_ratio = 2, + qkv_proj_type = "linear", + anchor_proj_type = "avgpool", + anchor_window_down_factor = 2, + out_proj_type = "linear", + conv_type = "1conv", + upsampler = "pixelshuffledirect", + ).cuda() + + elif opt['model_size'] == "tiny2": + # GRL tiny model + self.generator = GRL( + upscale = opt['scale'], + img_size = 64, + window_size = 8, + depths = [4, 4, 4, 4], + embed_dim = 64, + num_heads_window = [2, 2, 2, 2], + num_heads_stripe = [2, 2, 2, 2], + mlp_ratio = 2, + qkv_proj_type = "linear", + anchor_proj_type = "avgpool", + anchor_window_down_factor = 2, + out_proj_type = "linear", + conv_type = "1conv", + upsampler = "nearest+conv", # Change + ).cuda() + + else: + raise NotImplementedError("We don't support such model size in GRL model") + # self.generator = torch.compile(self.generator).cuda() + + # Discriminator + if opt['discriminator_type'] == "PatchDiscriminator": + self.discriminator = MultiScaleDiscriminator(3).cuda() + elif opt['discriminator_type'] == "UNetDiscriminator": + self.discriminator = UNetDiscriminatorSN(3).cuda() + + self.generator.train(); self.discriminator.train() + + + def run(self): + self.master_run() + + + + def calculate_loss(self, gen_hr, imgs_hr): + + ###################### We have 3 losses on Generator ###################### + # Generator Pixel loss (l1 loss): generated vs. GT + l_g_pix = self.cri_pix(gen_hr, imgs_hr) + self.generator_loss += l_g_pix + self.weight_store["pixel_loss"] = l_g_pix + + + # Generator perceptual loss: generated vs. perceptual + l_g_percep_danbooru = self.cri_danbooru_perceptual(gen_hr, imgs_hr) + l_g_percep_vgg = self.cri_vgg_perceptual(gen_hr, imgs_hr) + l_g_percep = l_g_percep_danbooru + l_g_percep_vgg + self.generator_loss += l_g_percep + self.weight_store["perceptual_loss"] = l_g_percep + + + # Generator GAN loss label correction + fake_g_preds = self.discriminator(gen_hr) + l_g_gan = self.cri_gan(fake_g_preds, True, is_disc=False) # loss_weight (self.gan_loss_weight) is included + self.generator_loss += l_g_gan + self.weight_store["gan_loss"] = l_g_gan # Already with gan_loss_weight (0.2/1) + + + def tensorboard_report(self, iteration): + self.writer.add_scalar('Loss/train-Generator_Loss-Iteration', self.generator_loss, iteration) + self.writer.add_scalar('Loss/train-Pixel_Loss-Iteration', self.weight_store["pixel_loss"], iteration) + self.writer.add_scalar('Loss/train-Perceptual_Loss-Iteration', self.weight_store["perceptual_loss"], iteration) + self.writer.add_scalar('Loss/train-Discriminator_Loss-Iteration', self.weight_store["gan_loss"], iteration) + diff --git a/train_code/train_master.py b/train_code/train_master.py new file mode 100644 index 0000000000000000000000000000000000000000..353fb3c01dee186a6e621de033943ec865f1a0f4 --- /dev/null +++ b/train_code/train_master.py @@ -0,0 +1,368 @@ +# -*- coding: utf-8 -*- + +import os, sys +import torch +import glob +import time, shutil +import math +import gc +from tqdm import tqdm +from collections import defaultdict + +# torch module import +from torch.multiprocessing import Pool, Process, set_start_method +from torch.utils.tensorboard import SummaryWriter +from torch.utils.data import DataLoader + + +try: + set_start_method('spawn') +except RuntimeError: + pass + + +# import files from local folder +root_path = os.path.abspath('.') +sys.path.append(root_path) +from loss.gan_loss import GANLoss, MultiScaleGANLoss +from loss.pixel_loss import PixelLoss, L1_Charbonnier_loss +from loss.perceptual_loss import PerceptualLoss +from loss.anime_perceptual_loss import Anime_PerceptualLoss +from architecture.dataset import ImageDataset +from scripts.generate_lr_esr import generate_low_res_esr + + +# Mixed precision training +scaler = torch.cuda.amp.GradScaler() + +class train_master(object): + def __init__(self, options, args, model_name, has_discriminator=False) -> None: + # General specs setup + self.args = args + self.model_name = model_name + self.options = options + self.has_discriminator = has_discriminator + + # Loss init + self.loss_init() + + # Generator + self.call_model() # generator + discriminator... + + # Optimizer + self.learning_rate = options['start_learning_rate'] + self.optimizer_g = torch.optim.Adam(self.generator.parameters(), lr=self.learning_rate, betas=(options["adam_beta1"], options["adam_beta2"])) + if self.has_discriminator: + self.optimizer_d = torch.optim.Adam(self.discriminator.parameters(), lr=self.learning_rate, betas=(self.options["adam_beta1"], self.options["adam_beta2"])) + + # Train specs + self.start_iteration = 0 + self.lowest_generator_loss = float("inf") + + # Other auxiliary function + self.writer = SummaryWriter() + self.weight_store = defaultdict(int) + + # Options setting + self.n_iterations = options['train_iterations'] + self.batch_size = options['train_batch_size'] + self.n_cpu = options['train_dataloader_workers'] + + + def adjust_learning_rate(self, iteration_idx): + self.learning_rate = self.options['start_learning_rate'] + end_iteration = self.options['train_iterations'] + + # Calculate a learning rate we need in real-time based on the iteration_idx + for idx in range(min(end_iteration, iteration_idx)//self.options['decay_iteration']): + idx = idx+1 + if idx * self.options['decay_iteration'] in self.options['double_milestones']: + # double the learning rate in milestones + self.learning_rate = self.learning_rate * 2 + else: + # else, try to multiply decay_gamma (when we decay, we won't upscale) + self.learning_rate = self.learning_rate * self.options['decay_gamma'] # should be divisible in all cases + + # Change the learning rate to our target + for param_group in self.optimizer_g.param_groups: + param_group['lr'] = self.learning_rate + + if self.has_discriminator: + # print("We didn't yet handle discriminator, but we think that it should be necessary") + for param_group in self.optimizer_d.param_groups: + param_group['lr'] = self.learning_rate + + assert(self.learning_rate == self.optimizer_g.param_groups[0]['lr']) + + + def pixel_loss_load(self): + if self.options['pixel_loss'] == "L1": + self.cri_pix = PixelLoss().cuda() + elif self.options['pixel_loss'] == "L1_Charbonnier": + self.cri_pix = L1_Charbonnier_loss().cuda() + + print("We are using {} loss".format(self.options['pixel_loss'])) + + + def GAN_loss_load(self): + # parameter init + gan_loss_weight = self.options["gan_loss_weight"] + vgg_type = self.options['train_perceptual_vgg_type'] + + # Preceptual Loss + self.cri_pix = torch.nn.L1Loss().cuda() + self.cri_vgg_perceptual = PerceptualLoss(self.options['train_perceptual_layer_weights'], vgg_type, perceptual_weight=self.options["vgg_perceptual_loss_weight"]).cuda() + self.cri_danbooru_perceptual = Anime_PerceptualLoss(self.options["Danbooru_layer_weights"], perceptual_weight=self.options["danbooru_perceptual_loss_weight"]).cuda() + + # GAN loss + if self.options['discriminator_type'] == "PatchDiscriminator": + self.cri_gan = MultiScaleGANLoss(gan_type="lsgan", loss_weight=gan_loss_weight).cuda() # already put in loss scaler for discriminator + elif self.options['discriminator_type'] == "UNetDiscriminator": + self.cri_gan = GANLoss(gan_type="vanilla", loss_weight=gan_loss_weight).cuda() # already put in loss scaler for discriminator + + def tensorboard_epoch_draw(self, epoch_loss, epoch): + self.writer.add_scalar('Loss/train-Loss-Epoch', epoch_loss, epoch) + + + def master_run(self): + torch.backends.cudnn.benchmark = True + print("options are ", self.options) + + # Generate a new LR dataset before doing anything (Must before Data Loading) + self.generate_lr() + + # Load data + train_lr_paths = glob.glob(self.options["lr_dataset_path"] + "/*.*") + degrade_hr_paths = glob.glob(self.options["degrade_hr_dataset_path"] + "/*.*") + train_hr_paths = glob.glob(self.options["train_hr_dataset_path"] + "/*.*") + train_dataloader = DataLoader(ImageDataset(train_lr_paths, degrade_hr_paths, train_hr_paths), batch_size=self.batch_size, shuffle=True, num_workers=self.n_cpu) # ONLY LOAD HALF OF CPU AVAILABLE + dataset_length = len(os.listdir(self.options["train_hr_dataset_path"])) + + + # Check if we need to load weight + if self.args.auto_resume_best or self.args.auto_resume_closest: + self.load_weight(self.model_name) + elif self.args.pretrained_path != "": # If we give a pretrained path, we will use it (Should have in GAN training which uses pretrained L1 loss Network) + self.load_pretrained(self.model_name) + + # Start iterating the epochs + start_epoch = self.start_iteration // math.ceil(dataset_length / self.options['train_batch_size']) + n_epochs = self.n_iterations // math.ceil(dataset_length / self.options['train_batch_size']) + iteration_idx = self.start_iteration # init the iteration index + self.batch_idx = iteration_idx + self.adjust_learning_rate(iteration_idx) # adjust the learning rate to the desired one at the beginning + + for epoch in range(start_epoch, n_epochs): + print("This is epoch {} and the start iteration is {} with learning rate {}".format(epoch, iteration_idx, self.optimizer_g.param_groups[0]['lr'])) + + # Generate new lr degradation image + if epoch != start_epoch and epoch % self.options['degradate_generation_freq'] == 0: + self.generate_lr() + + # Batch training + loss_per_epoch = 0.0 + self.generator.train() + tqdm_bar = tqdm(train_dataloader, total=len(train_dataloader)) + for batch_idx, imgs in enumerate(tqdm_bar): + + imgs_lr = imgs["lr"].cuda() + imgs_degrade_hr = imgs["degrade_hr"].cuda() + imgs_hr = imgs["hr"].cuda() + + # Used for each iteration + self.generator_loss = 0 + self.single_iteration(imgs_lr, imgs_degrade_hr, imgs_hr) + + # tensorboard and updates + self.tensorboard_report(iteration_idx) + loss_per_epoch += self.generator_loss.item() + + ################################# Save model weights and update hyperparameter ######################################## + if self.lowest_generator_loss >= self.generator_loss.item(): + self.lowest_generator_loss = self.generator_loss.item() + print("\nSave model with the lowest generator_loss among all iteartions ", self.lowest_generator_loss) + + # Store the best + self.save_weight(iteration_idx, self.model_name+"_best", self.options) + + self.lowest_tensorboard_report(iteration_idx) + + # Update iteration and learning rate + iteration_idx += 1 + self.batch_idx = iteration_idx + if iteration_idx % self.options['decay_iteration'] == 0: + self.adjust_learning_rate(iteration_idx) # adjust the learning rate to the desired one + print("Update the learning rate to {} at iteration {} ".format(self.optimizer_g.param_groups[0]['lr'], iteration_idx)) + + # Don't clean any memory here, it will dramatically slow down the code + + # Per epoch report + self.tensorboard_epoch_draw( loss_per_epoch/batch_idx, epoch) + + + # Per epoch store weight + self.save_weight(iteration_idx, self.model_name+"_closest", self.options) + # Backup Checkpoint (Per 50 epoch) + if epoch % self.options['checkpoints_freq'] == 0 or epoch == n_epochs-1: + self.save_weight(iteration_idx, "checkpoints/" + self.model_name + "_epoch_" + str(epoch), self.options) + + + # Clean unneeded GPU cache (since we use subprocess for generate_lr(), so we need to kill them all) + torch.cuda.empty_cache() + time.sleep(5) # For enough time to clean the cache + + + + def single_iteration(self, imgs_lr, imgs_degrade_hr, imgs_hr): + + ############################################# Generator section ################################################## + self.optimizer_g.zero_grad() + if self.has_discriminator: + for p in self.discriminator.parameters(): + p.requires_grad = False + + with torch.cuda.amp.autocast(): + # generate high res image + gen_hr = self.generator(imgs_lr) + + # all distinct loss will be stored in self.weight_store (per iteration) + self.calculate_loss(gen_hr, imgs_hr) + + # backward needed loss + # self.loss_generator_total.backward() + # self.optimizer_g.step() + scaler.scale(self.generator_loss).backward() # loss backward + scaler.step(self.optimizer_g) + scaler.update() + ################################################################################################################### + + + if self.has_discriminator: + ##################################### Discriminator section ##################################################### + for p in self.discriminator.parameters(): + p.requires_grad = True + + self.optimizer_d.zero_grad() + + # discriminator real input + with torch.cuda.amp.autocast(): + # We only need imgs_degrade_hr instead of imgs_hr in discriminator (Thus, we don't want to introduce usm in the discriminator) + real_d_preds = self.discriminator(imgs_degrade_hr) + l_d_real = self.cri_gan(real_d_preds, True, is_disc=True) + scaler.scale(l_d_real).backward() + + + # discriminator fake input + with torch.cuda.amp.autocast(): + fake_d_preds = self.discriminator(gen_hr.detach().clone()) + l_d_fake = self.cri_gan(fake_d_preds, False, is_disc=True) + scaler.scale(l_d_fake).backward() + + # update + scaler.step(self.optimizer_d) + scaler.update() + ################################################################################################################## + + + def load_pretrained(self, name): + # This part will load generator weight here, and it doesn't need to + + weight_dir = self.args.pretrained_path + if not os.path.exists(weight_dir): + print("No such pretrained "+weight_dir+" file exists! We end the program! Please check the dir!") + os._exit(0) + + checkpoint_g = torch.load(weight_dir) + if 'model_state_dict' in checkpoint_g: + self.generator.load_state_dict(checkpoint_g['model_state_dict']) + elif 'params_ema' in checkpoint_g: + self.generator.load_state_dict(checkpoint_g['params_ema']) + else: + raise NotImplementedError("We didn't cannot locate the weight of thie pretrained weight") + + print(f"We will use pretrained "+name+" weight!") + + + def load_weight(self, head_prefix): + # Resume best or the closest weight available + head = head_prefix+"_best" if self.args.auto_resume_best else head_prefix+"_closest" + + if os.path.exists("saved_models/"+head+"_generator.pth"): + print("We need to resume previous " + head + " weight") + + # Generator + checkpoint_g = torch.load("saved_models/"+head+"_generator.pth") + self.generator.load_state_dict(checkpoint_g['model_state_dict']) + self.optimizer_g.load_state_dict(checkpoint_g['optimizer_state_dict']) + + # Discriminator + if self.has_discriminator: + checkpoint_d = torch.load("saved_models/"+head+"_discriminator.pth") + self.discriminator.load_state_dict(checkpoint_d['model_state_dict']) + self.optimizer_d.load_state_dict(checkpoint_d['optimizer_state_dict']) + assert(checkpoint_g['iteration'] == checkpoint_d['iteration']) # must be the same for iteration in generator and discriminator + + self.start_iteration = checkpoint_g['iteration'] + 1 + + # Prepare lowest generator + if os.path.exists("saved_models/" + head_prefix + "_best_generator.pth"): + checkpoint_g = torch.load("saved_models/" + head_prefix + "_best_generator.pth") # load generator weight + else: + print("There is no best weight exists!") + self.lowest_generator_loss = min(self.lowest_generator_loss, checkpoint_g["lowest_generator_weight"] ) + print("The lowest generator loss at the beginning is ", self.lowest_generator_loss) + else: + print(f"No saved_models/"+head+"_generator.pth " or " saved_models/"+head+"_discriminator.pth exists") + + + print(f"We will start from the iteration {self.start_iteration}") + + + + def save_weight(self, iteration, name, opt): + + # Generator + torch.save({ + 'iteration': iteration, + 'model_state_dict': self.generator.state_dict(), + 'optimizer_state_dict': self.optimizer_g.state_dict(), + 'lowest_generator_weight': self.lowest_generator_loss, + 'opt': opt, + }, "saved_models/" + name + "_generator.pth") + # 'pixel_loss': self.weight_store["pixel_loss"], + # 'perceptual_loss': self.weight_store['perceptual_loss'], + # 'gan_loss': self.weight_store["gan_loss"], + + + if self.has_discriminator: + # Discriminator + torch.save({ + 'iteration': iteration, + 'model_state_dict': self.discriminator.state_dict(), + 'optimizer_state_dict': self.optimizer_d.state_dict(), + }, "saved_models/" + name + "_discriminator.pth") + + + def lowest_tensorboard_report(self, iteration): + self.writer.add_scalar('Loss/lowest-weight', self.generator_loss, iteration) + + + @torch.no_grad() + def generate_lr(self): + + # If we directly use API, pytorch2.0 may raise an unknown bugs which is extremely slow on degradation pipeline + os.system("python scripts/generate_lr_esr.py") + + + # Assert check + lr_paths = os.listdir(self.options["lr_dataset_path"]) + degrade_hr_paths = os.listdir(self.options["degrade_hr_dataset_path"]) + hr_paths = os.listdir(self.options["train_hr_dataset_path"]) + + assert(len(lr_paths) == len(degrade_hr_paths)) + assert(len(lr_paths) == len(hr_paths)) + + + + +