Spaces:
Running
on
T4
Running
on
T4
HikariDawn
commited on
Commit
•
561c629
1
Parent(s):
193b9cd
feat: initial push
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +30 -0
- __assets__/logo.png +0 -0
- __assets__/lr_inputs/41.png +0 -0
- __assets__/lr_inputs/f91.jpg +0 -0
- __assets__/lr_inputs/image-00164.jpg +0 -0
- __assets__/lr_inputs/image-00186.png +0 -0
- __assets__/lr_inputs/image-00277.png +0 -0
- __assets__/lr_inputs/image-00440.png +0 -0
- __assets__/lr_inputs/image-00542.png +0 -0
- __assets__/lr_inputs/img_eva.jpeg +0 -0
- __assets__/lr_inputs/screenshot_resize.jpg +0 -0
- __assets__/visual_results/0079_2_visual.png +0 -0
- __assets__/visual_results/0079_visual.png +0 -0
- __assets__/visual_results/eva_visual.png +0 -0
- __assets__/visual_results/f91_visual.png +0 -0
- __assets__/visual_results/kiteret_visual.png +0 -0
- __assets__/visual_results/pokemon2_visual.png +0 -0
- __assets__/visual_results/pokemon_visual.png +0 -0
- __assets__/visual_results/wataru_visual.png +0 -0
- __assets__/workflow.png +0 -0
- app.py +117 -0
- architecture/cunet.py +189 -0
- architecture/dataset.py +106 -0
- architecture/discriminator.py +241 -0
- architecture/grl.py +616 -0
- architecture/grl_common/__init__.py +8 -0
- architecture/grl_common/common_edsr.py +227 -0
- architecture/grl_common/mixed_attn_block.py +1126 -0
- architecture/grl_common/mixed_attn_block_efficient.py +568 -0
- architecture/grl_common/ops.py +551 -0
- architecture/grl_common/resblock.py +61 -0
- architecture/grl_common/swin_v1_block.py +602 -0
- architecture/grl_common/swin_v2_block.py +306 -0
- architecture/grl_common/upsample.py +50 -0
- architecture/rrdb.py +218 -0
- architecture/swinir.py +874 -0
- dataset_curation_pipeline/IC9600/ICNet.py +151 -0
- dataset_curation_pipeline/IC9600/gene.py +113 -0
- dataset_curation_pipeline/collect.py +222 -0
- degradation/ESR/degradation_esr_shared.py +180 -0
- degradation/ESR/degradations_functionality.py +785 -0
- degradation/ESR/diffjpeg.py +517 -0
- degradation/ESR/usm_sharp.py +114 -0
- degradation/ESR/utils.py +126 -0
- degradation/degradation_esr.py +110 -0
- degradation/image_compression/avif.py +88 -0
- degradation/image_compression/heif.py +90 -0
- degradation/image_compression/jpeg.py +68 -0
- degradation/image_compression/webp.py +65 -0
- degradation/video_compression/h264.py +73 -0
.gitignore
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
datasets/*
|
2 |
+
.ipynb_checkpoints
|
3 |
+
.idea
|
4 |
+
__pycache__
|
5 |
+
|
6 |
+
datasets/
|
7 |
+
tmp_imgs
|
8 |
+
runs/
|
9 |
+
runs_last/
|
10 |
+
saved_models/*
|
11 |
+
saved_models/
|
12 |
+
pre_trained/
|
13 |
+
save_log/*
|
14 |
+
tmp/*
|
15 |
+
|
16 |
+
*.pyc
|
17 |
+
*.pth
|
18 |
+
*.png
|
19 |
+
*.jpg
|
20 |
+
*.mp4
|
21 |
+
*.txt
|
22 |
+
*.json
|
23 |
+
*.zip
|
24 |
+
*.mp4
|
25 |
+
*.csv
|
26 |
+
|
27 |
+
!__assets__/lr_inputs/*
|
28 |
+
!__assets__/*
|
29 |
+
!__assets__/visual_results/*
|
30 |
+
!requirements.txt
|
__assets__/logo.png
ADDED
__assets__/lr_inputs/41.png
ADDED
__assets__/lr_inputs/f91.jpg
ADDED
__assets__/lr_inputs/image-00164.jpg
ADDED
__assets__/lr_inputs/image-00186.png
ADDED
__assets__/lr_inputs/image-00277.png
ADDED
__assets__/lr_inputs/image-00440.png
ADDED
__assets__/lr_inputs/image-00542.png
ADDED
__assets__/lr_inputs/img_eva.jpeg
ADDED
__assets__/lr_inputs/screenshot_resize.jpg
ADDED
__assets__/visual_results/0079_2_visual.png
ADDED
__assets__/visual_results/0079_visual.png
ADDED
__assets__/visual_results/eva_visual.png
ADDED
__assets__/visual_results/f91_visual.png
ADDED
__assets__/visual_results/kiteret_visual.png
ADDED
__assets__/visual_results/pokemon2_visual.png
ADDED
__assets__/visual_results/pokemon_visual.png
ADDED
__assets__/visual_results/wataru_visual.png
ADDED
__assets__/workflow.png
ADDED
app.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys
|
2 |
+
import cv2
|
3 |
+
import gradio as gr
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
from torchvision.utils import save_image
|
7 |
+
|
8 |
+
|
9 |
+
# Import files from the local folder
|
10 |
+
root_path = os.path.abspath('.')
|
11 |
+
sys.path.append(root_path)
|
12 |
+
from test_code.inference import super_resolve_img
|
13 |
+
from test_code.test_utils import load_grl, load_rrdb
|
14 |
+
|
15 |
+
|
16 |
+
def auto_download_if_needed(weight_path):
|
17 |
+
if os.path.exists(weight_path):
|
18 |
+
return
|
19 |
+
|
20 |
+
if not os.path.exists("pretrained"):
|
21 |
+
os.makedirs("pretrained")
|
22 |
+
|
23 |
+
if weight_path == "pretrained/4x_APISR_GRL_GAN_generator.pth":
|
24 |
+
os.system("wget https://github.com/Kiteretsu77/APISR/releases/download/v0.1.0/4x_APISR_GRL_GAN_generator.pth")
|
25 |
+
os.system("mv 4x_APISR_GRL_GAN_generator.pth pretrained")
|
26 |
+
|
27 |
+
if weight_path == "pretrained/2x_APISR_RRDB_GAN_generator.pth":
|
28 |
+
os.system("wget https://github.com/Kiteretsu77/APISR/releases/download/v0.1.0/2x_APISR_RRDB_GAN_generator.pth")
|
29 |
+
os.system("mv 2x_APISR_RRDB_GAN_generator.pth pretrained")
|
30 |
+
|
31 |
+
|
32 |
+
|
33 |
+
def inference(img_path, model_name):
|
34 |
+
|
35 |
+
try:
|
36 |
+
weight_dtype = torch.float32
|
37 |
+
|
38 |
+
# Load the model
|
39 |
+
if model_name == "4xGRL":
|
40 |
+
weight_path = "pretrained/4x_APISR_GRL_GAN_generator.pth"
|
41 |
+
auto_download_if_needed(weight_path)
|
42 |
+
generator = load_grl(weight_path, scale=4) # Directly use default way now
|
43 |
+
|
44 |
+
elif model_name == "2xRRDB":
|
45 |
+
weight_path = "pretrained/2x_APISR_RRDB_GAN_generator.pth"
|
46 |
+
auto_download_if_needed(weight_path)
|
47 |
+
generator = load_rrdb(weight_path, scale=2) # Directly use default way now
|
48 |
+
|
49 |
+
else:
|
50 |
+
raise gr.Error(error)
|
51 |
+
|
52 |
+
generator = generator.to(dtype=weight_dtype)
|
53 |
+
|
54 |
+
|
55 |
+
# In default, we will automatically use crop to match 4x size
|
56 |
+
super_resolved_img = super_resolve_img(generator, img_path, output_path=None, weight_dtype=weight_dtype, crop_for_4x=True)
|
57 |
+
save_image(super_resolved_img, "SR_result.png")
|
58 |
+
outputs = cv2.imread("SR_result.png")
|
59 |
+
outputs = cv2.cvtColor(outputs, cv2.COLOR_RGB2BGR)
|
60 |
+
|
61 |
+
return outputs
|
62 |
+
|
63 |
+
|
64 |
+
except Exception as error:
|
65 |
+
raise gr.Error(f"global exception: {error}")
|
66 |
+
|
67 |
+
|
68 |
+
|
69 |
+
if __name__ == '__main__':
|
70 |
+
|
71 |
+
MARKDOWN = \
|
72 |
+
"""
|
73 |
+
## APISR: Anime Production Inspired Real-World Anime Super-Resolution (CVPR 2024)
|
74 |
+
|
75 |
+
[GitHub](https://github.com/Kiteretsu77/APISR) | [Paper](https://arxiv.org/abs/2403.01598)
|
76 |
+
|
77 |
+
If APISR is helpful for you, please help star the GitHub Repo. Thanks!
|
78 |
+
"""
|
79 |
+
|
80 |
+
block = gr.Blocks().queue()
|
81 |
+
with block:
|
82 |
+
with gr.Row():
|
83 |
+
gr.Markdown(MARKDOWN)
|
84 |
+
with gr.Row(elem_classes=["container"]):
|
85 |
+
with gr.Column(scale=2):
|
86 |
+
input_image = gr.Image(type="filepath", label="Input")
|
87 |
+
model_name = gr.Dropdown(
|
88 |
+
[
|
89 |
+
"2xRRDB",
|
90 |
+
"4xGRL"
|
91 |
+
],
|
92 |
+
type="value",
|
93 |
+
value="4xGRL",
|
94 |
+
label="model",
|
95 |
+
)
|
96 |
+
run_btn = gr.Button(value="Submit")
|
97 |
+
|
98 |
+
with gr.Column(scale=3):
|
99 |
+
output_image = gr.Image(type="numpy", label="Output image")
|
100 |
+
|
101 |
+
with gr.Row(elem_classes=["container"]):
|
102 |
+
gr.Examples(
|
103 |
+
[
|
104 |
+
["__assets__/lr_inputs/image-00277.png"],
|
105 |
+
["__assets__/lr_inputs/image-00542.png"],
|
106 |
+
["__assets__/lr_inputs/41.png"],
|
107 |
+
["__assets__/lr_inputs/f91.jpg"],
|
108 |
+
["__assets__/lr_inputs/image-00440.png"],
|
109 |
+
["__assets__/lr_inputs/image-00164.png"],
|
110 |
+
["__assets__/lr_inputs/img_eva.jpeg"],
|
111 |
+
],
|
112 |
+
[input_image],
|
113 |
+
)
|
114 |
+
|
115 |
+
run_btn.click(inference, inputs=[input_image, model_name], outputs=[output_image])
|
116 |
+
|
117 |
+
block.launch()
|
architecture/cunet.py
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Github Repository: https://github.com/bilibili/ailab/blob/main/Real-CUGAN/README_EN.md
|
2 |
+
# Code snippet (with certain modificaiton) from: https://github.com/bilibili/ailab/blob/main/Real-CUGAN/VapourSynth/upcunet_v3_vs.py
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import nn as nn
|
6 |
+
from torch.nn import functional as F
|
7 |
+
import os, sys
|
8 |
+
import numpy as np
|
9 |
+
from time import time as ttime, sleep
|
10 |
+
|
11 |
+
|
12 |
+
class UNet_Full(nn.Module):
|
13 |
+
|
14 |
+
def __init__(self):
|
15 |
+
super(UNet_Full, self).__init__()
|
16 |
+
self.unet1 = UNet1(3, 3, deconv=True)
|
17 |
+
self.unet2 = UNet2(3, 3, deconv=False)
|
18 |
+
|
19 |
+
def forward(self, x):
|
20 |
+
n, c, h0, w0 = x.shape
|
21 |
+
|
22 |
+
ph = ((h0 - 1) // 2 + 1) * 2
|
23 |
+
pw = ((w0 - 1) // 2 + 1) * 2
|
24 |
+
x = F.pad(x, (18, 18 + pw - w0, 18, 18 + ph - h0), 'reflect') # In order to ensure that it can be divided by 2
|
25 |
+
|
26 |
+
x1 = self.unet1(x)
|
27 |
+
x2 = self.unet2(x1)
|
28 |
+
|
29 |
+
x1 = F.pad(x1, (-20, -20, -20, -20))
|
30 |
+
output = torch.add(x2, x1)
|
31 |
+
|
32 |
+
if (w0 != pw or h0 != ph):
|
33 |
+
output = output[:, :, :h0 * 2, :w0 * 2]
|
34 |
+
|
35 |
+
return output
|
36 |
+
|
37 |
+
|
38 |
+
class SEBlock(nn.Module):
|
39 |
+
def __init__(self, in_channels, reduction=8, bias=False):
|
40 |
+
super(SEBlock, self).__init__()
|
41 |
+
self.conv1 = nn.Conv2d(in_channels, in_channels // reduction, 1, 1, 0, bias=bias)
|
42 |
+
self.conv2 = nn.Conv2d(in_channels // reduction, in_channels, 1, 1, 0, bias=bias)
|
43 |
+
|
44 |
+
def forward(self, x):
|
45 |
+
if ("Half" in x.type()): # torch.HalfTensor/torch.cuda.HalfTensor
|
46 |
+
x0 = torch.mean(x.float(), dim=(2, 3), keepdim=True).half()
|
47 |
+
else:
|
48 |
+
x0 = torch.mean(x, dim=(2, 3), keepdim=True)
|
49 |
+
x0 = self.conv1(x0)
|
50 |
+
x0 = F.relu(x0, inplace=True)
|
51 |
+
x0 = self.conv2(x0)
|
52 |
+
x0 = torch.sigmoid(x0)
|
53 |
+
x = torch.mul(x, x0)
|
54 |
+
return x
|
55 |
+
|
56 |
+
class UNetConv(nn.Module):
|
57 |
+
def __init__(self, in_channels, mid_channels, out_channels, se):
|
58 |
+
super(UNetConv, self).__init__()
|
59 |
+
self.conv = nn.Sequential(
|
60 |
+
nn.Conv2d(in_channels, mid_channels, 3, 1, 0),
|
61 |
+
nn.LeakyReLU(0.1, inplace=True),
|
62 |
+
nn.Conv2d(mid_channels, out_channels, 3, 1, 0),
|
63 |
+
nn.LeakyReLU(0.1, inplace=True),
|
64 |
+
)
|
65 |
+
if se:
|
66 |
+
self.seblock = SEBlock(out_channels, reduction=8, bias=True)
|
67 |
+
else:
|
68 |
+
self.seblock = None
|
69 |
+
|
70 |
+
def forward(self, x):
|
71 |
+
z = self.conv(x)
|
72 |
+
if self.seblock is not None:
|
73 |
+
z = self.seblock(z)
|
74 |
+
return z
|
75 |
+
|
76 |
+
class UNet1(nn.Module):
|
77 |
+
def __init__(self, in_channels, out_channels, deconv):
|
78 |
+
super(UNet1, self).__init__()
|
79 |
+
self.conv1 = UNetConv(in_channels, 32, 64, se=False)
|
80 |
+
self.conv1_down = nn.Conv2d(64, 64, 2, 2, 0)
|
81 |
+
self.conv2 = UNetConv(64, 128, 64, se=True)
|
82 |
+
self.conv2_up = nn.ConvTranspose2d(64, 64, 2, 2, 0)
|
83 |
+
self.conv3 = nn.Conv2d(64, 64, 3, 1, 0)
|
84 |
+
|
85 |
+
if deconv:
|
86 |
+
self.conv_bottom = nn.ConvTranspose2d(64, out_channels, 4, 2, 3)
|
87 |
+
else:
|
88 |
+
self.conv_bottom = nn.Conv2d(64, out_channels, 3, 1, 0)
|
89 |
+
|
90 |
+
for m in self.modules():
|
91 |
+
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
|
92 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
93 |
+
elif isinstance(m, nn.Linear):
|
94 |
+
nn.init.normal_(m.weight, 0, 0.01)
|
95 |
+
if m.bias is not None:
|
96 |
+
nn.init.constant_(m.bias, 0)
|
97 |
+
|
98 |
+
def forward(self, x):
|
99 |
+
x1 = self.conv1(x)
|
100 |
+
x2 = self.conv1_down(x1)
|
101 |
+
x2 = F.leaky_relu(x2, 0.1, inplace=True)
|
102 |
+
x2 = self.conv2(x2)
|
103 |
+
x2 = self.conv2_up(x2)
|
104 |
+
x2 = F.leaky_relu(x2, 0.1, inplace=True)
|
105 |
+
|
106 |
+
x1 = F.pad(x1, (-4, -4, -4, -4))
|
107 |
+
x3 = self.conv3(x1 + x2)
|
108 |
+
x3 = F.leaky_relu(x3, 0.1, inplace=True)
|
109 |
+
z = self.conv_bottom(x3)
|
110 |
+
return z
|
111 |
+
|
112 |
+
|
113 |
+
class UNet2(nn.Module):
|
114 |
+
def __init__(self, in_channels, out_channels, deconv):
|
115 |
+
super(UNet2, self).__init__()
|
116 |
+
|
117 |
+
self.conv1 = UNetConv(in_channels, 32, 64, se=False)
|
118 |
+
self.conv1_down = nn.Conv2d(64, 64, 2, 2, 0)
|
119 |
+
self.conv2 = UNetConv(64, 64, 128, se=True)
|
120 |
+
self.conv2_down = nn.Conv2d(128, 128, 2, 2, 0)
|
121 |
+
self.conv3 = UNetConv(128, 256, 128, se=True)
|
122 |
+
self.conv3_up = nn.ConvTranspose2d(128, 128, 2, 2, 0)
|
123 |
+
self.conv4 = UNetConv(128, 64, 64, se=True)
|
124 |
+
self.conv4_up = nn.ConvTranspose2d(64, 64, 2, 2, 0)
|
125 |
+
self.conv5 = nn.Conv2d(64, 64, 3, 1, 0)
|
126 |
+
|
127 |
+
if deconv:
|
128 |
+
self.conv_bottom = nn.ConvTranspose2d(64, out_channels, 4, 2, 3)
|
129 |
+
else:
|
130 |
+
self.conv_bottom = nn.Conv2d(64, out_channels, 3, 1, 0)
|
131 |
+
|
132 |
+
for m in self.modules():
|
133 |
+
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
|
134 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
135 |
+
elif isinstance(m, nn.Linear):
|
136 |
+
nn.init.normal_(m.weight, 0, 0.01)
|
137 |
+
if m.bias is not None:
|
138 |
+
nn.init.constant_(m.bias, 0)
|
139 |
+
|
140 |
+
def forward(self, x):
|
141 |
+
x1 = self.conv1(x)
|
142 |
+
x2 = self.conv1_down(x1)
|
143 |
+
x2 = F.leaky_relu(x2, 0.1, inplace=True)
|
144 |
+
x2 = self.conv2(x2)
|
145 |
+
|
146 |
+
x3 = self.conv2_down(x2)
|
147 |
+
x3 = F.leaky_relu(x3, 0.1, inplace=True)
|
148 |
+
x3 = self.conv3(x3)
|
149 |
+
x3 = self.conv3_up(x3)
|
150 |
+
x3 = F.leaky_relu(x3, 0.1, inplace=True)
|
151 |
+
|
152 |
+
x2 = F.pad(x2, (-4, -4, -4, -4))
|
153 |
+
x4 = self.conv4(x2 + x3)
|
154 |
+
x4 = self.conv4_up(x4)
|
155 |
+
x4 = F.leaky_relu(x4, 0.1, inplace=True)
|
156 |
+
|
157 |
+
x1 = F.pad(x1, (-16, -16, -16, -16))
|
158 |
+
x5 = self.conv5(x1 + x4)
|
159 |
+
x5 = F.leaky_relu(x5, 0.1, inplace=True)
|
160 |
+
|
161 |
+
z = self.conv_bottom(x5)
|
162 |
+
return z
|
163 |
+
|
164 |
+
|
165 |
+
|
166 |
+
def main():
|
167 |
+
root_path = os.path.abspath('.')
|
168 |
+
sys.path.append(root_path)
|
169 |
+
|
170 |
+
from opt import opt # Manage GPU to choose
|
171 |
+
import time
|
172 |
+
|
173 |
+
model = UNet_Full().cuda()
|
174 |
+
pytorch_total_params = sum(p.numel() for p in model.parameters())
|
175 |
+
print(f"CuNet has param {pytorch_total_params//1000} K params")
|
176 |
+
|
177 |
+
|
178 |
+
# Count the number of FLOPs to double check
|
179 |
+
x = torch.randn((1, 3, 180, 180)).cuda()
|
180 |
+
start = time.time()
|
181 |
+
x = model(x)
|
182 |
+
print("output size is ", x.shape)
|
183 |
+
total = time.time() - start
|
184 |
+
print(total)
|
185 |
+
|
186 |
+
|
187 |
+
|
188 |
+
if __name__ == "__main__":
|
189 |
+
main()
|
architecture/dataset.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch.autograd import Variable
|
6 |
+
from torchvision.models import vgg19
|
7 |
+
import torchvision.transforms as transforms
|
8 |
+
from torch.utils.data import DataLoader, Dataset
|
9 |
+
from torchvision.utils import save_image, make_grid
|
10 |
+
from torchvision.transforms import ToTensor
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
import cv2
|
14 |
+
import glob
|
15 |
+
import random
|
16 |
+
from PIL import Image
|
17 |
+
from tqdm import tqdm
|
18 |
+
|
19 |
+
|
20 |
+
# from degradation.degradation_main import degredate_process, preparation
|
21 |
+
from opt import opt
|
22 |
+
|
23 |
+
|
24 |
+
class ImageDataset(Dataset):
|
25 |
+
@torch.no_grad()
|
26 |
+
def __init__(self, train_lr_paths, degrade_hr_paths, train_hr_paths):
|
27 |
+
# print("low_res path sample is ", train_lr_paths[0])
|
28 |
+
# print(train_hr_paths[0])
|
29 |
+
# hr_height, hr_width = hr_shape
|
30 |
+
self.transform = transforms.Compose(
|
31 |
+
[
|
32 |
+
transforms.ToTensor(),
|
33 |
+
]
|
34 |
+
)
|
35 |
+
|
36 |
+
self.files_lr = train_lr_paths
|
37 |
+
self.files_degrade_hr = degrade_hr_paths
|
38 |
+
self.files_hr = train_hr_paths
|
39 |
+
|
40 |
+
assert(len(self.files_lr) == len(self.files_hr))
|
41 |
+
assert(len(self.files_lr) == len(self.files_degrade_hr))
|
42 |
+
|
43 |
+
|
44 |
+
def augment(self, imgs, hflip=True, rotation=True):
|
45 |
+
"""Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees).
|
46 |
+
|
47 |
+
All the images in the list use the same augmentation.
|
48 |
+
|
49 |
+
Args:
|
50 |
+
imgs (list[ndarray] | ndarray): Images to be augmented. If the input
|
51 |
+
is an ndarray, it will be transformed to a list.
|
52 |
+
hflip (bool): Horizontal flip. Default: True.
|
53 |
+
rotation (bool): Rotation. Default: True.
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
imgs (list[ndarray] | ndarray): Augmented images and flows. If returned
|
57 |
+
results only have one element, just return ndarray.
|
58 |
+
|
59 |
+
"""
|
60 |
+
hflip = hflip and random.random() < 0.5
|
61 |
+
vflip = rotation and random.random() < 0.5
|
62 |
+
rot90 = rotation and random.random() < 0.5
|
63 |
+
|
64 |
+
def _augment(img):
|
65 |
+
if hflip: # horizontal
|
66 |
+
cv2.flip(img, 1, img)
|
67 |
+
if vflip: # vertical
|
68 |
+
cv2.flip(img, 0, img)
|
69 |
+
if rot90:
|
70 |
+
img = img.transpose(1, 0, 2)
|
71 |
+
return img
|
72 |
+
|
73 |
+
|
74 |
+
if not isinstance(imgs, list):
|
75 |
+
imgs = [imgs]
|
76 |
+
|
77 |
+
imgs = [_augment(img) for img in imgs]
|
78 |
+
if len(imgs) == 1:
|
79 |
+
imgs = imgs[0]
|
80 |
+
|
81 |
+
|
82 |
+
return imgs
|
83 |
+
|
84 |
+
|
85 |
+
def __getitem__(self, index):
|
86 |
+
|
87 |
+
# Read File
|
88 |
+
img_lr = cv2.imread(self.files_lr[index % len(self.files_lr)]) # Should be BGR
|
89 |
+
img_degrade_hr = cv2.imread(self.files_degrade_hr[index % len(self.files_degrade_hr)])
|
90 |
+
img_hr = cv2.imread(self.files_hr[index % len(self.files_hr)])
|
91 |
+
|
92 |
+
# Augmentation
|
93 |
+
if random.random() < opt["augment_prob"]:
|
94 |
+
img_lr, img_degrade_hr, img_hr = self.augment([img_lr, img_degrade_hr, img_hr])
|
95 |
+
|
96 |
+
# Transform to Tensor
|
97 |
+
img_lr = self.transform(img_lr)
|
98 |
+
img_degrade_hr = self.transform(img_degrade_hr)
|
99 |
+
img_hr = self.transform(img_hr) # ToTensor() is already in the range [0, 1]
|
100 |
+
|
101 |
+
|
102 |
+
return {"lr": img_lr, "degrade_hr": img_degrade_hr, "hr": img_hr}
|
103 |
+
|
104 |
+
def __len__(self):
|
105 |
+
assert(len(self.files_hr) == len(self.files_lr))
|
106 |
+
return len(self.files_hr)
|
architecture/discriminator.py
ADDED
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
from torch import nn as nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
from torch.nn.utils import spectral_norm
|
5 |
+
import torch
|
6 |
+
import functools
|
7 |
+
|
8 |
+
class UNetDiscriminatorSN(nn.Module):
|
9 |
+
"""Defines a U-Net discriminator with spectral normalization (SN)
|
10 |
+
|
11 |
+
It is used in Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
|
12 |
+
|
13 |
+
Arg:
|
14 |
+
num_in_ch (int): Channel number of inputs. Default: 3.
|
15 |
+
num_feat (int): Channel number of base intermediate features. Default: 64.
|
16 |
+
skip_connection (bool): Whether to use skip connections between U-Net. Default: True.
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(self, num_in_ch, num_feat=64, skip_connection=True):
|
20 |
+
super(UNetDiscriminatorSN, self).__init__()
|
21 |
+
self.skip_connection = skip_connection
|
22 |
+
norm = spectral_norm
|
23 |
+
# the first convolution
|
24 |
+
self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1)
|
25 |
+
# downsample
|
26 |
+
self.conv1 = norm(nn.Conv2d(num_feat, num_feat * 2, 4, 2, 1, bias=False))
|
27 |
+
self.conv2 = norm(nn.Conv2d(num_feat * 2, num_feat * 4, 4, 2, 1, bias=False))
|
28 |
+
self.conv3 = norm(nn.Conv2d(num_feat * 4, num_feat * 8, 4, 2, 1, bias=False))
|
29 |
+
# upsample
|
30 |
+
self.conv4 = norm(nn.Conv2d(num_feat * 8, num_feat * 4, 3, 1, 1, bias=False))
|
31 |
+
self.conv5 = norm(nn.Conv2d(num_feat * 4, num_feat * 2, 3, 1, 1, bias=False))
|
32 |
+
self.conv6 = norm(nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1, bias=False))
|
33 |
+
# extra convolutions
|
34 |
+
self.conv7 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
|
35 |
+
self.conv8 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
|
36 |
+
self.conv9 = nn.Conv2d(num_feat, 1, 3, 1, 1)
|
37 |
+
|
38 |
+
def forward(self, x):
|
39 |
+
|
40 |
+
# downsample
|
41 |
+
x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True)
|
42 |
+
x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True)
|
43 |
+
x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True)
|
44 |
+
x3 = F.leaky_relu(self.conv3(x2), negative_slope=0.2, inplace=True)
|
45 |
+
|
46 |
+
# upsample
|
47 |
+
x3 = F.interpolate(x3, scale_factor=2, mode='bilinear', align_corners=False)
|
48 |
+
x4 = F.leaky_relu(self.conv4(x3), negative_slope=0.2, inplace=True)
|
49 |
+
|
50 |
+
if self.skip_connection:
|
51 |
+
x4 = x4 + x2
|
52 |
+
x4 = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
|
53 |
+
x5 = F.leaky_relu(self.conv5(x4), negative_slope=0.2, inplace=True)
|
54 |
+
|
55 |
+
if self.skip_connection:
|
56 |
+
x5 = x5 + x1
|
57 |
+
x5 = F.interpolate(x5, scale_factor=2, mode='bilinear', align_corners=False)
|
58 |
+
x6 = F.leaky_relu(self.conv6(x5), negative_slope=0.2, inplace=True)
|
59 |
+
|
60 |
+
if self.skip_connection:
|
61 |
+
x6 = x6 + x0
|
62 |
+
|
63 |
+
# extra convolutions
|
64 |
+
out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True)
|
65 |
+
out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True)
|
66 |
+
out = self.conv9(out)
|
67 |
+
|
68 |
+
return out
|
69 |
+
|
70 |
+
|
71 |
+
|
72 |
+
def get_conv_layer(input_nc, ndf, kernel_size, stride, padding, bias=True, use_sn=False):
|
73 |
+
if not use_sn:
|
74 |
+
return nn.Conv2d(input_nc, ndf, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
|
75 |
+
return spectral_norm(nn.Conv2d(input_nc, ndf, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias))
|
76 |
+
|
77 |
+
|
78 |
+
class PatchDiscriminator(nn.Module):
|
79 |
+
"""Defines a PatchGAN discriminator, the receptive field of default config is 70x70.
|
80 |
+
|
81 |
+
Args:
|
82 |
+
use_sn (bool): Use spectra_norm or not, if use_sn is True, then norm_type should be none.
|
83 |
+
"""
|
84 |
+
|
85 |
+
def __init__(self,
|
86 |
+
num_in_ch,
|
87 |
+
num_feat=64,
|
88 |
+
num_layers=3,
|
89 |
+
max_nf_mult=8,
|
90 |
+
norm_type='batch',
|
91 |
+
use_sigmoid=False,
|
92 |
+
use_sn=False):
|
93 |
+
super(PatchDiscriminator, self).__init__()
|
94 |
+
|
95 |
+
norm_layer = self._get_norm_layer(norm_type)
|
96 |
+
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
|
97 |
+
use_bias = norm_layer.func != nn.BatchNorm2d
|
98 |
+
else:
|
99 |
+
use_bias = norm_layer != nn.BatchNorm2d
|
100 |
+
|
101 |
+
kw = 4
|
102 |
+
padw = 1
|
103 |
+
sequence = [
|
104 |
+
get_conv_layer(num_in_ch, num_feat, kernel_size=kw, stride=2, padding=padw, use_sn=use_sn),
|
105 |
+
nn.LeakyReLU(0.2, True)
|
106 |
+
]
|
107 |
+
nf_mult = 1
|
108 |
+
nf_mult_prev = 1
|
109 |
+
for n in range(1, num_layers): # gradually increase the number of filters
|
110 |
+
nf_mult_prev = nf_mult
|
111 |
+
nf_mult = min(2**n, max_nf_mult)
|
112 |
+
sequence += [
|
113 |
+
get_conv_layer(
|
114 |
+
num_feat * nf_mult_prev,
|
115 |
+
num_feat * nf_mult,
|
116 |
+
kernel_size=kw,
|
117 |
+
stride=2,
|
118 |
+
padding=padw,
|
119 |
+
bias=use_bias,
|
120 |
+
use_sn=use_sn),
|
121 |
+
norm_layer(num_feat * nf_mult),
|
122 |
+
nn.LeakyReLU(0.2, True)
|
123 |
+
]
|
124 |
+
|
125 |
+
nf_mult_prev = nf_mult
|
126 |
+
nf_mult = min(2**num_layers, max_nf_mult)
|
127 |
+
sequence += [
|
128 |
+
get_conv_layer(
|
129 |
+
num_feat * nf_mult_prev,
|
130 |
+
num_feat * nf_mult,
|
131 |
+
kernel_size=kw,
|
132 |
+
stride=1,
|
133 |
+
padding=padw,
|
134 |
+
bias=use_bias,
|
135 |
+
use_sn=use_sn),
|
136 |
+
norm_layer(num_feat * nf_mult),
|
137 |
+
nn.LeakyReLU(0.2, True)
|
138 |
+
]
|
139 |
+
|
140 |
+
# output 1 channel prediction map 我觉得这个应该就是pixel by pixel的feedback反馈
|
141 |
+
sequence += [get_conv_layer(num_feat * nf_mult, 1, kernel_size=kw, stride=1, padding=padw, use_sn=use_sn)]
|
142 |
+
|
143 |
+
if use_sigmoid:
|
144 |
+
sequence += [nn.Sigmoid()]
|
145 |
+
self.model = nn.Sequential(*sequence)
|
146 |
+
|
147 |
+
def _get_norm_layer(self, norm_type='batch'):
|
148 |
+
if norm_type == 'batch':
|
149 |
+
norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
|
150 |
+
elif norm_type == 'instance':
|
151 |
+
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
|
152 |
+
elif norm_type == 'batchnorm2d':
|
153 |
+
norm_layer = nn.BatchNorm2d
|
154 |
+
elif norm_type == 'none':
|
155 |
+
norm_layer = nn.Identity
|
156 |
+
else:
|
157 |
+
raise NotImplementedError(f'normalization layer [{norm_type}] is not found')
|
158 |
+
|
159 |
+
return norm_layer
|
160 |
+
|
161 |
+
def forward(self, x):
|
162 |
+
return self.model(x)
|
163 |
+
|
164 |
+
|
165 |
+
class MultiScaleDiscriminator(nn.Module):
|
166 |
+
"""Define a multi-scale discriminator, each discriminator is a instance of PatchDiscriminator.
|
167 |
+
|
168 |
+
Args:
|
169 |
+
num_layers (int or list): If the type of this variable is int, then degrade to PatchDiscriminator.
|
170 |
+
If the type of this variable is list, then the length of the list is
|
171 |
+
the number of discriminators.
|
172 |
+
use_downscale (bool): Progressive downscale the input to feed into different discriminators.
|
173 |
+
If set to True, then the discriminators are usually the same.
|
174 |
+
"""
|
175 |
+
|
176 |
+
def __init__(self,
|
177 |
+
num_in_ch,
|
178 |
+
num_feat=64,
|
179 |
+
num_layers=[3, 3, 3],
|
180 |
+
max_nf_mult=8,
|
181 |
+
norm_type='none',
|
182 |
+
use_sigmoid=False,
|
183 |
+
use_sn=True,
|
184 |
+
use_downscale=True):
|
185 |
+
super(MultiScaleDiscriminator, self).__init__()
|
186 |
+
|
187 |
+
if isinstance(num_layers, int):
|
188 |
+
num_layers = [num_layers]
|
189 |
+
|
190 |
+
# check whether the discriminators are the same
|
191 |
+
if use_downscale:
|
192 |
+
assert len(set(num_layers)) == 1
|
193 |
+
self.use_downscale = use_downscale
|
194 |
+
|
195 |
+
self.num_dis = len(num_layers)
|
196 |
+
self.dis_list = nn.ModuleList()
|
197 |
+
for nl in num_layers:
|
198 |
+
self.dis_list.append(
|
199 |
+
PatchDiscriminator(
|
200 |
+
num_in_ch,
|
201 |
+
num_feat=num_feat,
|
202 |
+
num_layers=nl,
|
203 |
+
max_nf_mult=max_nf_mult,
|
204 |
+
norm_type=norm_type,
|
205 |
+
use_sigmoid=use_sigmoid,
|
206 |
+
use_sn=use_sn,
|
207 |
+
))
|
208 |
+
|
209 |
+
def forward(self, x):
|
210 |
+
outs = []
|
211 |
+
h, w = x.size()[2:]
|
212 |
+
|
213 |
+
y = x
|
214 |
+
for i in range(self.num_dis):
|
215 |
+
if i != 0 and self.use_downscale:
|
216 |
+
y = F.interpolate(y, size=(h // 2, w // 2), mode='bilinear', align_corners=True)
|
217 |
+
h, w = y.size()[2:]
|
218 |
+
outs.append(self.dis_list[i](y))
|
219 |
+
|
220 |
+
return outs
|
221 |
+
|
222 |
+
|
223 |
+
def main():
|
224 |
+
from pthflops import count_ops
|
225 |
+
from torchsummary import summary
|
226 |
+
|
227 |
+
model = UNetDiscriminatorSN(3)
|
228 |
+
pytorch_total_params = sum(p.numel() for p in model.parameters())
|
229 |
+
|
230 |
+
# Create a network and a corresponding input
|
231 |
+
device = 'cuda'
|
232 |
+
inp = torch.rand(1, 3, 400, 400)
|
233 |
+
|
234 |
+
# Count the number of FLOPs
|
235 |
+
count_ops(model, inp)
|
236 |
+
summary(model.cuda(), (3, 400, 400), batch_size=1)
|
237 |
+
# print(f"pathGAN has param {pytorch_total_params//1000} K params")
|
238 |
+
|
239 |
+
|
240 |
+
if __name__ == "__main__":
|
241 |
+
main()
|
architecture/grl.py
ADDED
@@ -0,0 +1,616 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Efficient and Explicit Modelling of Image Hierarchies for Image Restoration
|
3 |
+
Image restoration transformers with global, regional, and local modelling
|
4 |
+
A clean version of the.
|
5 |
+
Shared buffers are used for relative_coords_table, relative_position_index, and attn_mask.
|
6 |
+
"""
|
7 |
+
import cv2
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from torchvision.transforms import ToTensor
|
12 |
+
from torchvision.utils import save_image
|
13 |
+
from fairscale.nn import checkpoint_wrapper
|
14 |
+
from omegaconf import OmegaConf
|
15 |
+
from timm.models.layers import to_2tuple, trunc_normal_
|
16 |
+
|
17 |
+
# Import files from local folder
|
18 |
+
import os, sys
|
19 |
+
root_path = os.path.abspath('.')
|
20 |
+
sys.path.append(root_path)
|
21 |
+
|
22 |
+
from architecture.grl_common import Upsample, UpsampleOneStep
|
23 |
+
from architecture.grl_common.mixed_attn_block_efficient import (
|
24 |
+
_get_stripe_info,
|
25 |
+
EfficientMixAttnTransformerBlock,
|
26 |
+
)
|
27 |
+
from architecture.grl_common.ops import (
|
28 |
+
bchw_to_blc,
|
29 |
+
blc_to_bchw,
|
30 |
+
calculate_mask,
|
31 |
+
calculate_mask_all,
|
32 |
+
get_relative_coords_table_all,
|
33 |
+
get_relative_position_index_simple,
|
34 |
+
)
|
35 |
+
from architecture.grl_common.swin_v1_block import (
|
36 |
+
build_last_conv,
|
37 |
+
)
|
38 |
+
|
39 |
+
|
40 |
+
class TransformerStage(nn.Module):
|
41 |
+
"""Transformer stage.
|
42 |
+
Args:
|
43 |
+
dim (int): Number of input channels.
|
44 |
+
input_resolution (tuple[int]): Input resolution.
|
45 |
+
depth (int): Number of blocks.
|
46 |
+
num_heads_window (list[int]): Number of window attention heads in different layers.
|
47 |
+
num_heads_stripe (list[int]): Number of stripe attention heads in different layers.
|
48 |
+
stripe_size (list[int]): Stripe size. Default: [8, 8]
|
49 |
+
stripe_groups (list[int]): Number of stripe groups. Default: [None, None].
|
50 |
+
stripe_shift (bool): whether to shift the stripes. This is used as an ablation study.
|
51 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
52 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
53 |
+
qkv_proj_type (str): QKV projection type. Default: linear. Choices: linear, separable_conv.
|
54 |
+
anchor_proj_type (str): Anchor projection type. Default: avgpool. Choices: avgpool, maxpool, conv2d, separable_conv, patchmerging.
|
55 |
+
anchor_one_stage (bool): Whether to use one operator or multiple progressive operators to reduce feature map resolution. Default: True.
|
56 |
+
anchor_window_down_factor (int): The downscale factor used to get the anchors.
|
57 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
58 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
59 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
60 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
61 |
+
pretrained_window_size (list[int]): pretrained window size. This is actually not used. Default: [0, 0].
|
62 |
+
pretrained_stripe_size (list[int]): pretrained stripe size. This is actually not used. Default: [0, 0].
|
63 |
+
conv_type: The convolutional block before residual connection.
|
64 |
+
init_method: initialization method of the weight parameters used to train large scale models.
|
65 |
+
Choices: n, normal -- Swin V1 init method.
|
66 |
+
l, layernorm -- Swin V2 init method. Zero the weight and bias in the post layer normalization layer.
|
67 |
+
r, res_rescale -- EDSR rescale method. Rescale the residual blocks with a scaling factor 0.1
|
68 |
+
w, weight_rescale -- MSRResNet rescale method. Rescale the weight parameter in residual blocks with a scaling factor 0.1
|
69 |
+
t, trunc_normal_ -- nn.Linear, trunc_normal; nn.Conv2d, weight_rescale
|
70 |
+
fairscale_checkpoint (bool): Whether to use fairscale checkpoint.
|
71 |
+
offload_to_cpu (bool): used by fairscale_checkpoint
|
72 |
+
args:
|
73 |
+
out_proj_type (str): Type of the output projection in the self-attention modules. Default: linear. Choices: linear, conv2d.
|
74 |
+
local_connection (bool): Whether to enable the local modelling module (two convs followed by Channel attention). For GRL base model, this is used. "local_connection": local_connection,
|
75 |
+
euclidean_dist (bool): use Euclidean distance or inner product as the similarity metric. An ablation study.
|
76 |
+
"""
|
77 |
+
|
78 |
+
def __init__(
|
79 |
+
self,
|
80 |
+
dim,
|
81 |
+
input_resolution,
|
82 |
+
depth,
|
83 |
+
num_heads_window,
|
84 |
+
num_heads_stripe,
|
85 |
+
window_size,
|
86 |
+
stripe_size,
|
87 |
+
stripe_groups,
|
88 |
+
stripe_shift,
|
89 |
+
mlp_ratio=4.0,
|
90 |
+
qkv_bias=True,
|
91 |
+
qkv_proj_type="linear",
|
92 |
+
anchor_proj_type="avgpool",
|
93 |
+
anchor_one_stage=True,
|
94 |
+
anchor_window_down_factor=1,
|
95 |
+
drop=0.0,
|
96 |
+
attn_drop=0.0,
|
97 |
+
drop_path=0.0,
|
98 |
+
norm_layer=nn.LayerNorm,
|
99 |
+
pretrained_window_size=[0, 0],
|
100 |
+
pretrained_stripe_size=[0, 0],
|
101 |
+
conv_type="1conv",
|
102 |
+
init_method="",
|
103 |
+
fairscale_checkpoint=False,
|
104 |
+
offload_to_cpu=False,
|
105 |
+
args=None,
|
106 |
+
):
|
107 |
+
super().__init__()
|
108 |
+
|
109 |
+
self.dim = dim
|
110 |
+
self.input_resolution = input_resolution
|
111 |
+
self.init_method = init_method
|
112 |
+
|
113 |
+
self.blocks = nn.ModuleList()
|
114 |
+
for i in range(depth):
|
115 |
+
block = EfficientMixAttnTransformerBlock(
|
116 |
+
dim=dim,
|
117 |
+
input_resolution=input_resolution,
|
118 |
+
num_heads_w=num_heads_window,
|
119 |
+
num_heads_s=num_heads_stripe,
|
120 |
+
window_size=window_size,
|
121 |
+
window_shift=i % 2 == 0,
|
122 |
+
stripe_size=stripe_size,
|
123 |
+
stripe_groups=stripe_groups,
|
124 |
+
stripe_type="H" if i % 2 == 0 else "W",
|
125 |
+
stripe_shift=i % 4 in [2, 3] if stripe_shift else False,
|
126 |
+
mlp_ratio=mlp_ratio,
|
127 |
+
qkv_bias=qkv_bias,
|
128 |
+
qkv_proj_type=qkv_proj_type,
|
129 |
+
anchor_proj_type=anchor_proj_type,
|
130 |
+
anchor_one_stage=anchor_one_stage,
|
131 |
+
anchor_window_down_factor=anchor_window_down_factor,
|
132 |
+
drop=drop,
|
133 |
+
attn_drop=attn_drop,
|
134 |
+
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
135 |
+
norm_layer=norm_layer,
|
136 |
+
pretrained_window_size=pretrained_window_size,
|
137 |
+
pretrained_stripe_size=pretrained_stripe_size,
|
138 |
+
res_scale=0.1 if init_method == "r" else 1.0,
|
139 |
+
args=args,
|
140 |
+
)
|
141 |
+
# print(fairscale_checkpoint, offload_to_cpu)
|
142 |
+
if fairscale_checkpoint:
|
143 |
+
block = checkpoint_wrapper(block, offload_to_cpu=offload_to_cpu)
|
144 |
+
self.blocks.append(block)
|
145 |
+
|
146 |
+
self.conv = build_last_conv(conv_type, dim)
|
147 |
+
|
148 |
+
def _init_weights(self):
|
149 |
+
for n, m in self.named_modules():
|
150 |
+
if self.init_method == "w":
|
151 |
+
if isinstance(m, (nn.Linear, nn.Conv2d)) and n.find("cpb_mlp") < 0:
|
152 |
+
print("nn.Linear and nn.Conv2d weight initilization")
|
153 |
+
m.weight.data *= 0.1
|
154 |
+
elif self.init_method == "l":
|
155 |
+
if isinstance(m, nn.LayerNorm):
|
156 |
+
print("nn.LayerNorm initialization")
|
157 |
+
nn.init.constant_(m.bias, 0)
|
158 |
+
nn.init.constant_(m.weight, 0)
|
159 |
+
elif self.init_method.find("t") >= 0:
|
160 |
+
scale = 0.1 ** (len(self.init_method) - 1) * int(self.init_method[-1])
|
161 |
+
if isinstance(m, nn.Linear) and n.find("cpb_mlp") < 0:
|
162 |
+
trunc_normal_(m.weight, std=scale)
|
163 |
+
elif isinstance(m, nn.Conv2d):
|
164 |
+
m.weight.data *= 0.1
|
165 |
+
print(
|
166 |
+
"Initialization nn.Linear - trunc_normal; nn.Conv2d - weight rescale."
|
167 |
+
)
|
168 |
+
else:
|
169 |
+
raise NotImplementedError(
|
170 |
+
f"Parameter initialization method {self.init_method} not implemented in TransformerStage."
|
171 |
+
)
|
172 |
+
|
173 |
+
def forward(self, x, x_size, table_index_mask):
|
174 |
+
res = x
|
175 |
+
for blk in self.blocks:
|
176 |
+
res = blk(res, x_size, table_index_mask)
|
177 |
+
res = bchw_to_blc(self.conv(blc_to_bchw(res, x_size)))
|
178 |
+
|
179 |
+
return res + x
|
180 |
+
|
181 |
+
def flops(self):
|
182 |
+
pass
|
183 |
+
|
184 |
+
|
185 |
+
class GRL(nn.Module):
|
186 |
+
r"""Image restoration transformer with global, non-local, and local connections
|
187 |
+
Args:
|
188 |
+
img_size (int | list[int]): Input image size. Default 64
|
189 |
+
in_channels (int): Number of input image channels. Default: 3
|
190 |
+
out_channels (int): Number of output image channels. Default: None
|
191 |
+
embed_dim (int): Patch embedding dimension. Default: 96
|
192 |
+
upscale (int): Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
|
193 |
+
img_range (float): Image range. 1. or 255.
|
194 |
+
upsampler (str): The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
|
195 |
+
depths (list[int]): Depth of each Swin Transformer layer.
|
196 |
+
num_heads_window (list[int]): Number of window attention heads in different layers.
|
197 |
+
num_heads_stripe (list[int]): Number of stripe attention heads in different layers.
|
198 |
+
window_size (int): Window size. Default: 8.
|
199 |
+
stripe_size (list[int]): Stripe size. Default: [8, 8]
|
200 |
+
stripe_groups (list[int]): Number of stripe groups. Default: [None, None].
|
201 |
+
stripe_shift (bool): whether to shift the stripes. This is used as an ablation study.
|
202 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
|
203 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
204 |
+
qkv_proj_type (str): QKV projection type. Default: linear. Choices: linear, separable_conv.
|
205 |
+
anchor_proj_type (str): Anchor projection type. Default: avgpool. Choices: avgpool, maxpool, conv2d, separable_conv, patchmerging.
|
206 |
+
anchor_one_stage (bool): Whether to use one operator or multiple progressive operators to reduce feature map resolution. Default: True.
|
207 |
+
anchor_window_down_factor (int): The downscale factor used to get the anchors.
|
208 |
+
out_proj_type (str): Type of the output projection in the self-attention modules. Default: linear. Choices: linear, conv2d.
|
209 |
+
local_connection (bool): Whether to enable the local modelling module (two convs followed by Channel attention). For GRL base model, this is used.
|
210 |
+
drop_rate (float): Dropout rate. Default: 0
|
211 |
+
attn_drop_rate (float): Attention dropout rate. Default: 0
|
212 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.1
|
213 |
+
pretrained_window_size (list[int]): pretrained window size. This is actually not used. Default: [0, 0].
|
214 |
+
pretrained_stripe_size (list[int]): pretrained stripe size. This is actually not used. Default: [0, 0].
|
215 |
+
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
216 |
+
conv_type (str): The convolutional block before residual connection. Default: 1conv. Choices: 1conv, 3conv, 1conv1x1, linear
|
217 |
+
init_method: initialization method of the weight parameters used to train large scale models.
|
218 |
+
Choices: n, normal -- Swin V1 init method.
|
219 |
+
l, layernorm -- Swin V2 init method. Zero the weight and bias in the post layer normalization layer.
|
220 |
+
r, res_rescale -- EDSR rescale method. Rescale the residual blocks with a scaling factor 0.1
|
221 |
+
w, weight_rescale -- MSRResNet rescale method. Rescale the weight parameter in residual blocks with a scaling factor 0.1
|
222 |
+
t, trunc_normal_ -- nn.Linear, trunc_normal; nn.Conv2d, weight_rescale
|
223 |
+
fairscale_checkpoint (bool): Whether to use fairscale checkpoint.
|
224 |
+
offload_to_cpu (bool): used by fairscale_checkpoint
|
225 |
+
euclidean_dist (bool): use Euclidean distance or inner product as the similarity metric. An ablation study.
|
226 |
+
|
227 |
+
"""
|
228 |
+
|
229 |
+
def __init__(
|
230 |
+
self,
|
231 |
+
img_size=64,
|
232 |
+
in_channels=3,
|
233 |
+
out_channels=None,
|
234 |
+
embed_dim=96,
|
235 |
+
upscale=2,
|
236 |
+
img_range=1.0,
|
237 |
+
upsampler="",
|
238 |
+
depths=[6, 6, 6, 6, 6, 6],
|
239 |
+
num_heads_window=[3, 3, 3, 3, 3, 3],
|
240 |
+
num_heads_stripe=[3, 3, 3, 3, 3, 3],
|
241 |
+
window_size=8,
|
242 |
+
stripe_size=[8, 8], # used for stripe window attention
|
243 |
+
stripe_groups=[None, None],
|
244 |
+
stripe_shift=False,
|
245 |
+
mlp_ratio=4.0,
|
246 |
+
qkv_bias=True,
|
247 |
+
qkv_proj_type="linear",
|
248 |
+
anchor_proj_type="avgpool",
|
249 |
+
anchor_one_stage=True,
|
250 |
+
anchor_window_down_factor=1,
|
251 |
+
out_proj_type="linear",
|
252 |
+
local_connection=False,
|
253 |
+
drop_rate=0.0,
|
254 |
+
attn_drop_rate=0.0,
|
255 |
+
drop_path_rate=0.1,
|
256 |
+
norm_layer=nn.LayerNorm,
|
257 |
+
pretrained_window_size=[0, 0],
|
258 |
+
pretrained_stripe_size=[0, 0],
|
259 |
+
conv_type="1conv",
|
260 |
+
init_method="n", # initialization method of the weight parameters used to train large scale models.
|
261 |
+
fairscale_checkpoint=False, # fairscale activation checkpointing
|
262 |
+
offload_to_cpu=False,
|
263 |
+
euclidean_dist=False,
|
264 |
+
**kwargs,
|
265 |
+
):
|
266 |
+
super(GRL, self).__init__()
|
267 |
+
# Process the input arguments
|
268 |
+
out_channels = out_channels or in_channels
|
269 |
+
self.in_channels = in_channels
|
270 |
+
self.out_channels = out_channels
|
271 |
+
num_out_feats = 64
|
272 |
+
self.embed_dim = embed_dim
|
273 |
+
self.upscale = upscale
|
274 |
+
self.upsampler = upsampler
|
275 |
+
self.img_range = img_range
|
276 |
+
if in_channels == 3:
|
277 |
+
rgb_mean = (0.4488, 0.4371, 0.4040)
|
278 |
+
self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
|
279 |
+
else:
|
280 |
+
self.mean = torch.zeros(1, 1, 1, 1)
|
281 |
+
|
282 |
+
max_stripe_size = max([0 if s is None else s for s in stripe_size])
|
283 |
+
max_stripe_groups = max([0 if s is None else s for s in stripe_groups])
|
284 |
+
max_stripe_groups *= anchor_window_down_factor
|
285 |
+
self.pad_size = max(window_size, max_stripe_size, max_stripe_groups)
|
286 |
+
# if max_stripe_size >= window_size:
|
287 |
+
# self.pad_size *= anchor_window_down_factor
|
288 |
+
# if stripe_groups[0] is None and stripe_groups[1] is None:
|
289 |
+
# self.pad_size = max(stripe_size)
|
290 |
+
# else:
|
291 |
+
# self.pad_size = window_size
|
292 |
+
self.input_resolution = to_2tuple(img_size)
|
293 |
+
self.window_size = to_2tuple(window_size)
|
294 |
+
self.shift_size = [w // 2 for w in self.window_size]
|
295 |
+
self.stripe_size = stripe_size
|
296 |
+
self.stripe_groups = stripe_groups
|
297 |
+
self.pretrained_window_size = pretrained_window_size
|
298 |
+
self.pretrained_stripe_size = pretrained_stripe_size
|
299 |
+
self.anchor_window_down_factor = anchor_window_down_factor
|
300 |
+
|
301 |
+
# Head of the network. First convolution.
|
302 |
+
self.conv_first = nn.Conv2d(in_channels, embed_dim, 3, 1, 1)
|
303 |
+
|
304 |
+
# Body of the network
|
305 |
+
self.norm_start = norm_layer(embed_dim)
|
306 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
307 |
+
|
308 |
+
# stochastic depth
|
309 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
|
310 |
+
# stochastic depth decay rule
|
311 |
+
args = OmegaConf.create(
|
312 |
+
{
|
313 |
+
"out_proj_type": out_proj_type,
|
314 |
+
"local_connection": local_connection,
|
315 |
+
"euclidean_dist": euclidean_dist,
|
316 |
+
}
|
317 |
+
)
|
318 |
+
for k, v in self.set_table_index_mask(self.input_resolution).items():
|
319 |
+
self.register_buffer(k, v)
|
320 |
+
|
321 |
+
self.layers = nn.ModuleList()
|
322 |
+
for i in range(len(depths)):
|
323 |
+
layer = TransformerStage(
|
324 |
+
dim=embed_dim,
|
325 |
+
input_resolution=self.input_resolution,
|
326 |
+
depth=depths[i],
|
327 |
+
num_heads_window=num_heads_window[i],
|
328 |
+
num_heads_stripe=num_heads_stripe[i],
|
329 |
+
window_size=self.window_size,
|
330 |
+
stripe_size=stripe_size,
|
331 |
+
stripe_groups=stripe_groups,
|
332 |
+
stripe_shift=stripe_shift,
|
333 |
+
mlp_ratio=mlp_ratio,
|
334 |
+
qkv_bias=qkv_bias,
|
335 |
+
qkv_proj_type=qkv_proj_type,
|
336 |
+
anchor_proj_type=anchor_proj_type,
|
337 |
+
anchor_one_stage=anchor_one_stage,
|
338 |
+
anchor_window_down_factor=anchor_window_down_factor,
|
339 |
+
drop=drop_rate,
|
340 |
+
attn_drop=attn_drop_rate,
|
341 |
+
drop_path=dpr[
|
342 |
+
sum(depths[:i]) : sum(depths[: i + 1])
|
343 |
+
], # no impact on SR results
|
344 |
+
norm_layer=norm_layer,
|
345 |
+
pretrained_window_size=pretrained_window_size,
|
346 |
+
pretrained_stripe_size=pretrained_stripe_size,
|
347 |
+
conv_type=conv_type,
|
348 |
+
init_method=init_method,
|
349 |
+
fairscale_checkpoint=fairscale_checkpoint,
|
350 |
+
offload_to_cpu=offload_to_cpu,
|
351 |
+
args=args,
|
352 |
+
)
|
353 |
+
self.layers.append(layer)
|
354 |
+
self.norm_end = norm_layer(embed_dim)
|
355 |
+
|
356 |
+
# Tail of the network
|
357 |
+
self.conv_after_body = build_last_conv(conv_type, embed_dim)
|
358 |
+
|
359 |
+
#####################################################################################################
|
360 |
+
################################ 3, high quality image reconstruction ################################
|
361 |
+
if self.upsampler == "pixelshuffle":
|
362 |
+
# for classical SR
|
363 |
+
self.conv_before_upsample = nn.Sequential(
|
364 |
+
nn.Conv2d(embed_dim, num_out_feats, 3, 1, 1), nn.LeakyReLU(inplace=True)
|
365 |
+
)
|
366 |
+
self.upsample = Upsample(upscale, num_out_feats)
|
367 |
+
self.conv_last = nn.Conv2d(num_out_feats, out_channels, 3, 1, 1)
|
368 |
+
elif self.upsampler == "pixelshuffledirect":
|
369 |
+
# for lightweight SR (to save parameters)
|
370 |
+
self.upsample = UpsampleOneStep(
|
371 |
+
upscale,
|
372 |
+
embed_dim,
|
373 |
+
out_channels,
|
374 |
+
)
|
375 |
+
elif self.upsampler == "nearest+conv":
|
376 |
+
# for real-world SR (less artifacts)
|
377 |
+
assert self.upscale == 4, "only support x4 now."
|
378 |
+
self.conv_before_upsample = nn.Sequential(
|
379 |
+
nn.Conv2d(embed_dim, num_out_feats, 3, 1, 1), nn.LeakyReLU(inplace=True)
|
380 |
+
)
|
381 |
+
self.conv_up1 = nn.Conv2d(num_out_feats, num_out_feats, 3, 1, 1)
|
382 |
+
self.conv_up2 = nn.Conv2d(num_out_feats, num_out_feats, 3, 1, 1)
|
383 |
+
self.conv_hr = nn.Conv2d(num_out_feats, num_out_feats, 3, 1, 1)
|
384 |
+
self.conv_last = nn.Conv2d(num_out_feats, out_channels, 3, 1, 1)
|
385 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
386 |
+
else:
|
387 |
+
# for image denoising and JPEG compression artifact reduction
|
388 |
+
self.conv_last = nn.Conv2d(embed_dim, out_channels, 3, 1, 1)
|
389 |
+
|
390 |
+
self.apply(self._init_weights)
|
391 |
+
if init_method in ["l", "w"] or init_method.find("t") >= 0:
|
392 |
+
for layer in self.layers:
|
393 |
+
layer._init_weights()
|
394 |
+
|
395 |
+
def set_table_index_mask(self, x_size):
|
396 |
+
"""
|
397 |
+
Two used cases:
|
398 |
+
1) At initialization: set the shared buffers.
|
399 |
+
2) During forward pass: get the new buffers if the resolution of the input changes
|
400 |
+
"""
|
401 |
+
# ss - stripe_size, sss - stripe_shift_size
|
402 |
+
ss, sss = _get_stripe_info(self.stripe_size, self.stripe_groups, True, x_size)
|
403 |
+
df = self.anchor_window_down_factor
|
404 |
+
|
405 |
+
table_w = get_relative_coords_table_all(
|
406 |
+
self.window_size, self.pretrained_window_size
|
407 |
+
)
|
408 |
+
table_sh = get_relative_coords_table_all(ss, self.pretrained_stripe_size, df)
|
409 |
+
table_sv = get_relative_coords_table_all(
|
410 |
+
ss[::-1], self.pretrained_stripe_size, df
|
411 |
+
)
|
412 |
+
|
413 |
+
index_w = get_relative_position_index_simple(self.window_size)
|
414 |
+
index_sh_a2w = get_relative_position_index_simple(ss, df, False)
|
415 |
+
index_sh_w2a = get_relative_position_index_simple(ss, df, True)
|
416 |
+
index_sv_a2w = get_relative_position_index_simple(ss[::-1], df, False)
|
417 |
+
index_sv_w2a = get_relative_position_index_simple(ss[::-1], df, True)
|
418 |
+
|
419 |
+
mask_w = calculate_mask(x_size, self.window_size, self.shift_size)
|
420 |
+
mask_sh_a2w = calculate_mask_all(x_size, ss, sss, df, False)
|
421 |
+
mask_sh_w2a = calculate_mask_all(x_size, ss, sss, df, True)
|
422 |
+
mask_sv_a2w = calculate_mask_all(x_size, ss[::-1], sss[::-1], df, False)
|
423 |
+
mask_sv_w2a = calculate_mask_all(x_size, ss[::-1], sss[::-1], df, True)
|
424 |
+
return {
|
425 |
+
"table_w": table_w,
|
426 |
+
"table_sh": table_sh,
|
427 |
+
"table_sv": table_sv,
|
428 |
+
"index_w": index_w,
|
429 |
+
"index_sh_a2w": index_sh_a2w,
|
430 |
+
"index_sh_w2a": index_sh_w2a,
|
431 |
+
"index_sv_a2w": index_sv_a2w,
|
432 |
+
"index_sv_w2a": index_sv_w2a,
|
433 |
+
"mask_w": mask_w,
|
434 |
+
"mask_sh_a2w": mask_sh_a2w,
|
435 |
+
"mask_sh_w2a": mask_sh_w2a,
|
436 |
+
"mask_sv_a2w": mask_sv_a2w,
|
437 |
+
"mask_sv_w2a": mask_sv_w2a,
|
438 |
+
}
|
439 |
+
|
440 |
+
def get_table_index_mask(self, device=None, input_resolution=None):
|
441 |
+
# Used during forward pass
|
442 |
+
if input_resolution == self.input_resolution:
|
443 |
+
return {
|
444 |
+
"table_w": self.table_w,
|
445 |
+
"table_sh": self.table_sh,
|
446 |
+
"table_sv": self.table_sv,
|
447 |
+
"index_w": self.index_w,
|
448 |
+
"index_sh_a2w": self.index_sh_a2w,
|
449 |
+
"index_sh_w2a": self.index_sh_w2a,
|
450 |
+
"index_sv_a2w": self.index_sv_a2w,
|
451 |
+
"index_sv_w2a": self.index_sv_w2a,
|
452 |
+
"mask_w": self.mask_w,
|
453 |
+
"mask_sh_a2w": self.mask_sh_a2w,
|
454 |
+
"mask_sh_w2a": self.mask_sh_w2a,
|
455 |
+
"mask_sv_a2w": self.mask_sv_a2w,
|
456 |
+
"mask_sv_w2a": self.mask_sv_w2a,
|
457 |
+
}
|
458 |
+
else:
|
459 |
+
table_index_mask = self.set_table_index_mask(input_resolution)
|
460 |
+
for k, v in table_index_mask.items():
|
461 |
+
table_index_mask[k] = v.to(device)
|
462 |
+
return table_index_mask
|
463 |
+
|
464 |
+
def _init_weights(self, m):
|
465 |
+
if isinstance(m, nn.Linear):
|
466 |
+
# Only used to initialize linear layers
|
467 |
+
# weight_shape = m.weight.shape
|
468 |
+
# if weight_shape[0] > 256 and weight_shape[1] > 256:
|
469 |
+
# std = 0.004
|
470 |
+
# else:
|
471 |
+
# std = 0.02
|
472 |
+
# print(f"Standard deviation during initialization {std}.")
|
473 |
+
trunc_normal_(m.weight, std=0.02)
|
474 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
475 |
+
nn.init.constant_(m.bias, 0)
|
476 |
+
elif isinstance(m, nn.LayerNorm):
|
477 |
+
nn.init.constant_(m.bias, 0)
|
478 |
+
nn.init.constant_(m.weight, 1.0)
|
479 |
+
|
480 |
+
@torch.jit.ignore
|
481 |
+
def no_weight_decay(self):
|
482 |
+
return {"absolute_pos_embed"}
|
483 |
+
|
484 |
+
@torch.jit.ignore
|
485 |
+
def no_weight_decay_keywords(self):
|
486 |
+
return {"relative_position_bias_table"}
|
487 |
+
|
488 |
+
def check_image_size(self, x):
|
489 |
+
_, _, h, w = x.size()
|
490 |
+
mod_pad_h = (self.pad_size - h % self.pad_size) % self.pad_size
|
491 |
+
mod_pad_w = (self.pad_size - w % self.pad_size) % self.pad_size
|
492 |
+
# print("padding size", h, w, self.pad_size, mod_pad_h, mod_pad_w)
|
493 |
+
|
494 |
+
try:
|
495 |
+
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "reflect")
|
496 |
+
except BaseException:
|
497 |
+
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "constant")
|
498 |
+
return x
|
499 |
+
|
500 |
+
def forward_features(self, x):
|
501 |
+
x_size = (x.shape[2], x.shape[3])
|
502 |
+
x = bchw_to_blc(x)
|
503 |
+
x = self.norm_start(x)
|
504 |
+
x = self.pos_drop(x)
|
505 |
+
|
506 |
+
table_index_mask = self.get_table_index_mask(x.device, x_size)
|
507 |
+
for layer in self.layers:
|
508 |
+
x = layer(x, x_size, table_index_mask)
|
509 |
+
|
510 |
+
x = self.norm_end(x) # B L C
|
511 |
+
x = blc_to_bchw(x, x_size)
|
512 |
+
|
513 |
+
return x
|
514 |
+
|
515 |
+
def forward(self, x):
|
516 |
+
H, W = x.shape[2:]
|
517 |
+
x = self.check_image_size(x)
|
518 |
+
|
519 |
+
self.mean = self.mean.type_as(x)
|
520 |
+
x = (x - self.mean) * self.img_range
|
521 |
+
|
522 |
+
if self.upsampler == "pixelshuffle":
|
523 |
+
# for classical SR
|
524 |
+
x = self.conv_first(x)
|
525 |
+
x = self.conv_after_body(self.forward_features(x)) + x
|
526 |
+
x = self.conv_before_upsample(x)
|
527 |
+
x = self.conv_last(self.upsample(x))
|
528 |
+
elif self.upsampler == "pixelshuffledirect":
|
529 |
+
# for lightweight SR
|
530 |
+
x = self.conv_first(x)
|
531 |
+
x = self.conv_after_body(self.forward_features(x)) + x
|
532 |
+
x = self.upsample(x)
|
533 |
+
elif self.upsampler == "nearest+conv":
|
534 |
+
# for real-world SR (claimed to have less artifacts)
|
535 |
+
x = self.conv_first(x)
|
536 |
+
x = self.conv_after_body(self.forward_features(x)) + x
|
537 |
+
x = self.conv_before_upsample(x)
|
538 |
+
x = self.lrelu(
|
539 |
+
self.conv_up1(
|
540 |
+
torch.nn.functional.interpolate(x, scale_factor=2, mode="nearest")
|
541 |
+
)
|
542 |
+
)
|
543 |
+
x = self.lrelu(
|
544 |
+
self.conv_up2(
|
545 |
+
torch.nn.functional.interpolate(x, scale_factor=2, mode="nearest")
|
546 |
+
)
|
547 |
+
)
|
548 |
+
x = self.conv_last(self.lrelu(self.conv_hr(x)))
|
549 |
+
else:
|
550 |
+
# for image denoising and JPEG compression artifact reduction
|
551 |
+
x_first = self.conv_first(x)
|
552 |
+
res = self.conv_after_body(self.forward_features(x_first)) + x_first
|
553 |
+
if self.in_channels == self.out_channels:
|
554 |
+
x = x + self.conv_last(res)
|
555 |
+
else:
|
556 |
+
x = self.conv_last(res)
|
557 |
+
|
558 |
+
x = x / self.img_range + self.mean
|
559 |
+
|
560 |
+
return x[:, :, : H * self.upscale, : W * self.upscale]
|
561 |
+
|
562 |
+
def flops(self):
|
563 |
+
pass
|
564 |
+
|
565 |
+
def convert_checkpoint(self, state_dict):
|
566 |
+
for k in list(state_dict.keys()):
|
567 |
+
if (
|
568 |
+
k.find("relative_coords_table") >= 0
|
569 |
+
or k.find("relative_position_index") >= 0
|
570 |
+
or k.find("attn_mask") >= 0
|
571 |
+
or k.find("model.table_") >= 0
|
572 |
+
or k.find("model.index_") >= 0
|
573 |
+
or k.find("model.mask_") >= 0
|
574 |
+
# or k.find(".upsample.") >= 0
|
575 |
+
):
|
576 |
+
state_dict.pop(k)
|
577 |
+
print(k)
|
578 |
+
return state_dict
|
579 |
+
|
580 |
+
|
581 |
+
if __name__ == "__main__":
|
582 |
+
# The version of GRL we use
|
583 |
+
model = GRL(
|
584 |
+
upscale = 4,
|
585 |
+
img_size = 64,
|
586 |
+
window_size = 8,
|
587 |
+
depths = [4, 4, 4, 4],
|
588 |
+
embed_dim = 64,
|
589 |
+
num_heads_window = [2, 2, 2, 2],
|
590 |
+
num_heads_stripe = [2, 2, 2, 2],
|
591 |
+
mlp_ratio = 2,
|
592 |
+
qkv_proj_type = "linear",
|
593 |
+
anchor_proj_type = "avgpool",
|
594 |
+
anchor_window_down_factor = 2,
|
595 |
+
out_proj_type = "linear",
|
596 |
+
conv_type = "1conv",
|
597 |
+
upsampler = "nearest+conv", # Change
|
598 |
+
).cuda()
|
599 |
+
|
600 |
+
# Parameter analysis
|
601 |
+
num_params = 0
|
602 |
+
for p in model.parameters():
|
603 |
+
if p.requires_grad:
|
604 |
+
num_params += p.numel()
|
605 |
+
print(f"Number of parameters {num_params / 10 ** 6: 0.2f}")
|
606 |
+
|
607 |
+
# Print param
|
608 |
+
for name, param in model.named_parameters():
|
609 |
+
print(name, param.dtype)
|
610 |
+
|
611 |
+
|
612 |
+
# Count the number of FLOPs to double check
|
613 |
+
x = torch.randn((1, 3, 180, 180)).cuda() # Don't use input size that is too big (we don't have @torch.no_grad here)
|
614 |
+
x = model(x)
|
615 |
+
print("output size is ", x.shape)
|
616 |
+
|
architecture/grl_common/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from architecture.grl_common.resblock import ResBlock
|
2 |
+
from architecture.grl_common.upsample import (
|
3 |
+
Upsample,
|
4 |
+
UpsampleOneStep,
|
5 |
+
)
|
6 |
+
|
7 |
+
|
8 |
+
__all__ = ["Upsample", "UpsampleOneStep", "ResBlock"]
|
architecture/grl_common/common_edsr.py
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
EDSR common.py
|
3 |
+
Since a lot of models are developed on top of EDSR, here we include some common functions from EDSR.
|
4 |
+
In this repository, the common functions is used by edsr_esa.py and ipt.py
|
5 |
+
"""
|
6 |
+
|
7 |
+
|
8 |
+
import math
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import torch.nn.functional as F
|
13 |
+
|
14 |
+
|
15 |
+
def default_conv(in_channels, out_channels, kernel_size, bias=True):
|
16 |
+
return nn.Conv2d(
|
17 |
+
in_channels, out_channels, kernel_size, padding=(kernel_size // 2), bias=bias
|
18 |
+
)
|
19 |
+
|
20 |
+
|
21 |
+
class MeanShift(nn.Conv2d):
|
22 |
+
def __init__(
|
23 |
+
self,
|
24 |
+
rgb_range,
|
25 |
+
rgb_mean=(0.4488, 0.4371, 0.4040),
|
26 |
+
rgb_std=(1.0, 1.0, 1.0),
|
27 |
+
sign=-1,
|
28 |
+
):
|
29 |
+
|
30 |
+
super(MeanShift, self).__init__(3, 3, kernel_size=1)
|
31 |
+
std = torch.Tensor(rgb_std)
|
32 |
+
self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1)
|
33 |
+
self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std
|
34 |
+
for p in self.parameters():
|
35 |
+
p.requires_grad = False
|
36 |
+
|
37 |
+
|
38 |
+
class BasicBlock(nn.Sequential):
|
39 |
+
def __init__(
|
40 |
+
self,
|
41 |
+
conv,
|
42 |
+
in_channels,
|
43 |
+
out_channels,
|
44 |
+
kernel_size,
|
45 |
+
stride=1,
|
46 |
+
bias=False,
|
47 |
+
bn=True,
|
48 |
+
act=nn.ReLU(True),
|
49 |
+
):
|
50 |
+
|
51 |
+
m = [conv(in_channels, out_channels, kernel_size, bias=bias)]
|
52 |
+
if bn:
|
53 |
+
m.append(nn.BatchNorm2d(out_channels))
|
54 |
+
if act is not None:
|
55 |
+
m.append(act)
|
56 |
+
|
57 |
+
super(BasicBlock, self).__init__(*m)
|
58 |
+
|
59 |
+
|
60 |
+
class ESA(nn.Module):
|
61 |
+
def __init__(self, esa_channels, n_feats):
|
62 |
+
super(ESA, self).__init__()
|
63 |
+
f = esa_channels
|
64 |
+
self.conv1 = nn.Conv2d(n_feats, f, kernel_size=1)
|
65 |
+
self.conv_f = nn.Conv2d(f, f, kernel_size=1)
|
66 |
+
# self.conv_max = conv(f, f, kernel_size=3, padding=1)
|
67 |
+
self.conv2 = nn.Conv2d(f, f, kernel_size=3, stride=2, padding=0)
|
68 |
+
self.conv3 = nn.Conv2d(f, f, kernel_size=3, padding=1)
|
69 |
+
# self.conv3_ = conv(f, f, kernel_size=3, padding=1)
|
70 |
+
self.conv4 = nn.Conv2d(f, n_feats, kernel_size=1)
|
71 |
+
self.sigmoid = nn.Sigmoid()
|
72 |
+
# self.relu = nn.ReLU(inplace=True)
|
73 |
+
|
74 |
+
def forward(self, x):
|
75 |
+
c1_ = self.conv1(x)
|
76 |
+
c1 = self.conv2(c1_)
|
77 |
+
v_max = F.max_pool2d(c1, kernel_size=7, stride=3)
|
78 |
+
c3 = self.conv3(v_max)
|
79 |
+
# v_range = self.relu(self.conv_max(v_max))
|
80 |
+
# c3 = self.relu(self.conv3(v_range))
|
81 |
+
# c3 = self.conv3_(c3)
|
82 |
+
c3 = F.interpolate(
|
83 |
+
c3, (x.size(2), x.size(3)), mode="bilinear", align_corners=False
|
84 |
+
)
|
85 |
+
cf = self.conv_f(c1_)
|
86 |
+
c4 = self.conv4(c3 + cf)
|
87 |
+
m = self.sigmoid(c4)
|
88 |
+
|
89 |
+
return x * m
|
90 |
+
|
91 |
+
|
92 |
+
# class ESA(nn.Module):
|
93 |
+
# def __init__(self, esa_channels, n_feats, conv=nn.Conv2d):
|
94 |
+
# super(ESA, self).__init__()
|
95 |
+
# f = n_feats // 4
|
96 |
+
# self.conv1 = conv(n_feats, f, kernel_size=1)
|
97 |
+
# self.conv_f = conv(f, f, kernel_size=1)
|
98 |
+
# self.conv_max = conv(f, f, kernel_size=3, padding=1)
|
99 |
+
# self.conv2 = conv(f, f, kernel_size=3, stride=2, padding=0)
|
100 |
+
# self.conv3 = conv(f, f, kernel_size=3, padding=1)
|
101 |
+
# self.conv3_ = conv(f, f, kernel_size=3, padding=1)
|
102 |
+
# self.conv4 = conv(f, n_feats, kernel_size=1)
|
103 |
+
# self.sigmoid = nn.Sigmoid()
|
104 |
+
# self.relu = nn.ReLU(inplace=True)
|
105 |
+
#
|
106 |
+
# def forward(self, x):
|
107 |
+
# c1_ = (self.conv1(x))
|
108 |
+
# c1 = self.conv2(c1_)
|
109 |
+
# v_max = F.max_pool2d(c1, kernel_size=7, stride=3)
|
110 |
+
# v_range = self.relu(self.conv_max(v_max))
|
111 |
+
# c3 = self.relu(self.conv3(v_range))
|
112 |
+
# c3 = self.conv3_(c3)
|
113 |
+
# c3 = F.interpolate(c3, (x.size(2), x.size(3)), mode='bilinear', align_corners=False)
|
114 |
+
# cf = self.conv_f(c1_)
|
115 |
+
# c4 = self.conv4(c3 + cf)
|
116 |
+
# m = self.sigmoid(c4)
|
117 |
+
#
|
118 |
+
# return x * m
|
119 |
+
|
120 |
+
|
121 |
+
class ResBlock(nn.Module):
|
122 |
+
def __init__(
|
123 |
+
self,
|
124 |
+
conv,
|
125 |
+
n_feats,
|
126 |
+
kernel_size,
|
127 |
+
bias=True,
|
128 |
+
bn=False,
|
129 |
+
act=nn.ReLU(True),
|
130 |
+
res_scale=1,
|
131 |
+
esa_block=True,
|
132 |
+
depth_wise_kernel=7,
|
133 |
+
):
|
134 |
+
|
135 |
+
super(ResBlock, self).__init__()
|
136 |
+
m = []
|
137 |
+
for i in range(2):
|
138 |
+
m.append(conv(n_feats, n_feats, kernel_size, bias=bias))
|
139 |
+
if bn:
|
140 |
+
m.append(nn.BatchNorm2d(n_feats))
|
141 |
+
if i == 0:
|
142 |
+
m.append(act)
|
143 |
+
|
144 |
+
self.body = nn.Sequential(*m)
|
145 |
+
self.esa_block = esa_block
|
146 |
+
if self.esa_block:
|
147 |
+
esa_channels = 16
|
148 |
+
self.c5 = nn.Conv2d(
|
149 |
+
n_feats,
|
150 |
+
n_feats,
|
151 |
+
depth_wise_kernel,
|
152 |
+
padding=depth_wise_kernel // 2,
|
153 |
+
groups=n_feats,
|
154 |
+
bias=True,
|
155 |
+
)
|
156 |
+
self.esa = ESA(esa_channels, n_feats)
|
157 |
+
self.res_scale = res_scale
|
158 |
+
|
159 |
+
def forward(self, x):
|
160 |
+
res = self.body(x).mul(self.res_scale)
|
161 |
+
res += x
|
162 |
+
if self.esa_block:
|
163 |
+
res = self.esa(self.c5(res))
|
164 |
+
|
165 |
+
return res
|
166 |
+
|
167 |
+
|
168 |
+
class Upsampler(nn.Sequential):
|
169 |
+
def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True):
|
170 |
+
|
171 |
+
m = []
|
172 |
+
if (scale & (scale - 1)) == 0: # Is scale = 2^n?
|
173 |
+
for _ in range(int(math.log(scale, 2))):
|
174 |
+
m.append(conv(n_feats, 4 * n_feats, 3, bias))
|
175 |
+
m.append(nn.PixelShuffle(2))
|
176 |
+
if bn:
|
177 |
+
m.append(nn.BatchNorm2d(n_feats))
|
178 |
+
if act == "relu":
|
179 |
+
m.append(nn.ReLU(True))
|
180 |
+
elif act == "prelu":
|
181 |
+
m.append(nn.PReLU(n_feats))
|
182 |
+
|
183 |
+
elif scale == 3:
|
184 |
+
m.append(conv(n_feats, 9 * n_feats, 3, bias))
|
185 |
+
m.append(nn.PixelShuffle(3))
|
186 |
+
if bn:
|
187 |
+
m.append(nn.BatchNorm2d(n_feats))
|
188 |
+
if act == "relu":
|
189 |
+
m.append(nn.ReLU(True))
|
190 |
+
elif act == "prelu":
|
191 |
+
m.append(nn.PReLU(n_feats))
|
192 |
+
else:
|
193 |
+
raise NotImplementedError
|
194 |
+
|
195 |
+
super(Upsampler, self).__init__(*m)
|
196 |
+
|
197 |
+
|
198 |
+
class LiteUpsampler(nn.Sequential):
|
199 |
+
def __init__(self, conv, scale, n_feats, n_out=3, bn=False, act=False, bias=True):
|
200 |
+
|
201 |
+
m = []
|
202 |
+
m.append(conv(n_feats, n_out * (scale**2), 3, bias))
|
203 |
+
m.append(nn.PixelShuffle(scale))
|
204 |
+
# if (scale & (scale - 1)) == 0: # Is scale = 2^n?
|
205 |
+
# for _ in range(int(math.log(scale, 2))):
|
206 |
+
# m.append(conv(n_feats, 4 * n_out, 3, bias))
|
207 |
+
# m.append(nn.PixelShuffle(2))
|
208 |
+
# if bn:
|
209 |
+
# m.append(nn.BatchNorm2d(n_out))
|
210 |
+
# if act == 'relu':
|
211 |
+
# m.append(nn.ReLU(True))
|
212 |
+
# elif act == 'prelu':
|
213 |
+
# m.append(nn.PReLU(n_out))
|
214 |
+
|
215 |
+
# elif scale == 3:
|
216 |
+
# m.append(conv(n_feats, 9 * n_out, 3, bias))
|
217 |
+
# m.append(nn.PixelShuffle(3))
|
218 |
+
# if bn:
|
219 |
+
# m.append(nn.BatchNorm2d(n_out))
|
220 |
+
# if act == 'relu':
|
221 |
+
# m.append(nn.ReLU(True))
|
222 |
+
# elif act == 'prelu':
|
223 |
+
# m.append(nn.PReLU(n_out))
|
224 |
+
# else:
|
225 |
+
# raise NotImplementedError
|
226 |
+
|
227 |
+
super(LiteUpsampler, self).__init__(*m)
|
architecture/grl_common/mixed_attn_block.py
ADDED
@@ -0,0 +1,1126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from abc import ABC
|
3 |
+
from math import prod
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from architecture.grl_common.ops import (
|
9 |
+
bchw_to_bhwc,
|
10 |
+
bchw_to_blc,
|
11 |
+
blc_to_bchw,
|
12 |
+
blc_to_bhwc,
|
13 |
+
calculate_mask,
|
14 |
+
calculate_mask_all,
|
15 |
+
get_relative_coords_table_all,
|
16 |
+
get_relative_position_index_simple,
|
17 |
+
window_partition,
|
18 |
+
window_reverse,
|
19 |
+
)
|
20 |
+
from architecture.grl_common.swin_v1_block import Mlp
|
21 |
+
from timm.models.layers import DropPath
|
22 |
+
|
23 |
+
|
24 |
+
class CPB_MLP(nn.Sequential):
|
25 |
+
def __init__(self, in_channels, out_channels, channels=512):
|
26 |
+
m = [
|
27 |
+
nn.Linear(in_channels, channels, bias=True),
|
28 |
+
nn.ReLU(inplace=True),
|
29 |
+
nn.Linear(channels, out_channels, bias=False),
|
30 |
+
]
|
31 |
+
super(CPB_MLP, self).__init__(*m)
|
32 |
+
|
33 |
+
|
34 |
+
class AffineTransformWindow(nn.Module):
|
35 |
+
r"""Affine transformation of the attention map.
|
36 |
+
The window is a square window.
|
37 |
+
Supports attention between different window sizes
|
38 |
+
"""
|
39 |
+
|
40 |
+
def __init__(
|
41 |
+
self,
|
42 |
+
num_heads,
|
43 |
+
input_resolution,
|
44 |
+
window_size,
|
45 |
+
pretrained_window_size=[0, 0],
|
46 |
+
shift_size=0,
|
47 |
+
anchor_window_down_factor=1,
|
48 |
+
args=None,
|
49 |
+
):
|
50 |
+
super(AffineTransformWindow, self).__init__()
|
51 |
+
# print("AffineTransformWindow", args)
|
52 |
+
self.num_heads = num_heads
|
53 |
+
self.input_resolution = input_resolution
|
54 |
+
self.window_size = window_size
|
55 |
+
self.pretrained_window_size = pretrained_window_size
|
56 |
+
self.shift_size = shift_size
|
57 |
+
self.anchor_window_down_factor = anchor_window_down_factor
|
58 |
+
self.use_buffer = args.use_buffer
|
59 |
+
|
60 |
+
logit_scale = torch.log(10 * torch.ones((num_heads, 1, 1)))
|
61 |
+
self.logit_scale = nn.Parameter(logit_scale, requires_grad=True)
|
62 |
+
|
63 |
+
# mlp to generate continuous relative position bias
|
64 |
+
self.cpb_mlp = CPB_MLP(2, num_heads)
|
65 |
+
if self.use_buffer:
|
66 |
+
table = get_relative_coords_table_all(
|
67 |
+
window_size, pretrained_window_size, anchor_window_down_factor
|
68 |
+
)
|
69 |
+
index = get_relative_position_index_simple(
|
70 |
+
window_size, anchor_window_down_factor
|
71 |
+
)
|
72 |
+
self.register_buffer("relative_coords_table", table)
|
73 |
+
self.register_buffer("relative_position_index", index)
|
74 |
+
|
75 |
+
if self.shift_size > 0:
|
76 |
+
attn_mask = calculate_mask(
|
77 |
+
input_resolution, self.window_size, self.shift_size
|
78 |
+
)
|
79 |
+
else:
|
80 |
+
attn_mask = None
|
81 |
+
self.register_buffer("attn_mask", attn_mask)
|
82 |
+
|
83 |
+
def forward(self, attn, x_size):
|
84 |
+
B_, H, N, _ = attn.shape
|
85 |
+
device = attn.device
|
86 |
+
# logit scale
|
87 |
+
attn = attn * torch.clamp(self.logit_scale, max=math.log(1.0 / 0.01)).exp()
|
88 |
+
|
89 |
+
# relative position bias
|
90 |
+
if self.use_buffer:
|
91 |
+
table = self.relative_coords_table
|
92 |
+
index = self.relative_position_index
|
93 |
+
else:
|
94 |
+
table = get_relative_coords_table_all(
|
95 |
+
self.window_size,
|
96 |
+
self.pretrained_window_size,
|
97 |
+
self.anchor_window_down_factor,
|
98 |
+
).to(device)
|
99 |
+
index = get_relative_position_index_simple(
|
100 |
+
self.window_size, self.anchor_window_down_factor
|
101 |
+
).to(device)
|
102 |
+
|
103 |
+
bias_table = self.cpb_mlp(table) # 2*Wh-1, 2*Ww-1, num_heads
|
104 |
+
bias_table = bias_table.view(-1, self.num_heads)
|
105 |
+
|
106 |
+
win_dim = prod(self.window_size)
|
107 |
+
bias = bias_table[index.view(-1)]
|
108 |
+
bias = bias.view(win_dim, win_dim, -1).permute(2, 0, 1).contiguous()
|
109 |
+
# nH, Wh*Ww, Wh*Ww
|
110 |
+
bias = 16 * torch.sigmoid(bias)
|
111 |
+
attn = attn + bias.unsqueeze(0)
|
112 |
+
|
113 |
+
# W-MSA/SW-MSA
|
114 |
+
if self.use_buffer:
|
115 |
+
mask = self.attn_mask
|
116 |
+
# during test and window shift, recalculate the mask
|
117 |
+
if self.input_resolution != x_size and self.shift_size > 0:
|
118 |
+
mask = calculate_mask(x_size, self.window_size, self.shift_size)
|
119 |
+
mask = mask.to(attn.device)
|
120 |
+
else:
|
121 |
+
if self.shift_size > 0:
|
122 |
+
mask = calculate_mask(x_size, self.window_size, self.shift_size)
|
123 |
+
mask = mask.to(attn.device)
|
124 |
+
else:
|
125 |
+
mask = None
|
126 |
+
|
127 |
+
# shift attention mask
|
128 |
+
if mask is not None:
|
129 |
+
nW = mask.shape[0]
|
130 |
+
mask = mask.unsqueeze(1).unsqueeze(0)
|
131 |
+
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask
|
132 |
+
attn = attn.view(-1, self.num_heads, N, N)
|
133 |
+
|
134 |
+
return attn
|
135 |
+
|
136 |
+
|
137 |
+
class AffineTransformStripe(nn.Module):
|
138 |
+
r"""Affine transformation of the attention map.
|
139 |
+
The window is a stripe window. Supports attention between different window sizes
|
140 |
+
"""
|
141 |
+
|
142 |
+
def __init__(
|
143 |
+
self,
|
144 |
+
num_heads,
|
145 |
+
input_resolution,
|
146 |
+
stripe_size,
|
147 |
+
stripe_groups,
|
148 |
+
stripe_shift,
|
149 |
+
pretrained_stripe_size=[0, 0],
|
150 |
+
anchor_window_down_factor=1,
|
151 |
+
window_to_anchor=True,
|
152 |
+
args=None,
|
153 |
+
):
|
154 |
+
super(AffineTransformStripe, self).__init__()
|
155 |
+
self.num_heads = num_heads
|
156 |
+
self.input_resolution = input_resolution
|
157 |
+
self.stripe_size = stripe_size
|
158 |
+
self.stripe_groups = stripe_groups
|
159 |
+
self.pretrained_stripe_size = pretrained_stripe_size
|
160 |
+
# TODO: be careful when determining the pretrained_stripe_size
|
161 |
+
self.stripe_shift = stripe_shift
|
162 |
+
stripe_size, shift_size = self._get_stripe_info(input_resolution)
|
163 |
+
self.anchor_window_down_factor = anchor_window_down_factor
|
164 |
+
self.window_to_anchor = window_to_anchor
|
165 |
+
self.use_buffer = args.use_buffer
|
166 |
+
|
167 |
+
logit_scale = torch.log(10 * torch.ones((num_heads, 1, 1)))
|
168 |
+
self.logit_scale = nn.Parameter(logit_scale, requires_grad=True)
|
169 |
+
|
170 |
+
# mlp to generate continuous relative position bias
|
171 |
+
self.cpb_mlp = CPB_MLP(2, num_heads)
|
172 |
+
if self.use_buffer:
|
173 |
+
table = get_relative_coords_table_all(
|
174 |
+
stripe_size, pretrained_stripe_size, anchor_window_down_factor
|
175 |
+
)
|
176 |
+
index = get_relative_position_index_simple(
|
177 |
+
stripe_size, anchor_window_down_factor, window_to_anchor
|
178 |
+
)
|
179 |
+
self.register_buffer("relative_coords_table", table)
|
180 |
+
self.register_buffer("relative_position_index", index)
|
181 |
+
|
182 |
+
if self.stripe_shift:
|
183 |
+
attn_mask = calculate_mask_all(
|
184 |
+
input_resolution,
|
185 |
+
stripe_size,
|
186 |
+
shift_size,
|
187 |
+
anchor_window_down_factor,
|
188 |
+
window_to_anchor,
|
189 |
+
)
|
190 |
+
else:
|
191 |
+
attn_mask = None
|
192 |
+
self.register_buffer("attn_mask", attn_mask)
|
193 |
+
|
194 |
+
def forward(self, attn, x_size):
|
195 |
+
B_, H, N1, N2 = attn.shape
|
196 |
+
device = attn.device
|
197 |
+
# logit scale
|
198 |
+
attn = attn * torch.clamp(self.logit_scale, max=math.log(1.0 / 0.01)).exp()
|
199 |
+
|
200 |
+
# relative position bias
|
201 |
+
stripe_size, shift_size = self._get_stripe_info(x_size)
|
202 |
+
fixed_stripe_size = (
|
203 |
+
self.stripe_groups[0] is None and self.stripe_groups[1] is None
|
204 |
+
)
|
205 |
+
if not self.use_buffer or (
|
206 |
+
self.use_buffer
|
207 |
+
and self.input_resolution != x_size
|
208 |
+
and not fixed_stripe_size
|
209 |
+
):
|
210 |
+
# during test and stripe size is not fixed.
|
211 |
+
pretrained_stripe_size = (
|
212 |
+
self.pretrained_stripe_size
|
213 |
+
) # or stripe_size; Needs further pondering
|
214 |
+
table = get_relative_coords_table_all(
|
215 |
+
stripe_size, pretrained_stripe_size, self.anchor_window_down_factor
|
216 |
+
)
|
217 |
+
table = table.to(device)
|
218 |
+
index = get_relative_position_index_simple(
|
219 |
+
stripe_size, self.anchor_window_down_factor, self.window_to_anchor
|
220 |
+
).to(device)
|
221 |
+
else:
|
222 |
+
table = self.relative_coords_table
|
223 |
+
index = self.relative_position_index
|
224 |
+
# The same table size-> 1, Wh+AWh-1, Ww+AWw-1, 2
|
225 |
+
# But different index size -> # Wh*Ww, AWh*AWw
|
226 |
+
# if N1 < N2:
|
227 |
+
# index = index.transpose(0, 1)
|
228 |
+
|
229 |
+
bias_table = self.cpb_mlp(table).view(-1, self.num_heads)
|
230 |
+
# if not self.training:
|
231 |
+
# print(bias_table.shape, index.max(), index.min())
|
232 |
+
bias = bias_table[index.view(-1)]
|
233 |
+
bias = bias.view(N1, N2, -1).permute(2, 0, 1).contiguous()
|
234 |
+
# nH, Wh*Ww, Wh*Ww
|
235 |
+
bias = 16 * torch.sigmoid(bias)
|
236 |
+
# print(N1, N2, attn.shape, bias.unsqueeze(0).shape)
|
237 |
+
attn = attn + bias.unsqueeze(0)
|
238 |
+
|
239 |
+
# W-MSA/SW-MSA
|
240 |
+
if self.use_buffer:
|
241 |
+
mask = self.attn_mask
|
242 |
+
# during test and window shift, recalculate the mask
|
243 |
+
if self.input_resolution != x_size and self.stripe_shift > 0:
|
244 |
+
mask = calculate_mask_all(
|
245 |
+
x_size,
|
246 |
+
stripe_size,
|
247 |
+
shift_size,
|
248 |
+
self.anchor_window_down_factor,
|
249 |
+
self.window_to_anchor,
|
250 |
+
)
|
251 |
+
mask = mask.to(device)
|
252 |
+
else:
|
253 |
+
if self.stripe_shift > 0:
|
254 |
+
mask = calculate_mask_all(
|
255 |
+
x_size,
|
256 |
+
stripe_size,
|
257 |
+
shift_size,
|
258 |
+
self.anchor_window_down_factor,
|
259 |
+
self.window_to_anchor,
|
260 |
+
)
|
261 |
+
mask = mask.to(attn.device)
|
262 |
+
else:
|
263 |
+
mask = None
|
264 |
+
|
265 |
+
# shift attention mask
|
266 |
+
if mask is not None:
|
267 |
+
nW = mask.shape[0]
|
268 |
+
mask = mask.unsqueeze(1).unsqueeze(0)
|
269 |
+
attn = attn.view(B_ // nW, nW, self.num_heads, N1, N2) + mask
|
270 |
+
attn = attn.view(-1, self.num_heads, N1, N2)
|
271 |
+
|
272 |
+
return attn
|
273 |
+
|
274 |
+
def _get_stripe_info(self, input_resolution):
|
275 |
+
stripe_size, shift_size = [], []
|
276 |
+
for s, g, d in zip(self.stripe_size, self.stripe_groups, input_resolution):
|
277 |
+
if g is None:
|
278 |
+
stripe_size.append(s)
|
279 |
+
shift_size.append(s // 2 if self.stripe_shift else 0)
|
280 |
+
else:
|
281 |
+
stripe_size.append(d // g)
|
282 |
+
shift_size.append(0 if g == 1 else d // (g * 2))
|
283 |
+
return stripe_size, shift_size
|
284 |
+
|
285 |
+
|
286 |
+
class Attention(ABC, nn.Module):
|
287 |
+
def __init__(self):
|
288 |
+
super(Attention, self).__init__()
|
289 |
+
|
290 |
+
def attn(self, q, k, v, attn_transform, x_size, reshape=True):
|
291 |
+
# cosine attention map
|
292 |
+
B_, _, H, head_dim = q.shape
|
293 |
+
if self.euclidean_dist:
|
294 |
+
attn = torch.norm(q.unsqueeze(-2) - k.unsqueeze(-3), dim=-1)
|
295 |
+
else:
|
296 |
+
attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)
|
297 |
+
attn = attn_transform(attn, x_size)
|
298 |
+
# attention
|
299 |
+
attn = self.softmax(attn)
|
300 |
+
attn = self.attn_drop(attn)
|
301 |
+
x = attn @ v # B_, H, N1, head_dim
|
302 |
+
if reshape:
|
303 |
+
x = x.transpose(1, 2).reshape(B_, -1, H * head_dim)
|
304 |
+
# B_, N, C
|
305 |
+
return x
|
306 |
+
|
307 |
+
|
308 |
+
class WindowAttention(Attention):
|
309 |
+
r"""Window attention. QKV is the input to the forward method.
|
310 |
+
Args:
|
311 |
+
num_heads (int): Number of attention heads.
|
312 |
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
313 |
+
pretrained_window_size (tuple[int]): The height and width of the window in pre-training.
|
314 |
+
"""
|
315 |
+
|
316 |
+
def __init__(
|
317 |
+
self,
|
318 |
+
input_resolution,
|
319 |
+
window_size,
|
320 |
+
num_heads,
|
321 |
+
window_shift=False,
|
322 |
+
attn_drop=0.0,
|
323 |
+
pretrained_window_size=[0, 0],
|
324 |
+
args=None,
|
325 |
+
):
|
326 |
+
|
327 |
+
super(WindowAttention, self).__init__()
|
328 |
+
self.input_resolution = input_resolution
|
329 |
+
self.window_size = window_size
|
330 |
+
self.pretrained_window_size = pretrained_window_size
|
331 |
+
self.num_heads = num_heads
|
332 |
+
self.shift_size = window_size[0] // 2 if window_shift else 0
|
333 |
+
self.euclidean_dist = args.euclidean_dist
|
334 |
+
|
335 |
+
self.attn_transform = AffineTransformWindow(
|
336 |
+
num_heads,
|
337 |
+
input_resolution,
|
338 |
+
window_size,
|
339 |
+
pretrained_window_size,
|
340 |
+
self.shift_size,
|
341 |
+
args=args,
|
342 |
+
)
|
343 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
344 |
+
self.softmax = nn.Softmax(dim=-1)
|
345 |
+
|
346 |
+
def forward(self, qkv, x_size):
|
347 |
+
"""
|
348 |
+
Args:
|
349 |
+
qkv: input QKV features with shape of (B, L, 3C)
|
350 |
+
x_size: use x_size to determine whether the relative positional bias table and index
|
351 |
+
need to be regenerated.
|
352 |
+
"""
|
353 |
+
H, W = x_size
|
354 |
+
B, L, C = qkv.shape
|
355 |
+
qkv = qkv.view(B, H, W, C)
|
356 |
+
|
357 |
+
# cyclic shift
|
358 |
+
if self.shift_size > 0:
|
359 |
+
qkv = torch.roll(
|
360 |
+
qkv, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)
|
361 |
+
)
|
362 |
+
|
363 |
+
# partition windows
|
364 |
+
qkv = window_partition(qkv, self.window_size) # nW*B, wh, ww, C
|
365 |
+
qkv = qkv.view(-1, prod(self.window_size), C) # nW*B, wh*ww, C
|
366 |
+
|
367 |
+
B_, N, _ = qkv.shape
|
368 |
+
qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
369 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
370 |
+
|
371 |
+
# attention
|
372 |
+
x = self.attn(q, k, v, self.attn_transform, x_size)
|
373 |
+
|
374 |
+
# merge windows
|
375 |
+
x = x.view(-1, *self.window_size, C // 3)
|
376 |
+
x = window_reverse(x, self.window_size, x_size) # B, H, W, C/3
|
377 |
+
|
378 |
+
# reverse cyclic shift
|
379 |
+
if self.shift_size > 0:
|
380 |
+
x = torch.roll(x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
381 |
+
x = x.view(B, L, C // 3)
|
382 |
+
|
383 |
+
return x
|
384 |
+
|
385 |
+
def extra_repr(self) -> str:
|
386 |
+
return (
|
387 |
+
f"window_size={self.window_size}, shift_size={self.shift_size}, "
|
388 |
+
f"pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}"
|
389 |
+
)
|
390 |
+
|
391 |
+
def flops(self, N):
|
392 |
+
# calculate flops for 1 window with token length of N
|
393 |
+
flops = 0
|
394 |
+
# qkv = self.qkv(x)
|
395 |
+
flops += N * self.dim * 3 * self.dim
|
396 |
+
# attn = (q @ k.transpose(-2, -1))
|
397 |
+
flops += self.num_heads * N * (self.dim // self.num_heads) * N
|
398 |
+
# x = (attn @ v)
|
399 |
+
flops += self.num_heads * N * N * (self.dim // self.num_heads)
|
400 |
+
# x = self.proj(x)
|
401 |
+
flops += N * self.dim * self.dim
|
402 |
+
return flops
|
403 |
+
|
404 |
+
|
405 |
+
class StripeAttention(Attention):
|
406 |
+
r"""Stripe attention
|
407 |
+
Args:
|
408 |
+
stripe_size (tuple[int]): The height and width of the stripe.
|
409 |
+
num_heads (int): Number of attention heads.
|
410 |
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
411 |
+
pretrained_stripe_size (tuple[int]): The height and width of the stripe in pre-training.
|
412 |
+
"""
|
413 |
+
|
414 |
+
def __init__(
|
415 |
+
self,
|
416 |
+
input_resolution,
|
417 |
+
stripe_size,
|
418 |
+
stripe_groups,
|
419 |
+
stripe_shift,
|
420 |
+
num_heads,
|
421 |
+
attn_drop=0.0,
|
422 |
+
pretrained_stripe_size=[0, 0],
|
423 |
+
args=None,
|
424 |
+
):
|
425 |
+
|
426 |
+
super(StripeAttention, self).__init__()
|
427 |
+
self.input_resolution = input_resolution
|
428 |
+
self.stripe_size = stripe_size # Wh, Ww
|
429 |
+
self.stripe_groups = stripe_groups
|
430 |
+
self.stripe_shift = stripe_shift
|
431 |
+
self.num_heads = num_heads
|
432 |
+
self.pretrained_stripe_size = pretrained_stripe_size
|
433 |
+
self.euclidean_dist = args.euclidean_dist
|
434 |
+
|
435 |
+
self.attn_transform = AffineTransformStripe(
|
436 |
+
num_heads,
|
437 |
+
input_resolution,
|
438 |
+
stripe_size,
|
439 |
+
stripe_groups,
|
440 |
+
stripe_shift,
|
441 |
+
pretrained_stripe_size,
|
442 |
+
anchor_window_down_factor=1,
|
443 |
+
args=args,
|
444 |
+
)
|
445 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
446 |
+
self.softmax = nn.Softmax(dim=-1)
|
447 |
+
|
448 |
+
def forward(self, qkv, x_size):
|
449 |
+
"""
|
450 |
+
Args:
|
451 |
+
x: input features with shape of (B, L, C)
|
452 |
+
stripe_size: use stripe_size to determine whether the relative positional bias table and index
|
453 |
+
need to be regenerated.
|
454 |
+
"""
|
455 |
+
H, W = x_size
|
456 |
+
B, L, C = qkv.shape
|
457 |
+
qkv = qkv.view(B, H, W, C)
|
458 |
+
|
459 |
+
running_stripe_size, running_shift_size = self.attn_transform._get_stripe_info(
|
460 |
+
x_size
|
461 |
+
)
|
462 |
+
# cyclic shift
|
463 |
+
if self.stripe_shift:
|
464 |
+
qkv = torch.roll(
|
465 |
+
qkv,
|
466 |
+
shifts=(-running_shift_size[0], -running_shift_size[1]),
|
467 |
+
dims=(1, 2),
|
468 |
+
)
|
469 |
+
|
470 |
+
# partition windows
|
471 |
+
qkv = window_partition(qkv, running_stripe_size) # nW*B, wh, ww, C
|
472 |
+
qkv = qkv.view(-1, prod(running_stripe_size), C) # nW*B, wh*ww, C
|
473 |
+
|
474 |
+
B_, N, _ = qkv.shape
|
475 |
+
qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
476 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
477 |
+
|
478 |
+
# attention
|
479 |
+
x = self.attn(q, k, v, self.attn_transform, x_size)
|
480 |
+
|
481 |
+
# merge windows
|
482 |
+
x = x.view(-1, *running_stripe_size, C // 3)
|
483 |
+
x = window_reverse(x, running_stripe_size, x_size) # B H W C/3
|
484 |
+
|
485 |
+
# reverse the shift
|
486 |
+
if self.stripe_shift:
|
487 |
+
x = torch.roll(x, shifts=running_shift_size, dims=(1, 2))
|
488 |
+
|
489 |
+
x = x.view(B, L, C // 3)
|
490 |
+
return x
|
491 |
+
|
492 |
+
def extra_repr(self) -> str:
|
493 |
+
return (
|
494 |
+
f"stripe_size={self.stripe_size}, stripe_groups={self.stripe_groups}, stripe_shift={self.stripe_shift}, "
|
495 |
+
f"pretrained_stripe_size={self.pretrained_stripe_size}, num_heads={self.num_heads}"
|
496 |
+
)
|
497 |
+
|
498 |
+
def flops(self, N):
|
499 |
+
# calculate flops for 1 window with token length of N
|
500 |
+
flops = 0
|
501 |
+
# qkv = self.qkv(x)
|
502 |
+
flops += N * self.dim * 3 * self.dim
|
503 |
+
# attn = (q @ k.transpose(-2, -1))
|
504 |
+
flops += self.num_heads * N * (self.dim // self.num_heads) * N
|
505 |
+
# x = (attn @ v)
|
506 |
+
flops += self.num_heads * N * N * (self.dim // self.num_heads)
|
507 |
+
# x = self.proj(x)
|
508 |
+
flops += N * self.dim * self.dim
|
509 |
+
return flops
|
510 |
+
|
511 |
+
|
512 |
+
class AnchorStripeAttention(Attention):
|
513 |
+
r"""Stripe attention
|
514 |
+
Args:
|
515 |
+
stripe_size (tuple[int]): The height and width of the stripe.
|
516 |
+
num_heads (int): Number of attention heads.
|
517 |
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
518 |
+
pretrained_stripe_size (tuple[int]): The height and width of the stripe in pre-training.
|
519 |
+
"""
|
520 |
+
|
521 |
+
def __init__(
|
522 |
+
self,
|
523 |
+
input_resolution,
|
524 |
+
stripe_size,
|
525 |
+
stripe_groups,
|
526 |
+
stripe_shift,
|
527 |
+
num_heads,
|
528 |
+
attn_drop=0.0,
|
529 |
+
pretrained_stripe_size=[0, 0],
|
530 |
+
anchor_window_down_factor=1,
|
531 |
+
args=None,
|
532 |
+
):
|
533 |
+
|
534 |
+
super(AnchorStripeAttention, self).__init__()
|
535 |
+
self.input_resolution = input_resolution
|
536 |
+
self.stripe_size = stripe_size # Wh, Ww
|
537 |
+
self.stripe_groups = stripe_groups
|
538 |
+
self.stripe_shift = stripe_shift
|
539 |
+
self.num_heads = num_heads
|
540 |
+
self.pretrained_stripe_size = pretrained_stripe_size
|
541 |
+
self.anchor_window_down_factor = anchor_window_down_factor
|
542 |
+
self.euclidean_dist = args.euclidean_dist
|
543 |
+
|
544 |
+
self.attn_transform1 = AffineTransformStripe(
|
545 |
+
num_heads,
|
546 |
+
input_resolution,
|
547 |
+
stripe_size,
|
548 |
+
stripe_groups,
|
549 |
+
stripe_shift,
|
550 |
+
pretrained_stripe_size,
|
551 |
+
anchor_window_down_factor,
|
552 |
+
window_to_anchor=False,
|
553 |
+
args=args,
|
554 |
+
)
|
555 |
+
|
556 |
+
self.attn_transform2 = AffineTransformStripe(
|
557 |
+
num_heads,
|
558 |
+
input_resolution,
|
559 |
+
stripe_size,
|
560 |
+
stripe_groups,
|
561 |
+
stripe_shift,
|
562 |
+
pretrained_stripe_size,
|
563 |
+
anchor_window_down_factor,
|
564 |
+
window_to_anchor=True,
|
565 |
+
args=args,
|
566 |
+
)
|
567 |
+
|
568 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
569 |
+
self.softmax = nn.Softmax(dim=-1)
|
570 |
+
|
571 |
+
def forward(self, qkv, anchor, x_size):
|
572 |
+
"""
|
573 |
+
Args:
|
574 |
+
qkv: input features with shape of (B, L, C)
|
575 |
+
anchor:
|
576 |
+
x_size: use stripe_size to determine whether the relative positional bias table and index
|
577 |
+
need to be regenerated.
|
578 |
+
"""
|
579 |
+
H, W = x_size
|
580 |
+
B, L, C = qkv.shape
|
581 |
+
qkv = qkv.view(B, H, W, C)
|
582 |
+
|
583 |
+
stripe_size, shift_size = self.attn_transform1._get_stripe_info(x_size)
|
584 |
+
anchor_stripe_size = [s // self.anchor_window_down_factor for s in stripe_size]
|
585 |
+
anchor_shift_size = [s // self.anchor_window_down_factor for s in shift_size]
|
586 |
+
# cyclic shift
|
587 |
+
if self.stripe_shift:
|
588 |
+
qkv = torch.roll(qkv, shifts=(-shift_size[0], -shift_size[1]), dims=(1, 2))
|
589 |
+
anchor = torch.roll(
|
590 |
+
anchor,
|
591 |
+
shifts=(-anchor_shift_size[0], -anchor_shift_size[1]),
|
592 |
+
dims=(1, 2),
|
593 |
+
)
|
594 |
+
|
595 |
+
# partition windows
|
596 |
+
qkv = window_partition(qkv, stripe_size) # nW*B, wh, ww, C
|
597 |
+
qkv = qkv.view(-1, prod(stripe_size), C) # nW*B, wh*ww, C
|
598 |
+
anchor = window_partition(anchor, anchor_stripe_size)
|
599 |
+
anchor = anchor.view(-1, prod(anchor_stripe_size), C // 3)
|
600 |
+
|
601 |
+
B_, N1, _ = qkv.shape
|
602 |
+
N2 = anchor.shape[1]
|
603 |
+
qkv = qkv.reshape(B_, N1, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
604 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
605 |
+
anchor = anchor.reshape(B_, N2, self.num_heads, -1).permute(0, 2, 1, 3)
|
606 |
+
|
607 |
+
# attention
|
608 |
+
x = self.attn(anchor, k, v, self.attn_transform1, x_size, False)
|
609 |
+
x = self.attn(q, anchor, x, self.attn_transform2, x_size)
|
610 |
+
|
611 |
+
# merge windows
|
612 |
+
x = x.view(B_, *stripe_size, C // 3)
|
613 |
+
x = window_reverse(x, stripe_size, x_size) # B H' W' C
|
614 |
+
|
615 |
+
# reverse the shift
|
616 |
+
if self.stripe_shift:
|
617 |
+
x = torch.roll(x, shifts=shift_size, dims=(1, 2))
|
618 |
+
|
619 |
+
x = x.view(B, H * W, C // 3)
|
620 |
+
return x
|
621 |
+
|
622 |
+
def extra_repr(self) -> str:
|
623 |
+
return (
|
624 |
+
f"stripe_size={self.stripe_size}, stripe_groups={self.stripe_groups}, stripe_shift={self.stripe_shift}, "
|
625 |
+
f"pretrained_stripe_size={self.pretrained_stripe_size}, num_heads={self.num_heads}, anchor_window_down_factor={self.anchor_window_down_factor}"
|
626 |
+
)
|
627 |
+
|
628 |
+
def flops(self, N):
|
629 |
+
# calculate flops for 1 window with token length of N
|
630 |
+
flops = 0
|
631 |
+
# qkv = self.qkv(x)
|
632 |
+
flops += N * self.dim * 3 * self.dim
|
633 |
+
# attn = (q @ k.transpose(-2, -1))
|
634 |
+
flops += self.num_heads * N * (self.dim // self.num_heads) * N
|
635 |
+
# x = (attn @ v)
|
636 |
+
flops += self.num_heads * N * N * (self.dim // self.num_heads)
|
637 |
+
# x = self.proj(x)
|
638 |
+
flops += N * self.dim * self.dim
|
639 |
+
return flops
|
640 |
+
|
641 |
+
|
642 |
+
class SeparableConv(nn.Sequential):
|
643 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride, bias, args):
|
644 |
+
m = [
|
645 |
+
nn.Conv2d(
|
646 |
+
in_channels,
|
647 |
+
in_channels,
|
648 |
+
kernel_size,
|
649 |
+
stride,
|
650 |
+
kernel_size // 2,
|
651 |
+
groups=in_channels,
|
652 |
+
bias=bias,
|
653 |
+
)
|
654 |
+
]
|
655 |
+
if args.separable_conv_act:
|
656 |
+
m.append(nn.GELU())
|
657 |
+
m.append(nn.Conv2d(in_channels, out_channels, 1, 1, 0, bias=bias))
|
658 |
+
super(SeparableConv, self).__init__(*m)
|
659 |
+
|
660 |
+
|
661 |
+
class QKVProjection(nn.Module):
|
662 |
+
def __init__(self, dim, qkv_bias, proj_type, args):
|
663 |
+
super(QKVProjection, self).__init__()
|
664 |
+
self.proj_type = proj_type
|
665 |
+
if proj_type == "linear":
|
666 |
+
self.body = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
667 |
+
else:
|
668 |
+
self.body = SeparableConv(dim, dim * 3, 3, 1, qkv_bias, args)
|
669 |
+
|
670 |
+
def forward(self, x, x_size):
|
671 |
+
if self.proj_type == "separable_conv":
|
672 |
+
x = blc_to_bchw(x, x_size)
|
673 |
+
x = self.body(x)
|
674 |
+
if self.proj_type == "separable_conv":
|
675 |
+
x = bchw_to_blc(x)
|
676 |
+
return x
|
677 |
+
|
678 |
+
|
679 |
+
class PatchMerging(nn.Module):
|
680 |
+
r"""Patch Merging Layer.
|
681 |
+
Args:
|
682 |
+
dim (int): Number of input channels.
|
683 |
+
"""
|
684 |
+
|
685 |
+
def __init__(self, in_dim, out_dim):
|
686 |
+
super().__init__()
|
687 |
+
self.in_dim = in_dim
|
688 |
+
self.out_dim = out_dim
|
689 |
+
self.reduction = nn.Linear(4 * in_dim, out_dim, bias=False)
|
690 |
+
|
691 |
+
def forward(self, x, x_size):
|
692 |
+
"""
|
693 |
+
x: B, H*W, C
|
694 |
+
"""
|
695 |
+
H, W = x_size
|
696 |
+
B, L, C = x.shape
|
697 |
+
assert L == H * W, "input feature has wrong size"
|
698 |
+
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
|
699 |
+
|
700 |
+
x = x.view(B, H, W, C)
|
701 |
+
|
702 |
+
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
703 |
+
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
704 |
+
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
705 |
+
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
706 |
+
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
707 |
+
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
708 |
+
|
709 |
+
x = self.reduction(x)
|
710 |
+
|
711 |
+
return x
|
712 |
+
|
713 |
+
|
714 |
+
class AnchorLinear(nn.Module):
|
715 |
+
r"""Linear anchor projection layer
|
716 |
+
Args:
|
717 |
+
dim (int): Number of input channels.
|
718 |
+
"""
|
719 |
+
|
720 |
+
def __init__(self, in_channels, out_channels, down_factor, pooling_mode, bias):
|
721 |
+
super().__init__()
|
722 |
+
self.down_factor = down_factor
|
723 |
+
if pooling_mode == "maxpool":
|
724 |
+
self.pooling = nn.MaxPool2d(down_factor, down_factor)
|
725 |
+
elif pooling_mode == "avgpool":
|
726 |
+
self.pooling = nn.AvgPool2d(down_factor, down_factor)
|
727 |
+
self.reduction = nn.Linear(in_channels, out_channels, bias=bias)
|
728 |
+
|
729 |
+
def forward(self, x, x_size):
|
730 |
+
"""
|
731 |
+
x: B, H*W, C
|
732 |
+
"""
|
733 |
+
x = blc_to_bchw(x, x_size)
|
734 |
+
x = bchw_to_blc(self.pooling(x))
|
735 |
+
x = blc_to_bhwc(self.reduction(x), [s // self.down_factor for s in x_size])
|
736 |
+
return x
|
737 |
+
|
738 |
+
|
739 |
+
class AnchorProjection(nn.Module):
|
740 |
+
def __init__(self, dim, proj_type, one_stage, anchor_window_down_factor, args):
|
741 |
+
super(AnchorProjection, self).__init__()
|
742 |
+
self.proj_type = proj_type
|
743 |
+
self.body = nn.ModuleList([])
|
744 |
+
if one_stage:
|
745 |
+
if proj_type == "patchmerging":
|
746 |
+
m = PatchMerging(dim, dim // 2)
|
747 |
+
elif proj_type == "conv2d":
|
748 |
+
kernel_size = anchor_window_down_factor + 1
|
749 |
+
stride = anchor_window_down_factor
|
750 |
+
padding = kernel_size // 2
|
751 |
+
m = nn.Conv2d(dim, dim // 2, kernel_size, stride, padding)
|
752 |
+
elif proj_type == "separable_conv":
|
753 |
+
kernel_size = anchor_window_down_factor + 1
|
754 |
+
stride = anchor_window_down_factor
|
755 |
+
m = SeparableConv(dim, dim // 2, kernel_size, stride, True, args)
|
756 |
+
elif proj_type.find("pool") >= 0:
|
757 |
+
m = AnchorLinear(
|
758 |
+
dim, dim // 2, anchor_window_down_factor, proj_type, True
|
759 |
+
)
|
760 |
+
self.body.append(m)
|
761 |
+
else:
|
762 |
+
for i in range(int(math.log2(anchor_window_down_factor))):
|
763 |
+
cin = dim if i == 0 else dim // 2
|
764 |
+
if proj_type == "patchmerging":
|
765 |
+
m = PatchMerging(cin, dim // 2)
|
766 |
+
elif proj_type == "conv2d":
|
767 |
+
m = nn.Conv2d(cin, dim // 2, 3, 2, 1)
|
768 |
+
elif proj_type == "separable_conv":
|
769 |
+
m = SeparableConv(cin, dim // 2, 3, 2, True, args)
|
770 |
+
self.body.append(m)
|
771 |
+
|
772 |
+
def forward(self, x, x_size):
|
773 |
+
if self.proj_type.find("conv") >= 0:
|
774 |
+
x = blc_to_bchw(x, x_size)
|
775 |
+
for m in self.body:
|
776 |
+
x = m(x)
|
777 |
+
x = bchw_to_bhwc(x)
|
778 |
+
elif self.proj_type.find("pool") >= 0:
|
779 |
+
for m in self.body:
|
780 |
+
x = m(x, x_size)
|
781 |
+
else:
|
782 |
+
for i, m in enumerate(self.body):
|
783 |
+
x = m(x, [s // 2**i for s in x_size])
|
784 |
+
x = blc_to_bhwc(x, [s // 2 ** (i + 1) for s in x_size])
|
785 |
+
return x
|
786 |
+
|
787 |
+
|
788 |
+
class MixedAttention(nn.Module):
|
789 |
+
r"""Mixed window attention and stripe attention
|
790 |
+
Args:
|
791 |
+
dim (int): Number of input channels.
|
792 |
+
stripe_size (tuple[int]): The height and width of the stripe.
|
793 |
+
num_heads (int): Number of attention heads.
|
794 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
795 |
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
796 |
+
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
797 |
+
pretrained_stripe_size (tuple[int]): The height and width of the stripe in pre-training.
|
798 |
+
"""
|
799 |
+
|
800 |
+
def __init__(
|
801 |
+
self,
|
802 |
+
dim,
|
803 |
+
input_resolution,
|
804 |
+
num_heads_w,
|
805 |
+
num_heads_s,
|
806 |
+
window_size,
|
807 |
+
window_shift,
|
808 |
+
stripe_size,
|
809 |
+
stripe_groups,
|
810 |
+
stripe_shift,
|
811 |
+
qkv_bias=True,
|
812 |
+
qkv_proj_type="linear",
|
813 |
+
anchor_proj_type="separable_conv",
|
814 |
+
anchor_one_stage=True,
|
815 |
+
anchor_window_down_factor=1,
|
816 |
+
attn_drop=0.0,
|
817 |
+
proj_drop=0.0,
|
818 |
+
pretrained_window_size=[0, 0],
|
819 |
+
pretrained_stripe_size=[0, 0],
|
820 |
+
args=None,
|
821 |
+
):
|
822 |
+
|
823 |
+
super(MixedAttention, self).__init__()
|
824 |
+
self.dim = dim
|
825 |
+
self.input_resolution = input_resolution
|
826 |
+
self.use_anchor = anchor_window_down_factor > 1
|
827 |
+
self.args = args
|
828 |
+
# print(args)
|
829 |
+
self.qkv = QKVProjection(dim, qkv_bias, qkv_proj_type, args)
|
830 |
+
if self.use_anchor:
|
831 |
+
# anchor is only used for stripe attention
|
832 |
+
self.anchor = AnchorProjection(
|
833 |
+
dim, anchor_proj_type, anchor_one_stage, anchor_window_down_factor, args
|
834 |
+
)
|
835 |
+
|
836 |
+
self.window_attn = WindowAttention(
|
837 |
+
input_resolution,
|
838 |
+
window_size,
|
839 |
+
num_heads_w,
|
840 |
+
window_shift,
|
841 |
+
attn_drop,
|
842 |
+
pretrained_window_size,
|
843 |
+
args,
|
844 |
+
)
|
845 |
+
|
846 |
+
if self.args.double_window:
|
847 |
+
self.stripe_attn = WindowAttention(
|
848 |
+
input_resolution,
|
849 |
+
window_size,
|
850 |
+
num_heads_w,
|
851 |
+
window_shift,
|
852 |
+
attn_drop,
|
853 |
+
pretrained_window_size,
|
854 |
+
args,
|
855 |
+
)
|
856 |
+
else:
|
857 |
+
if self.use_anchor:
|
858 |
+
self.stripe_attn = AnchorStripeAttention(
|
859 |
+
input_resolution,
|
860 |
+
stripe_size,
|
861 |
+
stripe_groups,
|
862 |
+
stripe_shift,
|
863 |
+
num_heads_s,
|
864 |
+
attn_drop,
|
865 |
+
pretrained_stripe_size,
|
866 |
+
anchor_window_down_factor,
|
867 |
+
args,
|
868 |
+
)
|
869 |
+
else:
|
870 |
+
if self.args.stripe_square:
|
871 |
+
self.stripe_attn = StripeAttention(
|
872 |
+
input_resolution,
|
873 |
+
window_size,
|
874 |
+
[None, None],
|
875 |
+
window_shift,
|
876 |
+
num_heads_s,
|
877 |
+
attn_drop,
|
878 |
+
pretrained_stripe_size,
|
879 |
+
args,
|
880 |
+
)
|
881 |
+
else:
|
882 |
+
self.stripe_attn = StripeAttention(
|
883 |
+
input_resolution,
|
884 |
+
stripe_size,
|
885 |
+
stripe_groups,
|
886 |
+
stripe_shift,
|
887 |
+
num_heads_s,
|
888 |
+
attn_drop,
|
889 |
+
pretrained_stripe_size,
|
890 |
+
args,
|
891 |
+
)
|
892 |
+
if self.args.out_proj_type == "linear":
|
893 |
+
self.proj = nn.Linear(dim, dim)
|
894 |
+
else:
|
895 |
+
self.proj = nn.Conv2d(dim, dim, 3, 1, 1)
|
896 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
897 |
+
|
898 |
+
def forward(self, x, x_size):
|
899 |
+
"""
|
900 |
+
Args:
|
901 |
+
x: input features with shape of (B, L, C)
|
902 |
+
stripe_size: use stripe_size to determine whether the relative positional bias table and index
|
903 |
+
need to be regenerated.
|
904 |
+
"""
|
905 |
+
B, L, C = x.shape
|
906 |
+
|
907 |
+
# qkv projection
|
908 |
+
qkv = self.qkv(x, x_size)
|
909 |
+
qkv_window, qkv_stripe = torch.split(qkv, C * 3 // 2, dim=-1)
|
910 |
+
# anchor projection
|
911 |
+
if self.use_anchor:
|
912 |
+
anchor = self.anchor(x, x_size)
|
913 |
+
|
914 |
+
# attention
|
915 |
+
x_window = self.window_attn(qkv_window, x_size)
|
916 |
+
if self.use_anchor:
|
917 |
+
x_stripe = self.stripe_attn(qkv_stripe, anchor, x_size)
|
918 |
+
else:
|
919 |
+
x_stripe = self.stripe_attn(qkv_stripe, x_size)
|
920 |
+
x = torch.cat([x_window, x_stripe], dim=-1)
|
921 |
+
|
922 |
+
# output projection
|
923 |
+
if self.args.out_proj_type == "linear":
|
924 |
+
x = self.proj(x)
|
925 |
+
else:
|
926 |
+
x = blc_to_bchw(x, x_size)
|
927 |
+
x = bchw_to_blc(self.proj(x))
|
928 |
+
x = self.proj_drop(x)
|
929 |
+
return x
|
930 |
+
|
931 |
+
def extra_repr(self) -> str:
|
932 |
+
return f"dim={self.dim}, input_resolution={self.input_resolution}"
|
933 |
+
|
934 |
+
def flops(self, N):
|
935 |
+
# calculate flops for 1 window with token length of N
|
936 |
+
flops = 0
|
937 |
+
# qkv = self.qkv(x)
|
938 |
+
flops += N * self.dim * 3 * self.dim
|
939 |
+
# attn = (q @ k.transpose(-2, -1))
|
940 |
+
flops += self.num_heads * N * (self.dim // self.num_heads) * N
|
941 |
+
# x = (attn @ v)
|
942 |
+
flops += self.num_heads * N * N * (self.dim // self.num_heads)
|
943 |
+
# x = self.proj(x)
|
944 |
+
flops += N * self.dim * self.dim
|
945 |
+
return flops
|
946 |
+
|
947 |
+
|
948 |
+
class ChannelAttention(nn.Module):
|
949 |
+
"""Channel attention used in RCAN.
|
950 |
+
Args:
|
951 |
+
num_feat (int): Channel number of intermediate features.
|
952 |
+
reduction (int): Channel reduction factor. Default: 16.
|
953 |
+
"""
|
954 |
+
|
955 |
+
def __init__(self, num_feat, reduction=16):
|
956 |
+
super(ChannelAttention, self).__init__()
|
957 |
+
self.attention = nn.Sequential(
|
958 |
+
nn.AdaptiveAvgPool2d(1),
|
959 |
+
nn.Conv2d(num_feat, num_feat // reduction, 1, padding=0),
|
960 |
+
nn.ReLU(inplace=True),
|
961 |
+
nn.Conv2d(num_feat // reduction, num_feat, 1, padding=0),
|
962 |
+
nn.Sigmoid(),
|
963 |
+
)
|
964 |
+
|
965 |
+
def forward(self, x):
|
966 |
+
y = self.attention(x)
|
967 |
+
return x * y
|
968 |
+
|
969 |
+
|
970 |
+
class CAB(nn.Module):
|
971 |
+
def __init__(self, num_feat, compress_ratio=4, reduction=18):
|
972 |
+
super(CAB, self).__init__()
|
973 |
+
|
974 |
+
self.cab = nn.Sequential(
|
975 |
+
nn.Conv2d(num_feat, num_feat // compress_ratio, 3, 1, 1),
|
976 |
+
nn.GELU(),
|
977 |
+
nn.Conv2d(num_feat // compress_ratio, num_feat, 3, 1, 1),
|
978 |
+
ChannelAttention(num_feat, reduction),
|
979 |
+
)
|
980 |
+
|
981 |
+
def forward(self, x, x_size):
|
982 |
+
x = self.cab(blc_to_bchw(x, x_size).contiguous())
|
983 |
+
return bchw_to_blc(x)
|
984 |
+
|
985 |
+
|
986 |
+
class MixAttnTransformerBlock(nn.Module):
|
987 |
+
r"""Mix attention transformer block with shared QKV projection and output projection for mixed attention modules.
|
988 |
+
Args:
|
989 |
+
dim (int): Number of input channels.
|
990 |
+
input_resolution (tuple[int]): Input resulotion.
|
991 |
+
num_heads (int): Number of attention heads.
|
992 |
+
window_size (int): Window size.
|
993 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
994 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
995 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
996 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
997 |
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
998 |
+
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
999 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
1000 |
+
pretrained_stripe_size (int): Window size in pre-training.
|
1001 |
+
attn_type (str, optional): Attention type. Default: cwhv.
|
1002 |
+
c: residual blocks
|
1003 |
+
w: window attention
|
1004 |
+
h: horizontal stripe attention
|
1005 |
+
v: vertical stripe attention
|
1006 |
+
"""
|
1007 |
+
|
1008 |
+
def __init__(
|
1009 |
+
self,
|
1010 |
+
dim,
|
1011 |
+
input_resolution,
|
1012 |
+
num_heads_w,
|
1013 |
+
num_heads_s,
|
1014 |
+
window_size=7,
|
1015 |
+
window_shift=False,
|
1016 |
+
stripe_size=[8, 8],
|
1017 |
+
stripe_groups=[None, None],
|
1018 |
+
stripe_shift=False,
|
1019 |
+
stripe_type="H",
|
1020 |
+
mlp_ratio=4.0,
|
1021 |
+
qkv_bias=True,
|
1022 |
+
qkv_proj_type="linear",
|
1023 |
+
anchor_proj_type="separable_conv",
|
1024 |
+
anchor_one_stage=True,
|
1025 |
+
anchor_window_down_factor=1,
|
1026 |
+
drop=0.0,
|
1027 |
+
attn_drop=0.0,
|
1028 |
+
drop_path=0.0,
|
1029 |
+
act_layer=nn.GELU,
|
1030 |
+
norm_layer=nn.LayerNorm,
|
1031 |
+
pretrained_window_size=[0, 0],
|
1032 |
+
pretrained_stripe_size=[0, 0],
|
1033 |
+
res_scale=1.0,
|
1034 |
+
args=None,
|
1035 |
+
):
|
1036 |
+
super().__init__()
|
1037 |
+
self.dim = dim
|
1038 |
+
self.input_resolution = input_resolution
|
1039 |
+
self.num_heads_w = num_heads_w
|
1040 |
+
self.num_heads_s = num_heads_s
|
1041 |
+
self.window_size = window_size
|
1042 |
+
self.window_shift = window_shift
|
1043 |
+
self.stripe_shift = stripe_shift
|
1044 |
+
self.stripe_type = stripe_type
|
1045 |
+
self.args = args
|
1046 |
+
if self.stripe_type == "W":
|
1047 |
+
self.stripe_size = stripe_size[::-1]
|
1048 |
+
self.stripe_groups = stripe_groups[::-1]
|
1049 |
+
else:
|
1050 |
+
self.stripe_size = stripe_size
|
1051 |
+
self.stripe_groups = stripe_groups
|
1052 |
+
self.mlp_ratio = mlp_ratio
|
1053 |
+
self.res_scale = res_scale
|
1054 |
+
|
1055 |
+
self.attn = MixedAttention(
|
1056 |
+
dim,
|
1057 |
+
input_resolution,
|
1058 |
+
num_heads_w,
|
1059 |
+
num_heads_s,
|
1060 |
+
window_size,
|
1061 |
+
window_shift,
|
1062 |
+
self.stripe_size,
|
1063 |
+
self.stripe_groups,
|
1064 |
+
stripe_shift,
|
1065 |
+
qkv_bias,
|
1066 |
+
qkv_proj_type,
|
1067 |
+
anchor_proj_type,
|
1068 |
+
anchor_one_stage,
|
1069 |
+
anchor_window_down_factor,
|
1070 |
+
attn_drop,
|
1071 |
+
drop,
|
1072 |
+
pretrained_window_size,
|
1073 |
+
pretrained_stripe_size,
|
1074 |
+
args,
|
1075 |
+
)
|
1076 |
+
self.norm1 = norm_layer(dim)
|
1077 |
+
if self.args.local_connection:
|
1078 |
+
self.conv = CAB(dim)
|
1079 |
+
|
1080 |
+
# self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
1081 |
+
|
1082 |
+
# self.mlp = Mlp(
|
1083 |
+
# in_features=dim,
|
1084 |
+
# hidden_features=int(dim * mlp_ratio),
|
1085 |
+
# act_layer=act_layer,
|
1086 |
+
# drop=drop,
|
1087 |
+
# )
|
1088 |
+
# self.norm2 = norm_layer(dim)
|
1089 |
+
|
1090 |
+
def forward(self, x, x_size):
|
1091 |
+
# Mixed attention
|
1092 |
+
if self.args.local_connection:
|
1093 |
+
x = (
|
1094 |
+
x
|
1095 |
+
+ self.res_scale * self.drop_path(self.norm1(self.attn(x, x_size)))
|
1096 |
+
+ self.conv(x, x_size)
|
1097 |
+
)
|
1098 |
+
else:
|
1099 |
+
x = x + self.res_scale * self.drop_path(self.norm1(self.attn(x, x_size)))
|
1100 |
+
# FFN
|
1101 |
+
x = x + self.res_scale * self.drop_path(self.norm2(self.mlp(x)))
|
1102 |
+
|
1103 |
+
# return x
|
1104 |
+
|
1105 |
+
def extra_repr(self) -> str:
|
1106 |
+
return (
|
1107 |
+
f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads=({self.num_heads_w}, {self.num_heads_s}), "
|
1108 |
+
f"window_size={self.window_size}, window_shift={self.window_shift}, "
|
1109 |
+
f"stripe_size={self.stripe_size}, stripe_groups={self.stripe_groups}, stripe_shift={self.stripe_shift}, self.stripe_type={self.stripe_type}, "
|
1110 |
+
f"mlp_ratio={self.mlp_ratio}, res_scale={self.res_scale}"
|
1111 |
+
)
|
1112 |
+
|
1113 |
+
|
1114 |
+
# def flops(self):
|
1115 |
+
# flops = 0
|
1116 |
+
# H, W = self.input_resolution
|
1117 |
+
# # norm1
|
1118 |
+
# flops += self.dim * H * W
|
1119 |
+
# # W-MSA/SW-MSA
|
1120 |
+
# nW = H * W / self.stripe_size[0] / self.stripe_size[1]
|
1121 |
+
# flops += nW * self.attn.flops(self.stripe_size[0] * self.stripe_size[1])
|
1122 |
+
# # mlp
|
1123 |
+
# flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
|
1124 |
+
# # norm2
|
1125 |
+
# flops += self.dim * H * W
|
1126 |
+
# return flops
|
architecture/grl_common/mixed_attn_block_efficient.py
ADDED
@@ -0,0 +1,568 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from abc import ABC
|
3 |
+
from math import prod
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from timm.models.layers import DropPath
|
9 |
+
|
10 |
+
|
11 |
+
from architecture.grl_common.mixed_attn_block import (
|
12 |
+
AnchorProjection,
|
13 |
+
CAB,
|
14 |
+
CPB_MLP,
|
15 |
+
QKVProjection,
|
16 |
+
)
|
17 |
+
from architecture.grl_common.ops import (
|
18 |
+
window_partition,
|
19 |
+
window_reverse,
|
20 |
+
)
|
21 |
+
from architecture.grl_common.swin_v1_block import Mlp
|
22 |
+
|
23 |
+
|
24 |
+
class AffineTransform(nn.Module):
|
25 |
+
r"""Affine transformation of the attention map.
|
26 |
+
The window could be a square window or a stripe window. Supports attention between different window sizes
|
27 |
+
"""
|
28 |
+
|
29 |
+
def __init__(self, num_heads):
|
30 |
+
super(AffineTransform, self).__init__()
|
31 |
+
logit_scale = torch.log(10 * torch.ones((num_heads, 1, 1)))
|
32 |
+
self.logit_scale = nn.Parameter(logit_scale, requires_grad=True)
|
33 |
+
|
34 |
+
# mlp to generate continuous relative position bias
|
35 |
+
self.cpb_mlp = CPB_MLP(2, num_heads)
|
36 |
+
|
37 |
+
def forward(self, attn, relative_coords_table, relative_position_index, mask):
|
38 |
+
B_, H, N1, N2 = attn.shape
|
39 |
+
# logit scale
|
40 |
+
attn = attn * torch.clamp(self.logit_scale, max=math.log(1.0 / 0.01)).exp()
|
41 |
+
|
42 |
+
bias_table = self.cpb_mlp(relative_coords_table) # 2*Wh-1, 2*Ww-1, num_heads
|
43 |
+
bias_table = bias_table.view(-1, H)
|
44 |
+
|
45 |
+
bias = bias_table[relative_position_index.view(-1)]
|
46 |
+
bias = bias.view(N1, N2, -1).permute(2, 0, 1).contiguous()
|
47 |
+
# nH, Wh*Ww, Wh*Ww
|
48 |
+
bias = 16 * torch.sigmoid(bias)
|
49 |
+
attn = attn + bias.unsqueeze(0)
|
50 |
+
|
51 |
+
# W-MSA/SW-MSA
|
52 |
+
# shift attention mask
|
53 |
+
if mask is not None:
|
54 |
+
nW = mask.shape[0]
|
55 |
+
mask = mask.unsqueeze(1).unsqueeze(0)
|
56 |
+
attn = attn.view(B_ // nW, nW, H, N1, N2) + mask
|
57 |
+
attn = attn.view(-1, H, N1, N2)
|
58 |
+
|
59 |
+
return attn
|
60 |
+
|
61 |
+
|
62 |
+
def _get_stripe_info(stripe_size_in, stripe_groups_in, stripe_shift, input_resolution):
|
63 |
+
stripe_size, shift_size = [], []
|
64 |
+
for s, g, d in zip(stripe_size_in, stripe_groups_in, input_resolution):
|
65 |
+
if g is None:
|
66 |
+
stripe_size.append(s)
|
67 |
+
shift_size.append(s // 2 if stripe_shift else 0)
|
68 |
+
else:
|
69 |
+
stripe_size.append(d // g)
|
70 |
+
shift_size.append(0 if g == 1 else d // (g * 2))
|
71 |
+
return stripe_size, shift_size
|
72 |
+
|
73 |
+
|
74 |
+
class Attention(ABC, nn.Module):
|
75 |
+
def __init__(self):
|
76 |
+
super(Attention, self).__init__()
|
77 |
+
|
78 |
+
def attn(self, q, k, v, attn_transform, table, index, mask, reshape=True):
|
79 |
+
# q, k, v: # nW*B, H, wh*ww, dim
|
80 |
+
# cosine attention map
|
81 |
+
B_, _, H, head_dim = q.shape
|
82 |
+
if self.euclidean_dist:
|
83 |
+
# print("use euclidean distance")
|
84 |
+
attn = torch.norm(q.unsqueeze(-2) - k.unsqueeze(-3), dim=-1)
|
85 |
+
else:
|
86 |
+
attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)
|
87 |
+
attn = attn_transform(attn, table, index, mask)
|
88 |
+
# attention
|
89 |
+
attn = self.softmax(attn)
|
90 |
+
attn = self.attn_drop(attn)
|
91 |
+
x = attn @ v # B_, H, N1, head_dim
|
92 |
+
if reshape:
|
93 |
+
x = x.transpose(1, 2).reshape(B_, -1, H * head_dim)
|
94 |
+
# B_, N, C
|
95 |
+
return x
|
96 |
+
|
97 |
+
|
98 |
+
class WindowAttention(Attention):
|
99 |
+
r"""Window attention. QKV is the input to the forward method.
|
100 |
+
Args:
|
101 |
+
num_heads (int): Number of attention heads.
|
102 |
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
103 |
+
pretrained_window_size (tuple[int]): The height and width of the window in pre-training.
|
104 |
+
"""
|
105 |
+
|
106 |
+
def __init__(
|
107 |
+
self,
|
108 |
+
input_resolution,
|
109 |
+
window_size,
|
110 |
+
num_heads,
|
111 |
+
window_shift=False,
|
112 |
+
attn_drop=0.0,
|
113 |
+
pretrained_window_size=[0, 0],
|
114 |
+
args=None,
|
115 |
+
):
|
116 |
+
|
117 |
+
super(WindowAttention, self).__init__()
|
118 |
+
self.input_resolution = input_resolution
|
119 |
+
self.window_size = window_size
|
120 |
+
self.pretrained_window_size = pretrained_window_size
|
121 |
+
self.num_heads = num_heads
|
122 |
+
self.shift_size = window_size[0] // 2 if window_shift else 0
|
123 |
+
self.euclidean_dist = args.euclidean_dist
|
124 |
+
|
125 |
+
self.attn_transform = AffineTransform(num_heads)
|
126 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
127 |
+
self.softmax = nn.Softmax(dim=-1)
|
128 |
+
|
129 |
+
def forward(self, qkv, x_size, table, index, mask):
|
130 |
+
"""
|
131 |
+
Args:
|
132 |
+
qkv: input QKV features with shape of (B, L, 3C)
|
133 |
+
x_size: use x_size to determine whether the relative positional bias table and index
|
134 |
+
need to be regenerated.
|
135 |
+
"""
|
136 |
+
H, W = x_size
|
137 |
+
B, L, C = qkv.shape
|
138 |
+
qkv = qkv.view(B, H, W, C)
|
139 |
+
|
140 |
+
# cyclic shift
|
141 |
+
if self.shift_size > 0:
|
142 |
+
qkv = torch.roll(
|
143 |
+
qkv, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)
|
144 |
+
)
|
145 |
+
|
146 |
+
# partition windows
|
147 |
+
qkv = window_partition(qkv, self.window_size) # nW*B, wh, ww, C
|
148 |
+
qkv = qkv.view(-1, prod(self.window_size), C) # nW*B, wh*ww, C
|
149 |
+
|
150 |
+
B_, N, _ = qkv.shape
|
151 |
+
qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
152 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # nW*B, H, wh*ww, dim
|
153 |
+
|
154 |
+
# attention
|
155 |
+
x = self.attn(q, k, v, self.attn_transform, table, index, mask)
|
156 |
+
|
157 |
+
# merge windows
|
158 |
+
x = x.view(-1, *self.window_size, C // 3)
|
159 |
+
x = window_reverse(x, self.window_size, x_size) # B, H, W, C/3
|
160 |
+
|
161 |
+
# reverse cyclic shift
|
162 |
+
if self.shift_size > 0:
|
163 |
+
x = torch.roll(x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
164 |
+
x = x.view(B, L, C // 3)
|
165 |
+
|
166 |
+
return x
|
167 |
+
|
168 |
+
def extra_repr(self) -> str:
|
169 |
+
return (
|
170 |
+
f"window_size={self.window_size}, shift_size={self.shift_size}, "
|
171 |
+
f"pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}"
|
172 |
+
)
|
173 |
+
|
174 |
+
def flops(self, N):
|
175 |
+
pass
|
176 |
+
|
177 |
+
|
178 |
+
class AnchorStripeAttention(Attention):
|
179 |
+
r"""Stripe attention
|
180 |
+
Args:
|
181 |
+
stripe_size (tuple[int]): The height and width of the stripe.
|
182 |
+
num_heads (int): Number of attention heads.
|
183 |
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
184 |
+
pretrained_stripe_size (tuple[int]): The height and width of the stripe in pre-training.
|
185 |
+
"""
|
186 |
+
|
187 |
+
def __init__(
|
188 |
+
self,
|
189 |
+
input_resolution,
|
190 |
+
stripe_size,
|
191 |
+
stripe_groups,
|
192 |
+
stripe_shift,
|
193 |
+
num_heads,
|
194 |
+
attn_drop=0.0,
|
195 |
+
pretrained_stripe_size=[0, 0],
|
196 |
+
anchor_window_down_factor=1,
|
197 |
+
args=None,
|
198 |
+
):
|
199 |
+
|
200 |
+
super(AnchorStripeAttention, self).__init__()
|
201 |
+
self.input_resolution = input_resolution
|
202 |
+
self.stripe_size = stripe_size # Wh, Ww
|
203 |
+
self.stripe_groups = stripe_groups
|
204 |
+
self.stripe_shift = stripe_shift
|
205 |
+
self.num_heads = num_heads
|
206 |
+
self.pretrained_stripe_size = pretrained_stripe_size
|
207 |
+
self.anchor_window_down_factor = anchor_window_down_factor
|
208 |
+
self.euclidean_dist = args.euclidean_dist
|
209 |
+
|
210 |
+
self.attn_transform1 = AffineTransform(num_heads)
|
211 |
+
self.attn_transform2 = AffineTransform(num_heads)
|
212 |
+
|
213 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
214 |
+
self.softmax = nn.Softmax(dim=-1)
|
215 |
+
|
216 |
+
def forward(
|
217 |
+
self, qkv, anchor, x_size, table, index_a2w, index_w2a, mask_a2w, mask_w2a
|
218 |
+
):
|
219 |
+
"""
|
220 |
+
Args:
|
221 |
+
qkv: input features with shape of (B, L, C)
|
222 |
+
anchor:
|
223 |
+
x_size: use stripe_size to determine whether the relative positional bias table and index
|
224 |
+
need to be regenerated.
|
225 |
+
"""
|
226 |
+
H, W = x_size
|
227 |
+
B, L, C = qkv.shape
|
228 |
+
qkv = qkv.view(B, H, W, C)
|
229 |
+
|
230 |
+
stripe_size, shift_size = _get_stripe_info(
|
231 |
+
self.stripe_size, self.stripe_groups, self.stripe_shift, x_size
|
232 |
+
)
|
233 |
+
anchor_stripe_size = [s // self.anchor_window_down_factor for s in stripe_size]
|
234 |
+
anchor_shift_size = [s // self.anchor_window_down_factor for s in shift_size]
|
235 |
+
# cyclic shift
|
236 |
+
if self.stripe_shift:
|
237 |
+
qkv = torch.roll(qkv, shifts=(-shift_size[0], -shift_size[1]), dims=(1, 2))
|
238 |
+
anchor = torch.roll(
|
239 |
+
anchor,
|
240 |
+
shifts=(-anchor_shift_size[0], -anchor_shift_size[1]),
|
241 |
+
dims=(1, 2),
|
242 |
+
)
|
243 |
+
|
244 |
+
# partition windows
|
245 |
+
qkv = window_partition(qkv, stripe_size) # nW*B, wh, ww, C
|
246 |
+
qkv = qkv.view(-1, prod(stripe_size), C) # nW*B, wh*ww, C
|
247 |
+
anchor = window_partition(anchor, anchor_stripe_size)
|
248 |
+
anchor = anchor.view(-1, prod(anchor_stripe_size), C // 3)
|
249 |
+
|
250 |
+
B_, N1, _ = qkv.shape
|
251 |
+
N2 = anchor.shape[1]
|
252 |
+
qkv = qkv.reshape(B_, N1, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
253 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
254 |
+
anchor = anchor.reshape(B_, N2, self.num_heads, -1).permute(0, 2, 1, 3)
|
255 |
+
|
256 |
+
# attention
|
257 |
+
x = self.attn(
|
258 |
+
anchor, k, v, self.attn_transform1, table, index_a2w, mask_a2w, False
|
259 |
+
)
|
260 |
+
x = self.attn(q, anchor, x, self.attn_transform2, table, index_w2a, mask_w2a)
|
261 |
+
|
262 |
+
# merge windows
|
263 |
+
x = x.view(B_, *stripe_size, C // 3)
|
264 |
+
x = window_reverse(x, stripe_size, x_size) # B H' W' C
|
265 |
+
|
266 |
+
# reverse the shift
|
267 |
+
if self.stripe_shift:
|
268 |
+
x = torch.roll(x, shifts=shift_size, dims=(1, 2))
|
269 |
+
|
270 |
+
x = x.view(B, H * W, C // 3)
|
271 |
+
return x
|
272 |
+
|
273 |
+
def extra_repr(self) -> str:
|
274 |
+
return (
|
275 |
+
f"stripe_size={self.stripe_size}, stripe_groups={self.stripe_groups}, stripe_shift={self.stripe_shift}, "
|
276 |
+
f"pretrained_stripe_size={self.pretrained_stripe_size}, num_heads={self.num_heads}, anchor_window_down_factor={self.anchor_window_down_factor}"
|
277 |
+
)
|
278 |
+
|
279 |
+
def flops(self, N):
|
280 |
+
pass
|
281 |
+
|
282 |
+
|
283 |
+
class MixedAttention(nn.Module):
|
284 |
+
r"""Mixed window attention and stripe attention
|
285 |
+
Args:
|
286 |
+
dim (int): Number of input channels.
|
287 |
+
stripe_size (tuple[int]): The height and width of the stripe.
|
288 |
+
num_heads (int): Number of attention heads.
|
289 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
290 |
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
291 |
+
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
292 |
+
pretrained_stripe_size (tuple[int]): The height and width of the stripe in pre-training.
|
293 |
+
"""
|
294 |
+
|
295 |
+
def __init__(
|
296 |
+
self,
|
297 |
+
dim,
|
298 |
+
input_resolution,
|
299 |
+
num_heads_w,
|
300 |
+
num_heads_s,
|
301 |
+
window_size,
|
302 |
+
window_shift,
|
303 |
+
stripe_size,
|
304 |
+
stripe_groups,
|
305 |
+
stripe_shift,
|
306 |
+
qkv_bias=True,
|
307 |
+
qkv_proj_type="linear",
|
308 |
+
anchor_proj_type="separable_conv",
|
309 |
+
anchor_one_stage=True,
|
310 |
+
anchor_window_down_factor=1,
|
311 |
+
attn_drop=0.0,
|
312 |
+
proj_drop=0.0,
|
313 |
+
pretrained_window_size=[0, 0],
|
314 |
+
pretrained_stripe_size=[0, 0],
|
315 |
+
args=None,
|
316 |
+
):
|
317 |
+
|
318 |
+
super(MixedAttention, self).__init__()
|
319 |
+
self.dim = dim
|
320 |
+
self.input_resolution = input_resolution
|
321 |
+
self.args = args
|
322 |
+
# print(args)
|
323 |
+
self.qkv = QKVProjection(dim, qkv_bias, qkv_proj_type, args)
|
324 |
+
# anchor is only used for stripe attention
|
325 |
+
self.anchor = AnchorProjection(
|
326 |
+
dim, anchor_proj_type, anchor_one_stage, anchor_window_down_factor, args
|
327 |
+
)
|
328 |
+
|
329 |
+
self.window_attn = WindowAttention(
|
330 |
+
input_resolution,
|
331 |
+
window_size,
|
332 |
+
num_heads_w,
|
333 |
+
window_shift,
|
334 |
+
attn_drop,
|
335 |
+
pretrained_window_size,
|
336 |
+
args,
|
337 |
+
)
|
338 |
+
self.stripe_attn = AnchorStripeAttention(
|
339 |
+
input_resolution,
|
340 |
+
stripe_size,
|
341 |
+
stripe_groups,
|
342 |
+
stripe_shift,
|
343 |
+
num_heads_s,
|
344 |
+
attn_drop,
|
345 |
+
pretrained_stripe_size,
|
346 |
+
anchor_window_down_factor,
|
347 |
+
args,
|
348 |
+
)
|
349 |
+
self.proj = nn.Linear(dim, dim)
|
350 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
351 |
+
|
352 |
+
def forward(self, x, x_size, table_index_mask):
|
353 |
+
"""
|
354 |
+
Args:
|
355 |
+
x: input features with shape of (B, L, C)
|
356 |
+
stripe_size: use stripe_size to determine whether the relative positional bias table and index
|
357 |
+
need to be regenerated.
|
358 |
+
"""
|
359 |
+
B, L, C = x.shape
|
360 |
+
|
361 |
+
# qkv projection
|
362 |
+
qkv = self.qkv(x, x_size)
|
363 |
+
qkv_window, qkv_stripe = torch.split(qkv, C * 3 // 2, dim=-1)
|
364 |
+
# anchor projection
|
365 |
+
anchor = self.anchor(x, x_size)
|
366 |
+
|
367 |
+
# attention
|
368 |
+
x_window = self.window_attn(
|
369 |
+
qkv_window, x_size, *self._get_table_index_mask(table_index_mask, True)
|
370 |
+
)
|
371 |
+
x_stripe = self.stripe_attn(
|
372 |
+
qkv_stripe,
|
373 |
+
anchor,
|
374 |
+
x_size,
|
375 |
+
*self._get_table_index_mask(table_index_mask, False),
|
376 |
+
)
|
377 |
+
x = torch.cat([x_window, x_stripe], dim=-1)
|
378 |
+
|
379 |
+
# output projection
|
380 |
+
x = self.proj(x)
|
381 |
+
x = self.proj_drop(x)
|
382 |
+
return x
|
383 |
+
|
384 |
+
def _get_table_index_mask(self, table_index_mask, window_attn=True):
|
385 |
+
if window_attn:
|
386 |
+
return (
|
387 |
+
table_index_mask["table_w"],
|
388 |
+
table_index_mask["index_w"],
|
389 |
+
table_index_mask["mask_w"],
|
390 |
+
)
|
391 |
+
else:
|
392 |
+
return (
|
393 |
+
table_index_mask["table_s"],
|
394 |
+
table_index_mask["index_a2w"],
|
395 |
+
table_index_mask["index_w2a"],
|
396 |
+
table_index_mask["mask_a2w"],
|
397 |
+
table_index_mask["mask_w2a"],
|
398 |
+
)
|
399 |
+
|
400 |
+
def extra_repr(self) -> str:
|
401 |
+
return f"dim={self.dim}, input_resolution={self.input_resolution}"
|
402 |
+
|
403 |
+
def flops(self, N):
|
404 |
+
pass
|
405 |
+
|
406 |
+
|
407 |
+
class EfficientMixAttnTransformerBlock(nn.Module):
|
408 |
+
r"""Mix attention transformer block with shared QKV projection and output projection for mixed attention modules.
|
409 |
+
Args:
|
410 |
+
dim (int): Number of input channels.
|
411 |
+
input_resolution (tuple[int]): Input resulotion.
|
412 |
+
num_heads (int): Number of attention heads.
|
413 |
+
window_size (int): Window size.
|
414 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
415 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
416 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
417 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
418 |
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
419 |
+
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
420 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
421 |
+
pretrained_stripe_size (int): Window size in pre-training.
|
422 |
+
attn_type (str, optional): Attention type. Default: cwhv.
|
423 |
+
c: residual blocks
|
424 |
+
w: window attention
|
425 |
+
h: horizontal stripe attention
|
426 |
+
v: vertical stripe attention
|
427 |
+
"""
|
428 |
+
|
429 |
+
def __init__(
|
430 |
+
self,
|
431 |
+
dim,
|
432 |
+
input_resolution,
|
433 |
+
num_heads_w,
|
434 |
+
num_heads_s,
|
435 |
+
window_size=7,
|
436 |
+
window_shift=False,
|
437 |
+
stripe_size=[8, 8],
|
438 |
+
stripe_groups=[None, None],
|
439 |
+
stripe_shift=False,
|
440 |
+
stripe_type="H",
|
441 |
+
mlp_ratio=4.0,
|
442 |
+
qkv_bias=True,
|
443 |
+
qkv_proj_type="linear",
|
444 |
+
anchor_proj_type="separable_conv",
|
445 |
+
anchor_one_stage=True,
|
446 |
+
anchor_window_down_factor=1,
|
447 |
+
drop=0.0,
|
448 |
+
attn_drop=0.0,
|
449 |
+
drop_path=0.0,
|
450 |
+
act_layer=nn.GELU,
|
451 |
+
norm_layer=nn.LayerNorm,
|
452 |
+
pretrained_window_size=[0, 0],
|
453 |
+
pretrained_stripe_size=[0, 0],
|
454 |
+
res_scale=1.0,
|
455 |
+
args=None,
|
456 |
+
):
|
457 |
+
super().__init__()
|
458 |
+
self.dim = dim
|
459 |
+
self.input_resolution = input_resolution
|
460 |
+
self.num_heads_w = num_heads_w
|
461 |
+
self.num_heads_s = num_heads_s
|
462 |
+
self.window_size = window_size
|
463 |
+
self.window_shift = window_shift
|
464 |
+
self.stripe_shift = stripe_shift
|
465 |
+
self.stripe_type = stripe_type
|
466 |
+
self.args = args
|
467 |
+
if self.stripe_type == "W":
|
468 |
+
self.stripe_size = stripe_size[::-1]
|
469 |
+
self.stripe_groups = stripe_groups[::-1]
|
470 |
+
else:
|
471 |
+
self.stripe_size = stripe_size
|
472 |
+
self.stripe_groups = stripe_groups
|
473 |
+
self.mlp_ratio = mlp_ratio
|
474 |
+
self.res_scale = res_scale
|
475 |
+
|
476 |
+
self.attn = MixedAttention(
|
477 |
+
dim,
|
478 |
+
input_resolution,
|
479 |
+
num_heads_w,
|
480 |
+
num_heads_s,
|
481 |
+
window_size,
|
482 |
+
window_shift,
|
483 |
+
self.stripe_size,
|
484 |
+
self.stripe_groups,
|
485 |
+
stripe_shift,
|
486 |
+
qkv_bias,
|
487 |
+
qkv_proj_type,
|
488 |
+
anchor_proj_type,
|
489 |
+
anchor_one_stage,
|
490 |
+
anchor_window_down_factor,
|
491 |
+
attn_drop,
|
492 |
+
drop,
|
493 |
+
pretrained_window_size,
|
494 |
+
pretrained_stripe_size,
|
495 |
+
args,
|
496 |
+
)
|
497 |
+
self.norm1 = norm_layer(dim)
|
498 |
+
if self.args.local_connection:
|
499 |
+
self.conv = CAB(dim)
|
500 |
+
|
501 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
502 |
+
|
503 |
+
self.mlp = Mlp(
|
504 |
+
in_features=dim,
|
505 |
+
hidden_features=int(dim * mlp_ratio),
|
506 |
+
act_layer=act_layer,
|
507 |
+
drop=drop,
|
508 |
+
)
|
509 |
+
self.norm2 = norm_layer(dim)
|
510 |
+
|
511 |
+
def _get_table_index_mask(self, all_table_index_mask):
|
512 |
+
table_index_mask = {
|
513 |
+
"table_w": all_table_index_mask["table_w"],
|
514 |
+
"index_w": all_table_index_mask["index_w"],
|
515 |
+
}
|
516 |
+
if self.stripe_type == "W":
|
517 |
+
table_index_mask["table_s"] = all_table_index_mask["table_sv"]
|
518 |
+
table_index_mask["index_a2w"] = all_table_index_mask["index_sv_a2w"]
|
519 |
+
table_index_mask["index_w2a"] = all_table_index_mask["index_sv_w2a"]
|
520 |
+
else:
|
521 |
+
table_index_mask["table_s"] = all_table_index_mask["table_sh"]
|
522 |
+
table_index_mask["index_a2w"] = all_table_index_mask["index_sh_a2w"]
|
523 |
+
table_index_mask["index_w2a"] = all_table_index_mask["index_sh_w2a"]
|
524 |
+
if self.window_shift:
|
525 |
+
table_index_mask["mask_w"] = all_table_index_mask["mask_w"]
|
526 |
+
else:
|
527 |
+
table_index_mask["mask_w"] = None
|
528 |
+
if self.stripe_shift:
|
529 |
+
if self.stripe_type == "W":
|
530 |
+
table_index_mask["mask_a2w"] = all_table_index_mask["mask_sv_a2w"]
|
531 |
+
table_index_mask["mask_w2a"] = all_table_index_mask["mask_sv_w2a"]
|
532 |
+
else:
|
533 |
+
table_index_mask["mask_a2w"] = all_table_index_mask["mask_sh_a2w"]
|
534 |
+
table_index_mask["mask_w2a"] = all_table_index_mask["mask_sh_w2a"]
|
535 |
+
else:
|
536 |
+
table_index_mask["mask_a2w"] = None
|
537 |
+
table_index_mask["mask_w2a"] = None
|
538 |
+
return table_index_mask
|
539 |
+
|
540 |
+
def forward(self, x, x_size, all_table_index_mask):
|
541 |
+
# Mixed attention
|
542 |
+
table_index_mask = self._get_table_index_mask(all_table_index_mask)
|
543 |
+
if self.args.local_connection:
|
544 |
+
x = (
|
545 |
+
x
|
546 |
+
+ self.res_scale
|
547 |
+
* self.drop_path(self.norm1(self.attn(x, x_size, table_index_mask)))
|
548 |
+
+ self.conv(x, x_size)
|
549 |
+
)
|
550 |
+
else:
|
551 |
+
x = x + self.res_scale * self.drop_path(
|
552 |
+
self.norm1(self.attn(x, x_size, table_index_mask))
|
553 |
+
)
|
554 |
+
# FFN
|
555 |
+
x = x + self.res_scale * self.drop_path(self.norm2(self.mlp(x)))
|
556 |
+
|
557 |
+
return x
|
558 |
+
|
559 |
+
def extra_repr(self) -> str:
|
560 |
+
return (
|
561 |
+
f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads=({self.num_heads_w}, {self.num_heads_s}), "
|
562 |
+
f"window_size={self.window_size}, window_shift={self.window_shift}, "
|
563 |
+
f"stripe_size={self.stripe_size}, stripe_groups={self.stripe_groups}, stripe_shift={self.stripe_shift}, self.stripe_type={self.stripe_type}, "
|
564 |
+
f"mlp_ratio={self.mlp_ratio}, res_scale={self.res_scale}"
|
565 |
+
)
|
566 |
+
|
567 |
+
def flops(self):
|
568 |
+
pass
|
architecture/grl_common/ops.py
ADDED
@@ -0,0 +1,551 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from math import prod
|
2 |
+
from typing import Tuple
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from timm.models.layers import to_2tuple
|
7 |
+
|
8 |
+
|
9 |
+
def bchw_to_bhwc(x: torch.Tensor) -> torch.Tensor:
|
10 |
+
"""Permutes a tensor from the shape (B, C, H, W) to (B, H, W, C)."""
|
11 |
+
return x.permute(0, 2, 3, 1)
|
12 |
+
|
13 |
+
|
14 |
+
def bhwc_to_bchw(x: torch.Tensor) -> torch.Tensor:
|
15 |
+
"""Permutes a tensor from the shape (B, H, W, C) to (B, C, H, W)."""
|
16 |
+
return x.permute(0, 3, 1, 2)
|
17 |
+
|
18 |
+
|
19 |
+
def bchw_to_blc(x: torch.Tensor) -> torch.Tensor:
|
20 |
+
"""Rearrange a tensor from the shape (B, C, H, W) to (B, L, C)."""
|
21 |
+
return x.flatten(2).transpose(1, 2)
|
22 |
+
|
23 |
+
|
24 |
+
def blc_to_bchw(x: torch.Tensor, x_size: Tuple) -> torch.Tensor:
|
25 |
+
"""Rearrange a tensor from the shape (B, L, C) to (B, C, H, W)."""
|
26 |
+
B, L, C = x.shape
|
27 |
+
return x.transpose(1, 2).view(B, C, *x_size)
|
28 |
+
|
29 |
+
|
30 |
+
def blc_to_bhwc(x: torch.Tensor, x_size: Tuple) -> torch.Tensor:
|
31 |
+
"""Rearrange a tensor from the shape (B, L, C) to (B, H, W, C)."""
|
32 |
+
B, L, C = x.shape
|
33 |
+
return x.view(B, *x_size, C)
|
34 |
+
|
35 |
+
|
36 |
+
def window_partition(x, window_size: Tuple[int, int]):
|
37 |
+
"""
|
38 |
+
Args:
|
39 |
+
x: (B, H, W, C)
|
40 |
+
window_size (int): window size
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
windows: (num_windows*B, window_size, window_size, C)
|
44 |
+
"""
|
45 |
+
B, H, W, C = x.shape
|
46 |
+
x = x.view(
|
47 |
+
B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C
|
48 |
+
)
|
49 |
+
windows = (
|
50 |
+
x.permute(0, 1, 3, 2, 4, 5)
|
51 |
+
.contiguous()
|
52 |
+
.view(-1, window_size[0], window_size[1], C)
|
53 |
+
)
|
54 |
+
return windows
|
55 |
+
|
56 |
+
|
57 |
+
def window_reverse(windows, window_size: Tuple[int, int], img_size: Tuple[int, int]):
|
58 |
+
"""
|
59 |
+
Args:
|
60 |
+
windows: (num_windows * B, window_size[0], window_size[1], C)
|
61 |
+
window_size (Tuple[int, int]): Window size
|
62 |
+
img_size (Tuple[int, int]): Image size
|
63 |
+
|
64 |
+
Returns:
|
65 |
+
x: (B, H, W, C)
|
66 |
+
"""
|
67 |
+
H, W = img_size
|
68 |
+
B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1]))
|
69 |
+
x = windows.view(
|
70 |
+
B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1
|
71 |
+
)
|
72 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
73 |
+
return x
|
74 |
+
|
75 |
+
|
76 |
+
def _fill_window(input_resolution, window_size, shift_size=None):
|
77 |
+
if shift_size is None:
|
78 |
+
shift_size = [s // 2 for s in window_size]
|
79 |
+
|
80 |
+
img_mask = torch.zeros((1, *input_resolution, 1)) # 1 H W 1
|
81 |
+
h_slices = (
|
82 |
+
slice(0, -window_size[0]),
|
83 |
+
slice(-window_size[0], -shift_size[0]),
|
84 |
+
slice(-shift_size[0], None),
|
85 |
+
)
|
86 |
+
w_slices = (
|
87 |
+
slice(0, -window_size[1]),
|
88 |
+
slice(-window_size[1], -shift_size[1]),
|
89 |
+
slice(-shift_size[1], None),
|
90 |
+
)
|
91 |
+
cnt = 0
|
92 |
+
for h in h_slices:
|
93 |
+
for w in w_slices:
|
94 |
+
img_mask[:, h, w, :] = cnt
|
95 |
+
cnt += 1
|
96 |
+
|
97 |
+
mask_windows = window_partition(img_mask, window_size)
|
98 |
+
# nW, window_size, window_size, 1
|
99 |
+
mask_windows = mask_windows.view(-1, prod(window_size))
|
100 |
+
return mask_windows
|
101 |
+
|
102 |
+
|
103 |
+
#####################################
|
104 |
+
# Different versions of the functions
|
105 |
+
# 1) Swin Transformer, SwinIR, Square window attention in GRL;
|
106 |
+
# 2) Early development of the decomposition-based efficient attention mechanism (efficient_win_attn.py);
|
107 |
+
# 3) GRL. Window-anchor attention mechanism.
|
108 |
+
# 1) & 3) are still useful
|
109 |
+
#####################################
|
110 |
+
|
111 |
+
|
112 |
+
def calculate_mask(input_resolution, window_size, shift_size):
|
113 |
+
"""
|
114 |
+
Use case: 1)
|
115 |
+
"""
|
116 |
+
# calculate attention mask for SW-MSA
|
117 |
+
if isinstance(shift_size, int):
|
118 |
+
shift_size = to_2tuple(shift_size)
|
119 |
+
mask_windows = _fill_window(input_resolution, window_size, shift_size)
|
120 |
+
|
121 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
122 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
|
123 |
+
attn_mask == 0, float(0.0)
|
124 |
+
) # nW, window_size**2, window_size**2
|
125 |
+
|
126 |
+
return attn_mask
|
127 |
+
|
128 |
+
|
129 |
+
def calculate_mask_all(
|
130 |
+
input_resolution,
|
131 |
+
window_size,
|
132 |
+
shift_size,
|
133 |
+
anchor_window_down_factor=1,
|
134 |
+
window_to_anchor=True,
|
135 |
+
):
|
136 |
+
"""
|
137 |
+
Use case: 3)
|
138 |
+
"""
|
139 |
+
# calculate attention mask for SW-MSA
|
140 |
+
anchor_resolution = [s // anchor_window_down_factor for s in input_resolution]
|
141 |
+
aws = [s // anchor_window_down_factor for s in window_size]
|
142 |
+
anchor_shift = [s // anchor_window_down_factor for s in shift_size]
|
143 |
+
|
144 |
+
# mask of window1: nW, Wh**Ww
|
145 |
+
mask_windows = _fill_window(input_resolution, window_size, shift_size)
|
146 |
+
# mask of window2: nW, AWh*AWw
|
147 |
+
mask_anchor = _fill_window(anchor_resolution, aws, anchor_shift)
|
148 |
+
|
149 |
+
if window_to_anchor:
|
150 |
+
attn_mask = mask_windows.unsqueeze(2) - mask_anchor.unsqueeze(1)
|
151 |
+
else:
|
152 |
+
attn_mask = mask_anchor.unsqueeze(2) - mask_windows.unsqueeze(1)
|
153 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
|
154 |
+
attn_mask == 0, float(0.0)
|
155 |
+
) # nW, Wh**Ww, AWh*AWw
|
156 |
+
|
157 |
+
return attn_mask
|
158 |
+
|
159 |
+
|
160 |
+
def calculate_win_mask(
|
161 |
+
input_resolution1, input_resolution2, window_size1, window_size2
|
162 |
+
):
|
163 |
+
"""
|
164 |
+
Use case: 2)
|
165 |
+
"""
|
166 |
+
# calculate attention mask for SW-MSA
|
167 |
+
|
168 |
+
# mask of window1: nW, Wh**Ww
|
169 |
+
mask_windows1 = _fill_window(input_resolution1, window_size1)
|
170 |
+
# mask of window2: nW, AWh*AWw
|
171 |
+
mask_windows2 = _fill_window(input_resolution2, window_size2)
|
172 |
+
|
173 |
+
attn_mask = mask_windows1.unsqueeze(2) - mask_windows2.unsqueeze(1)
|
174 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
|
175 |
+
attn_mask == 0, float(0.0)
|
176 |
+
) # nW, Wh**Ww, AWh*AWw
|
177 |
+
|
178 |
+
return attn_mask
|
179 |
+
|
180 |
+
|
181 |
+
def _get_meshgrid_coords(start_coords, end_coords):
|
182 |
+
coord_h = torch.arange(start_coords[0], end_coords[0])
|
183 |
+
coord_w = torch.arange(start_coords[1], end_coords[1])
|
184 |
+
coords = torch.stack(torch.meshgrid([coord_h, coord_w], indexing="ij")) # 2, Wh, Ww
|
185 |
+
coords = torch.flatten(coords, 1) # 2, Wh*Ww
|
186 |
+
return coords
|
187 |
+
|
188 |
+
|
189 |
+
def get_relative_coords_table(
|
190 |
+
window_size, pretrained_window_size=[0, 0], anchor_window_down_factor=1
|
191 |
+
):
|
192 |
+
"""
|
193 |
+
Use case: 1)
|
194 |
+
"""
|
195 |
+
# get relative_coords_table
|
196 |
+
ws = window_size
|
197 |
+
aws = [w // anchor_window_down_factor for w in window_size]
|
198 |
+
pws = pretrained_window_size
|
199 |
+
paws = [w // anchor_window_down_factor for w in pretrained_window_size]
|
200 |
+
|
201 |
+
ts = [(w1 + w2) // 2 for w1, w2 in zip(ws, aws)]
|
202 |
+
pts = [(w1 + w2) // 2 for w1, w2 in zip(pws, paws)]
|
203 |
+
|
204 |
+
# TODO: pretrained window size and pretrained anchor window size is only used here.
|
205 |
+
# TODO: Investigate whether it is really important to use this setting when finetuning large window size
|
206 |
+
# TODO: based on pretrained weights with small window size.
|
207 |
+
|
208 |
+
coord_h = torch.arange(-(ts[0] - 1), ts[0], dtype=torch.float32)
|
209 |
+
coord_w = torch.arange(-(ts[1] - 1), ts[1], dtype=torch.float32)
|
210 |
+
table = torch.stack(torch.meshgrid([coord_h, coord_w], indexing="ij")).permute(
|
211 |
+
1, 2, 0
|
212 |
+
)
|
213 |
+
table = table.contiguous().unsqueeze(0) # 1, Wh+AWh-1, Ww+AWw-1, 2
|
214 |
+
if pts[0] > 0:
|
215 |
+
table[:, :, :, 0] /= pts[0] - 1
|
216 |
+
table[:, :, :, 1] /= pts[1] - 1
|
217 |
+
else:
|
218 |
+
table[:, :, :, 0] /= ts[0] - 1
|
219 |
+
table[:, :, :, 1] /= ts[1] - 1
|
220 |
+
table *= 8 # normalize to -8, 8
|
221 |
+
table = torch.sign(table) * torch.log2(torch.abs(table) + 1.0) / np.log2(8)
|
222 |
+
return table
|
223 |
+
|
224 |
+
|
225 |
+
def get_relative_coords_table_all(
|
226 |
+
window_size, pretrained_window_size=[0, 0], anchor_window_down_factor=1
|
227 |
+
):
|
228 |
+
"""
|
229 |
+
Use case: 3)
|
230 |
+
|
231 |
+
Support all window shapes.
|
232 |
+
Args:
|
233 |
+
window_size:
|
234 |
+
pretrained_window_size:
|
235 |
+
anchor_window_down_factor:
|
236 |
+
|
237 |
+
Returns:
|
238 |
+
|
239 |
+
"""
|
240 |
+
# get relative_coords_table
|
241 |
+
ws = window_size
|
242 |
+
aws = [w // anchor_window_down_factor for w in window_size]
|
243 |
+
pws = pretrained_window_size
|
244 |
+
paws = [w // anchor_window_down_factor for w in pretrained_window_size]
|
245 |
+
|
246 |
+
# positive table size: (Ww - 1) - (Ww - AWw) // 2
|
247 |
+
ts_p = [w1 - 1 - (w1 - w2) // 2 for w1, w2 in zip(ws, aws)]
|
248 |
+
# negative table size: -(AWw - 1) - (Ww - AWw) // 2
|
249 |
+
ts_n = [-(w2 - 1) - (w1 - w2) // 2 for w1, w2 in zip(ws, aws)]
|
250 |
+
pts = [w1 - 1 - (w1 - w2) // 2 for w1, w2 in zip(pws, paws)]
|
251 |
+
|
252 |
+
# TODO: pretrained window size and pretrained anchor window size is only used here.
|
253 |
+
# TODO: Investigate whether it is really important to use this setting when finetuning large window size
|
254 |
+
# TODO: based on pretrained weights with small window size.
|
255 |
+
|
256 |
+
coord_h = torch.arange(ts_n[0], ts_p[0] + 1, dtype=torch.float32)
|
257 |
+
coord_w = torch.arange(ts_n[1], ts_p[1] + 1, dtype=torch.float32)
|
258 |
+
table = torch.stack(torch.meshgrid([coord_h, coord_w], indexing="ij")).permute(
|
259 |
+
1, 2, 0
|
260 |
+
)
|
261 |
+
table = table.contiguous().unsqueeze(0) # 1, Wh+AWh-1, Ww+AWw-1, 2
|
262 |
+
if pts[0] > 0:
|
263 |
+
table[:, :, :, 0] /= pts[0]
|
264 |
+
table[:, :, :, 1] /= pts[1]
|
265 |
+
else:
|
266 |
+
table[:, :, :, 0] /= ts_p[0]
|
267 |
+
table[:, :, :, 1] /= ts_p[1]
|
268 |
+
table *= 8 # normalize to -8, 8
|
269 |
+
table = torch.sign(table) * torch.log2(torch.abs(table) + 1.0) / np.log2(8)
|
270 |
+
# 1, Wh+AWh-1, Ww+AWw-1, 2
|
271 |
+
return table
|
272 |
+
|
273 |
+
|
274 |
+
def coords_diff(coords1, coords2, max_diff):
|
275 |
+
# The coordinates starts from (-start_coord[0], -start_coord[1])
|
276 |
+
coords = coords1[:, :, None] - coords2[:, None, :] # 2, Wh*Ww, AWh*AWw
|
277 |
+
coords = coords.permute(1, 2, 0).contiguous() # Wh*Ww, AWh*AWw, 2
|
278 |
+
coords[:, :, 0] += max_diff[0] - 1 # shift to start from 0
|
279 |
+
coords[:, :, 1] += max_diff[1] - 1
|
280 |
+
coords[:, :, 0] *= 2 * max_diff[1] - 1
|
281 |
+
idx = coords.sum(-1) # Wh*Ww, AWh*AWw
|
282 |
+
return idx
|
283 |
+
|
284 |
+
|
285 |
+
def get_relative_position_index(
|
286 |
+
window_size, anchor_window_down_factor=1, window_to_anchor=True
|
287 |
+
):
|
288 |
+
"""
|
289 |
+
Use case: 1)
|
290 |
+
"""
|
291 |
+
# get pair-wise relative position index for each token inside the window
|
292 |
+
ws = window_size
|
293 |
+
aws = [w // anchor_window_down_factor for w in window_size]
|
294 |
+
coords_anchor_end = [(w1 + w2) // 2 for w1, w2 in zip(ws, aws)]
|
295 |
+
coords_anchor_start = [(w1 - w2) // 2 for w1, w2 in zip(ws, aws)]
|
296 |
+
|
297 |
+
coords = _get_meshgrid_coords((0, 0), window_size) # 2, Wh*Ww
|
298 |
+
coords_anchor = _get_meshgrid_coords(coords_anchor_start, coords_anchor_end)
|
299 |
+
# 2, AWh*AWw
|
300 |
+
|
301 |
+
if window_to_anchor:
|
302 |
+
idx = coords_diff(coords, coords_anchor, max_diff=coords_anchor_end)
|
303 |
+
else:
|
304 |
+
idx = coords_diff(coords_anchor, coords, max_diff=coords_anchor_end)
|
305 |
+
return idx # Wh*Ww, AWh*AWw or AWh*AWw, Wh*Ww
|
306 |
+
|
307 |
+
|
308 |
+
def coords_diff_odd(coords1, coords2, start_coord, max_diff):
|
309 |
+
# The coordinates starts from (-start_coord[0], -start_coord[1])
|
310 |
+
coords = coords1[:, :, None] - coords2[:, None, :] # 2, Wh*Ww, AWh*AWw
|
311 |
+
coords = coords.permute(1, 2, 0).contiguous() # Wh*Ww, AWh*AWw, 2
|
312 |
+
coords[:, :, 0] += start_coord[0] # shift to start from 0
|
313 |
+
coords[:, :, 1] += start_coord[1]
|
314 |
+
coords[:, :, 0] *= max_diff
|
315 |
+
idx = coords.sum(-1) # Wh*Ww, AWh*AWw
|
316 |
+
return idx
|
317 |
+
|
318 |
+
|
319 |
+
def get_relative_position_index_all(
|
320 |
+
window_size, anchor_window_down_factor=1, window_to_anchor=True
|
321 |
+
):
|
322 |
+
"""
|
323 |
+
Use case: 3)
|
324 |
+
Support all window shapes:
|
325 |
+
square window - square window
|
326 |
+
rectangular window - rectangular window
|
327 |
+
window - anchor
|
328 |
+
anchor - window
|
329 |
+
[8, 8] - [8, 8]
|
330 |
+
[4, 86] - [2, 43]
|
331 |
+
"""
|
332 |
+
# get pair-wise relative position index for each token inside the window
|
333 |
+
ws = window_size
|
334 |
+
aws = [w // anchor_window_down_factor for w in window_size]
|
335 |
+
coords_anchor_start = [(w1 - w2) // 2 for w1, w2 in zip(ws, aws)]
|
336 |
+
coords_anchor_end = [s + w2 for s, w2 in zip(coords_anchor_start, aws)]
|
337 |
+
|
338 |
+
coords = _get_meshgrid_coords((0, 0), window_size) # 2, Wh*Ww
|
339 |
+
coords_anchor = _get_meshgrid_coords(coords_anchor_start, coords_anchor_end)
|
340 |
+
# 2, AWh*AWw
|
341 |
+
|
342 |
+
max_horizontal_diff = aws[1] + ws[1] - 1
|
343 |
+
if window_to_anchor:
|
344 |
+
offset = [w2 + s - 1 for s, w2 in zip(coords_anchor_start, aws)]
|
345 |
+
idx = coords_diff_odd(coords, coords_anchor, offset, max_horizontal_diff)
|
346 |
+
else:
|
347 |
+
offset = [w1 - s - 1 for s, w1 in zip(coords_anchor_start, ws)]
|
348 |
+
idx = coords_diff_odd(coords_anchor, coords, offset, max_horizontal_diff)
|
349 |
+
return idx # Wh*Ww, AWh*AWw or AWh*AWw, Wh*Ww
|
350 |
+
|
351 |
+
|
352 |
+
def get_relative_position_index_simple(
|
353 |
+
window_size, anchor_window_down_factor=1, window_to_anchor=True
|
354 |
+
):
|
355 |
+
"""
|
356 |
+
Use case: 3)
|
357 |
+
This is a simplified version of get_relative_position_index_all
|
358 |
+
The start coordinate of anchor window is also (0, 0)
|
359 |
+
get pair-wise relative position index for each token inside the window
|
360 |
+
"""
|
361 |
+
ws = window_size
|
362 |
+
aws = [w // anchor_window_down_factor for w in window_size]
|
363 |
+
|
364 |
+
coords = _get_meshgrid_coords((0, 0), window_size) # 2, Wh*Ww
|
365 |
+
coords_anchor = _get_meshgrid_coords((0, 0), aws)
|
366 |
+
# 2, AWh*AWw
|
367 |
+
|
368 |
+
max_horizontal_diff = aws[1] + ws[1] - 1
|
369 |
+
if window_to_anchor:
|
370 |
+
offset = [w2 - 1 for w2 in aws]
|
371 |
+
idx = coords_diff_odd(coords, coords_anchor, offset, max_horizontal_diff)
|
372 |
+
else:
|
373 |
+
offset = [w1 - 1 for w1 in ws]
|
374 |
+
idx = coords_diff_odd(coords_anchor, coords, offset, max_horizontal_diff)
|
375 |
+
return idx # Wh*Ww, AWh*AWw or AWh*AWw, Wh*Ww
|
376 |
+
|
377 |
+
|
378 |
+
# def get_relative_position_index(window_size):
|
379 |
+
# # This is a very early version
|
380 |
+
# # get pair-wise relative position index for each token inside the window
|
381 |
+
# coords = _get_meshgrid_coords(start_coords=(0, 0), end_coords=window_size)
|
382 |
+
|
383 |
+
# coords = coords[:, :, None] - coords[:, None, :] # 2, Wh*Ww, Wh*Ww
|
384 |
+
# coords = coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
385 |
+
# coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
|
386 |
+
# coords[:, :, 1] += window_size[1] - 1
|
387 |
+
# coords[:, :, 0] *= 2 * window_size[1] - 1
|
388 |
+
# idx = coords.sum(-1) # Wh*Ww, Wh*Ww
|
389 |
+
# return idx
|
390 |
+
|
391 |
+
|
392 |
+
def get_relative_win_position_index(window_size, anchor_window_size):
|
393 |
+
"""
|
394 |
+
Use case: 2)
|
395 |
+
"""
|
396 |
+
# get pair-wise relative position index for each token inside the window
|
397 |
+
ws = window_size
|
398 |
+
aws = anchor_window_size
|
399 |
+
coords_anchor_end = [(w1 + w2) // 2 for w1, w2 in zip(ws, aws)]
|
400 |
+
coords_anchor_start = [(w1 - w2) // 2 for w1, w2 in zip(ws, aws)]
|
401 |
+
|
402 |
+
coords = _get_meshgrid_coords((0, 0), window_size) # 2, Wh*Ww
|
403 |
+
coords_anchor = _get_meshgrid_coords(coords_anchor_start, coords_anchor_end)
|
404 |
+
# 2, AWh*AWw
|
405 |
+
coords = coords[:, :, None] - coords_anchor[:, None, :] # 2, Wh*Ww, AWh*AWw
|
406 |
+
coords = coords.permute(1, 2, 0).contiguous() # Wh*Ww, AWh*AWw, 2
|
407 |
+
coords[:, :, 0] += coords_anchor_end[0] - 1 # shift to start from 0
|
408 |
+
coords[:, :, 1] += coords_anchor_end[1] - 1
|
409 |
+
coords[:, :, 0] *= 2 * coords_anchor_end[1] - 1
|
410 |
+
idx = coords.sum(-1) # Wh*Ww, AWh*AWw
|
411 |
+
return idx
|
412 |
+
|
413 |
+
|
414 |
+
# def get_relative_coords_table(window_size, pretrained_window_size):
|
415 |
+
# # This is a very early version
|
416 |
+
# # get relative_coords_table
|
417 |
+
# ws = window_size
|
418 |
+
# pws = pretrained_window_size
|
419 |
+
# coord_h = torch.arange(-(ws[0] - 1), ws[0], dtype=torch.float32)
|
420 |
+
# coord_w = torch.arange(-(ws[1] - 1), ws[1], dtype=torch.float32)
|
421 |
+
# table = torch.stack(torch.meshgrid([coord_h, coord_w], indexing='ij')).permute(1, 2, 0)
|
422 |
+
# table = table.contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
|
423 |
+
# if pws[0] > 0:
|
424 |
+
# table[:, :, :, 0] /= pws[0] - 1
|
425 |
+
# table[:, :, :, 1] /= pws[1] - 1
|
426 |
+
# else:
|
427 |
+
# table[:, :, :, 0] /= ws[0] - 1
|
428 |
+
# table[:, :, :, 1] /= ws[1] - 1
|
429 |
+
# table *= 8 # normalize to -8, 8
|
430 |
+
# table = torch.sign(table) * torch.log2(torch.abs(table) + 1.0) / np.log2(8)
|
431 |
+
# return table
|
432 |
+
|
433 |
+
|
434 |
+
def get_relative_win_coords_table(
|
435 |
+
window_size,
|
436 |
+
anchor_window_size,
|
437 |
+
pretrained_window_size=[0, 0],
|
438 |
+
pretrained_anchor_window_size=[0, 0],
|
439 |
+
):
|
440 |
+
"""
|
441 |
+
Use case: 2)
|
442 |
+
"""
|
443 |
+
# get relative_coords_table
|
444 |
+
ws = window_size
|
445 |
+
aws = anchor_window_size
|
446 |
+
pws = pretrained_window_size
|
447 |
+
paws = pretrained_anchor_window_size
|
448 |
+
|
449 |
+
# TODO: pretrained window size and pretrained anchor window size is only used here.
|
450 |
+
# TODO: Investigate whether it is really important to use this setting when finetuning large window size
|
451 |
+
# TODO: based on pretrained weights with small window size.
|
452 |
+
|
453 |
+
table_size = [(wsi + awsi) // 2 for wsi, awsi in zip(ws, aws)]
|
454 |
+
table_size_pretrained = [(pwsi + pawsi) // 2 for pwsi, pawsi in zip(pws, paws)]
|
455 |
+
coord_h = torch.arange(-(table_size[0] - 1), table_size[0], dtype=torch.float32)
|
456 |
+
coord_w = torch.arange(-(table_size[1] - 1), table_size[1], dtype=torch.float32)
|
457 |
+
table = torch.stack(torch.meshgrid([coord_h, coord_w], indexing="ij")).permute(
|
458 |
+
1, 2, 0
|
459 |
+
)
|
460 |
+
table = table.contiguous().unsqueeze(0) # 1, Wh+AWh-1, Ww+AWw-1, 2
|
461 |
+
if table_size_pretrained[0] > 0:
|
462 |
+
table[:, :, :, 0] /= table_size_pretrained[0] - 1
|
463 |
+
table[:, :, :, 1] /= table_size_pretrained[1] - 1
|
464 |
+
else:
|
465 |
+
table[:, :, :, 0] /= table_size[0] - 1
|
466 |
+
table[:, :, :, 1] /= table_size[1] - 1
|
467 |
+
table *= 8 # normalize to -8, 8
|
468 |
+
table = torch.sign(table) * torch.log2(torch.abs(table) + 1.0) / np.log2(8)
|
469 |
+
return table
|
470 |
+
|
471 |
+
|
472 |
+
if __name__ == "__main__":
|
473 |
+
table = get_relative_coords_table_all((4, 86), anchor_window_down_factor=2)
|
474 |
+
table = table.view(-1, 2)
|
475 |
+
index1 = get_relative_position_index_all((4, 86), 2, False)
|
476 |
+
index2 = get_relative_position_index_simple((4, 86), 2, False)
|
477 |
+
print(index2)
|
478 |
+
index3 = get_relative_position_index_all((4, 86), 2)
|
479 |
+
index4 = get_relative_position_index_simple((4, 86), 2)
|
480 |
+
print(index4)
|
481 |
+
print(
|
482 |
+
table.shape,
|
483 |
+
index2.shape,
|
484 |
+
index2.max(),
|
485 |
+
index2.min(),
|
486 |
+
index4.shape,
|
487 |
+
index4.max(),
|
488 |
+
index4.min(),
|
489 |
+
torch.allclose(index1, index2),
|
490 |
+
torch.allclose(index3, index4),
|
491 |
+
)
|
492 |
+
|
493 |
+
table = get_relative_coords_table_all((4, 86), anchor_window_down_factor=1)
|
494 |
+
table = table.view(-1, 2)
|
495 |
+
index1 = get_relative_position_index_all((4, 86), 1, False)
|
496 |
+
index2 = get_relative_position_index_simple((4, 86), 1, False)
|
497 |
+
# print(index1)
|
498 |
+
index3 = get_relative_position_index_all((4, 86), 1)
|
499 |
+
index4 = get_relative_position_index_simple((4, 86), 1)
|
500 |
+
# print(index2)
|
501 |
+
print(
|
502 |
+
table.shape,
|
503 |
+
index2.shape,
|
504 |
+
index2.max(),
|
505 |
+
index2.min(),
|
506 |
+
index4.shape,
|
507 |
+
index4.max(),
|
508 |
+
index4.min(),
|
509 |
+
torch.allclose(index1, index2),
|
510 |
+
torch.allclose(index3, index4),
|
511 |
+
)
|
512 |
+
|
513 |
+
table = get_relative_coords_table_all((8, 8), anchor_window_down_factor=2)
|
514 |
+
table = table.view(-1, 2)
|
515 |
+
index1 = get_relative_position_index_all((8, 8), 2, False)
|
516 |
+
index2 = get_relative_position_index_simple((8, 8), 2, False)
|
517 |
+
# print(index1)
|
518 |
+
index3 = get_relative_position_index_all((8, 8), 2)
|
519 |
+
index4 = get_relative_position_index_simple((8, 8), 2)
|
520 |
+
# print(index2)
|
521 |
+
print(
|
522 |
+
table.shape,
|
523 |
+
index2.shape,
|
524 |
+
index2.max(),
|
525 |
+
index2.min(),
|
526 |
+
index4.shape,
|
527 |
+
index4.max(),
|
528 |
+
index4.min(),
|
529 |
+
torch.allclose(index1, index2),
|
530 |
+
torch.allclose(index3, index4),
|
531 |
+
)
|
532 |
+
|
533 |
+
table = get_relative_coords_table_all((8, 8), anchor_window_down_factor=1)
|
534 |
+
table = table.view(-1, 2)
|
535 |
+
index1 = get_relative_position_index_all((8, 8), 1, False)
|
536 |
+
index2 = get_relative_position_index_simple((8, 8), 1, False)
|
537 |
+
# print(index1)
|
538 |
+
index3 = get_relative_position_index_all((8, 8), 1)
|
539 |
+
index4 = get_relative_position_index_simple((8, 8), 1)
|
540 |
+
# print(index2)
|
541 |
+
print(
|
542 |
+
table.shape,
|
543 |
+
index2.shape,
|
544 |
+
index2.max(),
|
545 |
+
index2.min(),
|
546 |
+
index4.shape,
|
547 |
+
index4.max(),
|
548 |
+
index4.min(),
|
549 |
+
torch.allclose(index1, index2),
|
550 |
+
torch.allclose(index3, index4),
|
551 |
+
)
|
architecture/grl_common/resblock.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
|
4 |
+
class ResBlock(nn.Module):
|
5 |
+
"""Residual block without BN.
|
6 |
+
|
7 |
+
It has a style of:
|
8 |
+
|
9 |
+
::
|
10 |
+
|
11 |
+
---Conv-ReLU-Conv-+-
|
12 |
+
|________________|
|
13 |
+
|
14 |
+
Args:
|
15 |
+
num_feats (int): Channel number of intermediate features.
|
16 |
+
Default: 64.
|
17 |
+
res_scale (float): Used to scale the residual before addition.
|
18 |
+
Default: 1.0.
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(self, num_feats=64, res_scale=1.0, bias=True, shortcut=True):
|
22 |
+
super().__init__()
|
23 |
+
self.res_scale = res_scale
|
24 |
+
self.shortcut = shortcut
|
25 |
+
self.conv1 = nn.Conv2d(num_feats, num_feats, 3, 1, 1, bias=bias)
|
26 |
+
self.conv2 = nn.Conv2d(num_feats, num_feats, 3, 1, 1, bias=bias)
|
27 |
+
self.relu = nn.ReLU(inplace=True)
|
28 |
+
|
29 |
+
def forward(self, x):
|
30 |
+
"""Forward function.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
x (Tensor): Input tensor with shape (n, c, h, w).
|
34 |
+
|
35 |
+
Returns:
|
36 |
+
Tensor: Forward results.
|
37 |
+
"""
|
38 |
+
|
39 |
+
identity = x
|
40 |
+
out = self.conv2(self.relu(self.conv1(x)))
|
41 |
+
if self.shortcut:
|
42 |
+
return identity + out * self.res_scale
|
43 |
+
else:
|
44 |
+
return out * self.res_scale
|
45 |
+
|
46 |
+
|
47 |
+
class ResBlockWrapper(ResBlock):
|
48 |
+
"Used for transformers"
|
49 |
+
|
50 |
+
def __init__(self, num_feats, bias=True, shortcut=True):
|
51 |
+
super(ResBlockWrapper, self).__init__(
|
52 |
+
num_feats=num_feats, bias=bias, shortcut=shortcut
|
53 |
+
)
|
54 |
+
|
55 |
+
def forward(self, x, x_size):
|
56 |
+
H, W = x_size
|
57 |
+
B, L, C = x.shape
|
58 |
+
x = x.view(B, H, W, C).permute(0, 3, 1, 2)
|
59 |
+
x = super(ResBlockWrapper, self).forward(x)
|
60 |
+
x = x.flatten(2).permute(0, 2, 1)
|
61 |
+
return x
|
architecture/grl_common/swin_v1_block.py
ADDED
@@ -0,0 +1,602 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from math import prod
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from architecture.grl_common.ops import (
|
6 |
+
bchw_to_blc,
|
7 |
+
blc_to_bchw,
|
8 |
+
calculate_mask,
|
9 |
+
window_partition,
|
10 |
+
window_reverse,
|
11 |
+
)
|
12 |
+
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
13 |
+
|
14 |
+
|
15 |
+
class Mlp(nn.Module):
|
16 |
+
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
|
17 |
+
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
in_features,
|
21 |
+
hidden_features=None,
|
22 |
+
out_features=None,
|
23 |
+
act_layer=nn.GELU,
|
24 |
+
drop=0.0,
|
25 |
+
):
|
26 |
+
super().__init__()
|
27 |
+
out_features = out_features or in_features
|
28 |
+
hidden_features = hidden_features or in_features
|
29 |
+
drop_probs = to_2tuple(drop)
|
30 |
+
|
31 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
32 |
+
self.act = act_layer()
|
33 |
+
self.drop1 = nn.Dropout(drop_probs[0])
|
34 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
35 |
+
self.drop2 = nn.Dropout(drop_probs[1])
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
x = self.fc1(x)
|
39 |
+
x = self.act(x)
|
40 |
+
x = self.drop1(x)
|
41 |
+
x = self.fc2(x)
|
42 |
+
x = self.drop2(x)
|
43 |
+
return x
|
44 |
+
|
45 |
+
|
46 |
+
class WindowAttentionV1(nn.Module):
|
47 |
+
r"""Window based multi-head self attention (W-MSA) module with relative position bias.
|
48 |
+
It supports both of shifted and non-shifted window.
|
49 |
+
Args:
|
50 |
+
dim (int): Number of input channels.
|
51 |
+
window_size (tuple[int]): The height and width of the window.
|
52 |
+
num_heads (int): Number of attention heads.
|
53 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
54 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
55 |
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
56 |
+
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
57 |
+
"""
|
58 |
+
|
59 |
+
def __init__(
|
60 |
+
self,
|
61 |
+
dim,
|
62 |
+
window_size,
|
63 |
+
num_heads,
|
64 |
+
qkv_bias=True,
|
65 |
+
qk_scale=None,
|
66 |
+
attn_drop=0.0,
|
67 |
+
proj_drop=0.0,
|
68 |
+
use_pe=True,
|
69 |
+
):
|
70 |
+
|
71 |
+
super().__init__()
|
72 |
+
self.dim = dim
|
73 |
+
self.window_size = window_size # Wh, Ww
|
74 |
+
self.num_heads = num_heads
|
75 |
+
head_dim = dim // num_heads
|
76 |
+
self.scale = qk_scale or head_dim**-0.5
|
77 |
+
self.use_pe = use_pe
|
78 |
+
|
79 |
+
if self.use_pe:
|
80 |
+
# define a parameter table of relative position bias
|
81 |
+
ws = self.window_size
|
82 |
+
table = torch.zeros((2 * ws[0] - 1) * (2 * ws[1] - 1), num_heads)
|
83 |
+
self.relative_position_bias_table = nn.Parameter(table)
|
84 |
+
# 2*Wh-1 * 2*Ww-1, nH
|
85 |
+
trunc_normal_(self.relative_position_bias_table, std=0.02)
|
86 |
+
|
87 |
+
self.get_relative_position_index(self.window_size)
|
88 |
+
|
89 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
90 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
91 |
+
self.proj = nn.Linear(dim, dim)
|
92 |
+
|
93 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
94 |
+
|
95 |
+
self.softmax = nn.Softmax(dim=-1)
|
96 |
+
|
97 |
+
def get_relative_position_index(self, window_size):
|
98 |
+
# get pair-wise relative position index for each token inside the window
|
99 |
+
coord_h = torch.arange(window_size[0])
|
100 |
+
coord_w = torch.arange(window_size[1])
|
101 |
+
coords = torch.stack(torch.meshgrid([coord_h, coord_w])) # 2, Wh, Ww
|
102 |
+
coords = torch.flatten(coords, 1) # 2, Wh*Ww
|
103 |
+
coords = coords[:, :, None] - coords[:, None, :] # 2, Wh*Ww, Wh*Ww
|
104 |
+
coords = coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
105 |
+
coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
|
106 |
+
coords[:, :, 1] += window_size[1] - 1
|
107 |
+
coords[:, :, 0] *= 2 * window_size[1] - 1
|
108 |
+
relative_position_index = coords.sum(-1) # Wh*Ww, Wh*Ww
|
109 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
110 |
+
|
111 |
+
def forward(self, x, mask=None):
|
112 |
+
"""
|
113 |
+
Args:
|
114 |
+
x: input features with shape of (num_windows*B, N, C)
|
115 |
+
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
116 |
+
"""
|
117 |
+
B_, N, C = x.shape
|
118 |
+
|
119 |
+
# qkv projection
|
120 |
+
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
121 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
122 |
+
|
123 |
+
# attention map
|
124 |
+
q = q * self.scale
|
125 |
+
attn = q @ k.transpose(-2, -1)
|
126 |
+
|
127 |
+
# positional encoding
|
128 |
+
if self.use_pe:
|
129 |
+
win_dim = prod(self.window_size)
|
130 |
+
bias = self.relative_position_bias_table[
|
131 |
+
self.relative_position_index.view(-1)
|
132 |
+
]
|
133 |
+
bias = bias.view(win_dim, win_dim, -1).permute(2, 0, 1).contiguous()
|
134 |
+
# nH, Wh*Ww, Wh*Ww
|
135 |
+
attn = attn + bias.unsqueeze(0)
|
136 |
+
|
137 |
+
# shift attention mask
|
138 |
+
if mask is not None:
|
139 |
+
nW = mask.shape[0]
|
140 |
+
mask = mask.unsqueeze(1).unsqueeze(0)
|
141 |
+
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask
|
142 |
+
attn = attn.view(-1, self.num_heads, N, N)
|
143 |
+
|
144 |
+
# attention
|
145 |
+
attn = self.softmax(attn)
|
146 |
+
attn = self.attn_drop(attn)
|
147 |
+
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
148 |
+
|
149 |
+
# output projection
|
150 |
+
x = self.proj(x)
|
151 |
+
x = self.proj_drop(x)
|
152 |
+
return x
|
153 |
+
|
154 |
+
def extra_repr(self) -> str:
|
155 |
+
return f"dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}"
|
156 |
+
|
157 |
+
def flops(self, N):
|
158 |
+
# calculate flops for 1 window with token length of N
|
159 |
+
flops = 0
|
160 |
+
# qkv = self.qkv(x)
|
161 |
+
flops += N * self.dim * 3 * self.dim
|
162 |
+
# attn = (q @ k.transpose(-2, -1))
|
163 |
+
flops += self.num_heads * N * (self.dim // self.num_heads) * N
|
164 |
+
# x = (attn @ v)
|
165 |
+
flops += self.num_heads * N * N * (self.dim // self.num_heads)
|
166 |
+
# x = self.proj(x)
|
167 |
+
flops += N * self.dim * self.dim
|
168 |
+
return flops
|
169 |
+
|
170 |
+
|
171 |
+
class WindowAttentionWrapperV1(WindowAttentionV1):
|
172 |
+
def __init__(self, shift_size, input_resolution, **kwargs):
|
173 |
+
super(WindowAttentionWrapperV1, self).__init__(**kwargs)
|
174 |
+
self.shift_size = shift_size
|
175 |
+
self.input_resolution = input_resolution
|
176 |
+
|
177 |
+
if self.shift_size > 0:
|
178 |
+
attn_mask = calculate_mask(input_resolution, self.window_size, shift_size)
|
179 |
+
else:
|
180 |
+
attn_mask = None
|
181 |
+
self.register_buffer("attn_mask", attn_mask)
|
182 |
+
|
183 |
+
def forward(self, x, x_size):
|
184 |
+
H, W = x_size
|
185 |
+
B, L, C = x.shape
|
186 |
+
x = x.view(B, H, W, C)
|
187 |
+
|
188 |
+
# cyclic shift
|
189 |
+
if self.shift_size > 0:
|
190 |
+
x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
191 |
+
|
192 |
+
# partition windows
|
193 |
+
x = window_partition(x, self.window_size) # nW*B, wh, ww, C
|
194 |
+
x = x.view(-1, prod(self.window_size), C) # nW*B, wh*ww, C
|
195 |
+
|
196 |
+
# W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
|
197 |
+
if self.input_resolution == x_size:
|
198 |
+
attn_mask = self.attn_mask
|
199 |
+
else:
|
200 |
+
attn_mask = calculate_mask(x_size, self.window_size, self.shift_size)
|
201 |
+
attn_mask = attn_mask.to(x.device)
|
202 |
+
|
203 |
+
# attention
|
204 |
+
x = super(WindowAttentionWrapperV1, self).forward(x, mask=attn_mask)
|
205 |
+
# nW*B, wh*ww, C
|
206 |
+
|
207 |
+
# merge windows
|
208 |
+
x = x.view(-1, *self.window_size, C)
|
209 |
+
x = window_reverse(x, self.window_size, x_size) # B, H, W, C
|
210 |
+
|
211 |
+
# reverse cyclic shift
|
212 |
+
if self.shift_size > 0:
|
213 |
+
x = torch.roll(x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
214 |
+
x = x.view(B, H * W, C)
|
215 |
+
|
216 |
+
return x
|
217 |
+
|
218 |
+
|
219 |
+
class SwinTransformerBlockV1(nn.Module):
|
220 |
+
r"""Swin Transformer Block.
|
221 |
+
Args:
|
222 |
+
dim (int): Number of input channels.
|
223 |
+
input_resolution (tuple[int]): Input resulotion.
|
224 |
+
num_heads (int): Number of attention heads.
|
225 |
+
window_size (int): Window size.
|
226 |
+
shift_size (int): Shift size for SW-MSA.
|
227 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
228 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
229 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
230 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
231 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
232 |
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
233 |
+
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
234 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
235 |
+
"""
|
236 |
+
|
237 |
+
def __init__(
|
238 |
+
self,
|
239 |
+
dim,
|
240 |
+
input_resolution,
|
241 |
+
num_heads,
|
242 |
+
window_size=7,
|
243 |
+
shift_size=0,
|
244 |
+
mlp_ratio=4.0,
|
245 |
+
qkv_bias=True,
|
246 |
+
qk_scale=None,
|
247 |
+
drop=0.0,
|
248 |
+
attn_drop=0.0,
|
249 |
+
drop_path=0.0,
|
250 |
+
act_layer=nn.GELU,
|
251 |
+
norm_layer=nn.LayerNorm,
|
252 |
+
use_pe=True,
|
253 |
+
res_scale=1.0,
|
254 |
+
):
|
255 |
+
super().__init__()
|
256 |
+
self.dim = dim
|
257 |
+
self.input_resolution = input_resolution
|
258 |
+
self.num_heads = num_heads
|
259 |
+
self.window_size = window_size
|
260 |
+
self.shift_size = shift_size
|
261 |
+
self.mlp_ratio = mlp_ratio
|
262 |
+
if min(self.input_resolution) <= self.window_size:
|
263 |
+
# if window size is larger than input resolution, we don't partition windows
|
264 |
+
self.shift_size = 0
|
265 |
+
self.window_size = min(self.input_resolution)
|
266 |
+
assert (
|
267 |
+
0 <= self.shift_size < self.window_size
|
268 |
+
), "shift_size must in 0-window_size"
|
269 |
+
self.res_scale = res_scale
|
270 |
+
|
271 |
+
self.norm1 = norm_layer(dim)
|
272 |
+
self.attn = WindowAttentionWrapperV1(
|
273 |
+
shift_size=self.shift_size,
|
274 |
+
input_resolution=self.input_resolution,
|
275 |
+
dim=dim,
|
276 |
+
window_size=to_2tuple(self.window_size),
|
277 |
+
num_heads=num_heads,
|
278 |
+
qkv_bias=qkv_bias,
|
279 |
+
qk_scale=qk_scale,
|
280 |
+
attn_drop=attn_drop,
|
281 |
+
proj_drop=drop,
|
282 |
+
use_pe=use_pe,
|
283 |
+
)
|
284 |
+
|
285 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
286 |
+
|
287 |
+
self.norm2 = norm_layer(dim)
|
288 |
+
self.mlp = Mlp(
|
289 |
+
in_features=dim,
|
290 |
+
hidden_features=int(dim * mlp_ratio),
|
291 |
+
act_layer=act_layer,
|
292 |
+
drop=drop,
|
293 |
+
)
|
294 |
+
|
295 |
+
def forward(self, x, x_size):
|
296 |
+
# Window attention
|
297 |
+
x = x + self.res_scale * self.drop_path(self.attn(self.norm1(x), x_size))
|
298 |
+
# FFN
|
299 |
+
x = x + self.res_scale * self.drop_path(self.mlp(self.norm2(x)))
|
300 |
+
|
301 |
+
return x
|
302 |
+
|
303 |
+
def extra_repr(self) -> str:
|
304 |
+
return (
|
305 |
+
f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, "
|
306 |
+
f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}, res_scale={self.res_scale}"
|
307 |
+
)
|
308 |
+
|
309 |
+
def flops(self):
|
310 |
+
flops = 0
|
311 |
+
H, W = self.input_resolution
|
312 |
+
# norm1
|
313 |
+
flops += self.dim * H * W
|
314 |
+
# W-MSA/SW-MSA
|
315 |
+
nW = H * W / self.window_size / self.window_size
|
316 |
+
flops += nW * self.attn.flops(self.window_size * self.window_size)
|
317 |
+
# mlp
|
318 |
+
flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
|
319 |
+
# norm2
|
320 |
+
flops += self.dim * H * W
|
321 |
+
return flops
|
322 |
+
|
323 |
+
|
324 |
+
class PatchMerging(nn.Module):
|
325 |
+
r"""Patch Merging Layer.
|
326 |
+
Args:
|
327 |
+
input_resolution (tuple[int]): Resolution of input feature.
|
328 |
+
dim (int): Number of input channels.
|
329 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
330 |
+
"""
|
331 |
+
|
332 |
+
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
|
333 |
+
super().__init__()
|
334 |
+
self.input_resolution = input_resolution
|
335 |
+
self.dim = dim
|
336 |
+
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
337 |
+
self.norm = norm_layer(4 * dim)
|
338 |
+
|
339 |
+
def forward(self, x):
|
340 |
+
"""
|
341 |
+
x: B, H*W, C
|
342 |
+
"""
|
343 |
+
H, W = self.input_resolution
|
344 |
+
B, L, C = x.shape
|
345 |
+
assert L == H * W, "input feature has wrong size"
|
346 |
+
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
|
347 |
+
|
348 |
+
x = x.view(B, H, W, C)
|
349 |
+
|
350 |
+
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
351 |
+
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
352 |
+
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
353 |
+
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
354 |
+
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
355 |
+
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
356 |
+
|
357 |
+
x = self.norm(x)
|
358 |
+
x = self.reduction(x)
|
359 |
+
|
360 |
+
return x
|
361 |
+
|
362 |
+
def extra_repr(self) -> str:
|
363 |
+
return f"input_resolution={self.input_resolution}, dim={self.dim}"
|
364 |
+
|
365 |
+
def flops(self):
|
366 |
+
H, W = self.input_resolution
|
367 |
+
flops = H * W * self.dim
|
368 |
+
flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
|
369 |
+
return flops
|
370 |
+
|
371 |
+
|
372 |
+
class PatchEmbed(nn.Module):
|
373 |
+
r"""Image to Patch Embedding
|
374 |
+
Args:
|
375 |
+
img_size (int): Image size. Default: 224.
|
376 |
+
patch_size (int): Patch token size. Default: 4.
|
377 |
+
in_chans (int): Number of input image channels. Default: 3.
|
378 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
379 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
380 |
+
"""
|
381 |
+
|
382 |
+
def __init__(
|
383 |
+
self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None
|
384 |
+
):
|
385 |
+
super().__init__()
|
386 |
+
img_size = to_2tuple(img_size)
|
387 |
+
patch_size = to_2tuple(patch_size)
|
388 |
+
patches_resolution = [
|
389 |
+
img_size[0] // patch_size[0],
|
390 |
+
img_size[1] // patch_size[1],
|
391 |
+
]
|
392 |
+
self.img_size = img_size
|
393 |
+
self.patch_size = patch_size
|
394 |
+
self.patches_resolution = patches_resolution
|
395 |
+
self.num_patches = patches_resolution[0] * patches_resolution[1]
|
396 |
+
|
397 |
+
self.in_chans = in_chans
|
398 |
+
self.embed_dim = embed_dim
|
399 |
+
|
400 |
+
if norm_layer is not None:
|
401 |
+
self.norm = norm_layer(embed_dim)
|
402 |
+
else:
|
403 |
+
self.norm = None
|
404 |
+
|
405 |
+
def forward(self, x):
|
406 |
+
x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
|
407 |
+
if self.norm is not None:
|
408 |
+
x = self.norm(x)
|
409 |
+
return x
|
410 |
+
|
411 |
+
def flops(self):
|
412 |
+
flops = 0
|
413 |
+
H, W = self.img_size
|
414 |
+
if self.norm is not None:
|
415 |
+
flops += H * W * self.embed_dim
|
416 |
+
return flops
|
417 |
+
|
418 |
+
|
419 |
+
class PatchUnEmbed(nn.Module):
|
420 |
+
r"""Image to Patch Unembedding
|
421 |
+
Args:
|
422 |
+
img_size (int): Image size. Default: 224.
|
423 |
+
patch_size (int): Patch token size. Default: 4.
|
424 |
+
in_chans (int): Number of input image channels. Default: 3.
|
425 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
426 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
427 |
+
"""
|
428 |
+
|
429 |
+
def __init__(
|
430 |
+
self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None
|
431 |
+
):
|
432 |
+
super().__init__()
|
433 |
+
img_size = to_2tuple(img_size)
|
434 |
+
patch_size = to_2tuple(patch_size)
|
435 |
+
patches_resolution = [
|
436 |
+
img_size[0] // patch_size[0],
|
437 |
+
img_size[1] // patch_size[1],
|
438 |
+
]
|
439 |
+
self.img_size = img_size
|
440 |
+
self.patch_size = patch_size
|
441 |
+
self.patches_resolution = patches_resolution
|
442 |
+
self.num_patches = patches_resolution[0] * patches_resolution[1]
|
443 |
+
|
444 |
+
self.in_chans = in_chans
|
445 |
+
self.embed_dim = embed_dim
|
446 |
+
|
447 |
+
def forward(self, x, x_size):
|
448 |
+
B, HW, C = x.shape
|
449 |
+
x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
|
450 |
+
return x
|
451 |
+
|
452 |
+
def flops(self):
|
453 |
+
flops = 0
|
454 |
+
return flops
|
455 |
+
|
456 |
+
|
457 |
+
class Linear(nn.Linear):
|
458 |
+
def __init__(self, in_features, out_features, bias=True):
|
459 |
+
super(Linear, self).__init__(in_features, out_features, bias)
|
460 |
+
|
461 |
+
def forward(self, x):
|
462 |
+
B, C, H, W = x.shape
|
463 |
+
x = bchw_to_blc(x)
|
464 |
+
x = super(Linear, self).forward(x)
|
465 |
+
x = blc_to_bchw(x, (H, W))
|
466 |
+
return x
|
467 |
+
|
468 |
+
|
469 |
+
def build_last_conv(conv_type, dim):
|
470 |
+
if conv_type == "1conv":
|
471 |
+
block = nn.Conv2d(dim, dim, 3, 1, 1)
|
472 |
+
elif conv_type == "3conv":
|
473 |
+
# to save parameters and memory
|
474 |
+
block = nn.Sequential(
|
475 |
+
nn.Conv2d(dim, dim // 4, 3, 1, 1),
|
476 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
477 |
+
nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
|
478 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
479 |
+
nn.Conv2d(dim // 4, dim, 3, 1, 1),
|
480 |
+
)
|
481 |
+
elif conv_type == "1conv1x1":
|
482 |
+
block = nn.Conv2d(dim, dim, 1, 1, 0)
|
483 |
+
elif conv_type == "linear":
|
484 |
+
block = Linear(dim, dim)
|
485 |
+
return block
|
486 |
+
|
487 |
+
|
488 |
+
# class BasicLayer(nn.Module):
|
489 |
+
# """A basic Swin Transformer layer for one stage.
|
490 |
+
# Args:
|
491 |
+
# dim (int): Number of input channels.
|
492 |
+
# input_resolution (tuple[int]): Input resolution.
|
493 |
+
# depth (int): Number of blocks.
|
494 |
+
# num_heads (int): Number of attention heads.
|
495 |
+
# window_size (int): Local window size.
|
496 |
+
# mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
497 |
+
# qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
498 |
+
# qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
499 |
+
# drop (float, optional): Dropout rate. Default: 0.0
|
500 |
+
# attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
501 |
+
# drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
502 |
+
# norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
503 |
+
# downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
504 |
+
# args: Additional arguments
|
505 |
+
# """
|
506 |
+
|
507 |
+
# def __init__(
|
508 |
+
# self,
|
509 |
+
# dim,
|
510 |
+
# input_resolution,
|
511 |
+
# depth,
|
512 |
+
# num_heads,
|
513 |
+
# window_size,
|
514 |
+
# mlp_ratio=4.0,
|
515 |
+
# qkv_bias=True,
|
516 |
+
# qk_scale=None,
|
517 |
+
# drop=0.0,
|
518 |
+
# attn_drop=0.0,
|
519 |
+
# drop_path=0.0,
|
520 |
+
# norm_layer=nn.LayerNorm,
|
521 |
+
# downsample=None,
|
522 |
+
# args=None,
|
523 |
+
# ):
|
524 |
+
|
525 |
+
# super().__init__()
|
526 |
+
# self.dim = dim
|
527 |
+
# self.input_resolution = input_resolution
|
528 |
+
# self.depth = depth
|
529 |
+
|
530 |
+
# # build blocks
|
531 |
+
# self.blocks = nn.ModuleList(
|
532 |
+
# [
|
533 |
+
# _parse_block(
|
534 |
+
# dim=dim,
|
535 |
+
# input_resolution=input_resolution,
|
536 |
+
# num_heads=num_heads,
|
537 |
+
# window_size=window_size,
|
538 |
+
# shift_size=0
|
539 |
+
# if args.no_shift
|
540 |
+
# else (0 if (i % 2 == 0) else window_size // 2),
|
541 |
+
# mlp_ratio=mlp_ratio,
|
542 |
+
# qkv_bias=qkv_bias,
|
543 |
+
# qk_scale=qk_scale,
|
544 |
+
# drop=drop,
|
545 |
+
# attn_drop=attn_drop,
|
546 |
+
# drop_path=drop_path[i]
|
547 |
+
# if isinstance(drop_path, list)
|
548 |
+
# else drop_path,
|
549 |
+
# norm_layer=norm_layer,
|
550 |
+
# stripe_type="H" if (i % 2 == 0) else "W",
|
551 |
+
# args=args,
|
552 |
+
# )
|
553 |
+
# for i in range(depth)
|
554 |
+
# ]
|
555 |
+
# )
|
556 |
+
# # self.blocks = nn.ModuleList(
|
557 |
+
# # [
|
558 |
+
# # STV1Block(
|
559 |
+
# # dim=dim,
|
560 |
+
# # input_resolution=input_resolution,
|
561 |
+
# # num_heads=num_heads,
|
562 |
+
# # window_size=window_size,
|
563 |
+
# # shift_size=0 if (i % 2 == 0) else window_size // 2,
|
564 |
+
# # mlp_ratio=mlp_ratio,
|
565 |
+
# # qkv_bias=qkv_bias,
|
566 |
+
# # qk_scale=qk_scale,
|
567 |
+
# # drop=drop,
|
568 |
+
# # attn_drop=attn_drop,
|
569 |
+
# # drop_path=drop_path[i]
|
570 |
+
# # if isinstance(drop_path, list)
|
571 |
+
# # else drop_path,
|
572 |
+
# # norm_layer=norm_layer,
|
573 |
+
# # )
|
574 |
+
# # for i in range(depth)
|
575 |
+
# # ]
|
576 |
+
# # )
|
577 |
+
|
578 |
+
# # patch merging layer
|
579 |
+
# if downsample is not None:
|
580 |
+
# self.downsample = downsample(
|
581 |
+
# input_resolution, dim=dim, norm_layer=norm_layer
|
582 |
+
# )
|
583 |
+
# else:
|
584 |
+
# self.downsample = None
|
585 |
+
|
586 |
+
# def forward(self, x, x_size):
|
587 |
+
# for blk in self.blocks:
|
588 |
+
# x = blk(x, x_size)
|
589 |
+
# if self.downsample is not None:
|
590 |
+
# x = self.downsample(x)
|
591 |
+
# return x
|
592 |
+
|
593 |
+
# def extra_repr(self) -> str:
|
594 |
+
# return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
|
595 |
+
|
596 |
+
# def flops(self):
|
597 |
+
# flops = 0
|
598 |
+
# for blk in self.blocks:
|
599 |
+
# flops += blk.flops()
|
600 |
+
# if self.downsample is not None:
|
601 |
+
# flops += self.downsample.flops()
|
602 |
+
# return flops
|
architecture/grl_common/swin_v2_block.py
ADDED
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from math import prod
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from architecture.grl_common.ops import (
|
8 |
+
calculate_mask,
|
9 |
+
get_relative_coords_table,
|
10 |
+
get_relative_position_index,
|
11 |
+
window_partition,
|
12 |
+
window_reverse,
|
13 |
+
)
|
14 |
+
from architecture.grl_common.swin_v1_block import Mlp
|
15 |
+
from timm.models.layers import DropPath, to_2tuple
|
16 |
+
|
17 |
+
|
18 |
+
class WindowAttentionV2(nn.Module):
|
19 |
+
r"""Window based multi-head self attention (W-MSA) module with relative position bias.
|
20 |
+
It supports both of shifted and non-shifted window.
|
21 |
+
Args:
|
22 |
+
dim (int): Number of input channels.
|
23 |
+
window_size (tuple[int]): The height and width of the window.
|
24 |
+
num_heads (int): Number of attention heads.
|
25 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
26 |
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
27 |
+
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
28 |
+
pretrained_window_size (tuple[int]): The height and width of the window in pre-training.
|
29 |
+
"""
|
30 |
+
|
31 |
+
def __init__(
|
32 |
+
self,
|
33 |
+
dim,
|
34 |
+
window_size,
|
35 |
+
num_heads,
|
36 |
+
qkv_bias=True,
|
37 |
+
attn_drop=0.0,
|
38 |
+
proj_drop=0.0,
|
39 |
+
pretrained_window_size=[0, 0],
|
40 |
+
use_pe=True,
|
41 |
+
):
|
42 |
+
|
43 |
+
super().__init__()
|
44 |
+
self.dim = dim
|
45 |
+
self.window_size = window_size # Wh, Ww
|
46 |
+
self.pretrained_window_size = pretrained_window_size
|
47 |
+
self.num_heads = num_heads
|
48 |
+
self.use_pe = use_pe
|
49 |
+
|
50 |
+
self.logit_scale = nn.Parameter(
|
51 |
+
torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True
|
52 |
+
)
|
53 |
+
|
54 |
+
if self.use_pe:
|
55 |
+
# mlp to generate continuous relative position bias
|
56 |
+
self.cpb_mlp = nn.Sequential(
|
57 |
+
nn.Linear(2, 512, bias=True),
|
58 |
+
nn.ReLU(inplace=True),
|
59 |
+
nn.Linear(512, num_heads, bias=False),
|
60 |
+
)
|
61 |
+
table = get_relative_coords_table(window_size, pretrained_window_size)
|
62 |
+
index = get_relative_position_index(window_size)
|
63 |
+
self.register_buffer("relative_coords_table", table)
|
64 |
+
self.register_buffer("relative_position_index", index)
|
65 |
+
|
66 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
67 |
+
# self.qkv = nn.Linear(dim, dim * 3, bias=False)
|
68 |
+
# if qkv_bias:
|
69 |
+
# self.q_bias = nn.Parameter(torch.zeros(dim))
|
70 |
+
# self.v_bias = nn.Parameter(torch.zeros(dim))
|
71 |
+
# else:
|
72 |
+
# self.q_bias = None
|
73 |
+
# self.v_bias = None
|
74 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
75 |
+
self.proj = nn.Linear(dim, dim)
|
76 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
77 |
+
self.softmax = nn.Softmax(dim=-1)
|
78 |
+
|
79 |
+
def forward(self, x, mask=None):
|
80 |
+
"""
|
81 |
+
Args:
|
82 |
+
x: input features with shape of (num_windows*B, N, C)
|
83 |
+
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
84 |
+
"""
|
85 |
+
B_, N, C = x.shape
|
86 |
+
|
87 |
+
# qkv projection
|
88 |
+
# qkv_bias = None
|
89 |
+
# if self.q_bias is not None:
|
90 |
+
# qkv_bias = torch.cat(
|
91 |
+
# (
|
92 |
+
# self.q_bias,
|
93 |
+
# torch.zeros_like(self.v_bias, requires_grad=False),
|
94 |
+
# self.v_bias,
|
95 |
+
# )
|
96 |
+
# )
|
97 |
+
# qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
98 |
+
qkv = self.qkv(x)
|
99 |
+
qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
100 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
101 |
+
|
102 |
+
# cosine attention map
|
103 |
+
attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)
|
104 |
+
logit_scale = torch.clamp(self.logit_scale, max=math.log(1.0 / 0.01)).exp()
|
105 |
+
attn = attn * logit_scale
|
106 |
+
|
107 |
+
# positional encoding
|
108 |
+
if self.use_pe:
|
109 |
+
bias_table = self.cpb_mlp(self.relative_coords_table)
|
110 |
+
bias_table = bias_table.view(-1, self.num_heads)
|
111 |
+
|
112 |
+
win_dim = prod(self.window_size)
|
113 |
+
bias = bias_table[self.relative_position_index.view(-1)]
|
114 |
+
bias = bias.view(win_dim, win_dim, -1).permute(2, 0, 1).contiguous()
|
115 |
+
# nH, Wh*Ww, Wh*Ww
|
116 |
+
bias = 16 * torch.sigmoid(bias)
|
117 |
+
attn = attn + bias.unsqueeze(0)
|
118 |
+
|
119 |
+
# shift attention mask
|
120 |
+
if mask is not None:
|
121 |
+
nW = mask.shape[0]
|
122 |
+
mask = mask.unsqueeze(1).unsqueeze(0)
|
123 |
+
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask
|
124 |
+
attn = attn.view(-1, self.num_heads, N, N)
|
125 |
+
|
126 |
+
# attention
|
127 |
+
attn = self.softmax(attn)
|
128 |
+
attn = self.attn_drop(attn)
|
129 |
+
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
130 |
+
|
131 |
+
# output projection
|
132 |
+
x = self.proj(x)
|
133 |
+
x = self.proj_drop(x)
|
134 |
+
return x
|
135 |
+
|
136 |
+
def extra_repr(self) -> str:
|
137 |
+
return (
|
138 |
+
f"dim={self.dim}, window_size={self.window_size}, "
|
139 |
+
f"pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}"
|
140 |
+
)
|
141 |
+
|
142 |
+
def flops(self, N):
|
143 |
+
# calculate flops for 1 window with token length of N
|
144 |
+
flops = 0
|
145 |
+
# qkv = self.qkv(x)
|
146 |
+
flops += N * self.dim * 3 * self.dim
|
147 |
+
# attn = (q @ k.transpose(-2, -1))
|
148 |
+
flops += self.num_heads * N * (self.dim // self.num_heads) * N
|
149 |
+
# x = (attn @ v)
|
150 |
+
flops += self.num_heads * N * N * (self.dim // self.num_heads)
|
151 |
+
# x = self.proj(x)
|
152 |
+
flops += N * self.dim * self.dim
|
153 |
+
return flops
|
154 |
+
|
155 |
+
|
156 |
+
class WindowAttentionWrapperV2(WindowAttentionV2):
|
157 |
+
def __init__(self, shift_size, input_resolution, **kwargs):
|
158 |
+
super(WindowAttentionWrapperV2, self).__init__(**kwargs)
|
159 |
+
self.shift_size = shift_size
|
160 |
+
self.input_resolution = input_resolution
|
161 |
+
|
162 |
+
if self.shift_size > 0:
|
163 |
+
attn_mask = calculate_mask(input_resolution, self.window_size, shift_size)
|
164 |
+
else:
|
165 |
+
attn_mask = None
|
166 |
+
self.register_buffer("attn_mask", attn_mask)
|
167 |
+
|
168 |
+
def forward(self, x, x_size):
|
169 |
+
H, W = x_size
|
170 |
+
B, L, C = x.shape
|
171 |
+
x = x.view(B, H, W, C)
|
172 |
+
|
173 |
+
# cyclic shift
|
174 |
+
if self.shift_size > 0:
|
175 |
+
x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
176 |
+
|
177 |
+
# partition windows
|
178 |
+
x = window_partition(x, self.window_size) # nW*B, wh, ww, C
|
179 |
+
x = x.view(-1, prod(self.window_size), C) # nW*B, wh*ww, C
|
180 |
+
|
181 |
+
# W-MSA/SW-MSA
|
182 |
+
if self.input_resolution == x_size:
|
183 |
+
attn_mask = self.attn_mask
|
184 |
+
else:
|
185 |
+
attn_mask = calculate_mask(x_size, self.window_size, self.shift_size)
|
186 |
+
attn_mask = attn_mask.to(x.device)
|
187 |
+
|
188 |
+
# attention
|
189 |
+
x = super(WindowAttentionWrapperV2, self).forward(x, mask=attn_mask)
|
190 |
+
# nW*B, wh*ww, C
|
191 |
+
|
192 |
+
# merge windows
|
193 |
+
x = x.view(-1, *self.window_size, C)
|
194 |
+
x = window_reverse(x, self.window_size, x_size) # B, H, W, C
|
195 |
+
|
196 |
+
# reverse cyclic shift
|
197 |
+
if self.shift_size > 0:
|
198 |
+
x = torch.roll(x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
199 |
+
x = x.view(B, H * W, C)
|
200 |
+
|
201 |
+
return x
|
202 |
+
|
203 |
+
|
204 |
+
class SwinTransformerBlockV2(nn.Module):
|
205 |
+
r"""Swin Transformer Block.
|
206 |
+
Args:
|
207 |
+
dim (int): Number of input channels.
|
208 |
+
input_resolution (tuple[int]): Input resulotion.
|
209 |
+
num_heads (int): Number of attention heads.
|
210 |
+
window_size (int): Window size.
|
211 |
+
shift_size (int): Shift size for SW-MSA.
|
212 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
213 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
214 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
215 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
216 |
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
217 |
+
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
218 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
219 |
+
pretrained_window_size (int): Window size in pre-training.
|
220 |
+
"""
|
221 |
+
|
222 |
+
def __init__(
|
223 |
+
self,
|
224 |
+
dim,
|
225 |
+
input_resolution,
|
226 |
+
num_heads,
|
227 |
+
window_size=7,
|
228 |
+
shift_size=0,
|
229 |
+
mlp_ratio=4.0,
|
230 |
+
qkv_bias=True,
|
231 |
+
drop=0.0,
|
232 |
+
attn_drop=0.0,
|
233 |
+
drop_path=0.0,
|
234 |
+
act_layer=nn.GELU,
|
235 |
+
norm_layer=nn.LayerNorm,
|
236 |
+
pretrained_window_size=0,
|
237 |
+
use_pe=True,
|
238 |
+
res_scale=1.0,
|
239 |
+
):
|
240 |
+
super().__init__()
|
241 |
+
self.dim = dim
|
242 |
+
self.input_resolution = input_resolution
|
243 |
+
self.num_heads = num_heads
|
244 |
+
self.window_size = window_size
|
245 |
+
self.shift_size = shift_size
|
246 |
+
self.mlp_ratio = mlp_ratio
|
247 |
+
if min(self.input_resolution) <= self.window_size:
|
248 |
+
# if window size is larger than input resolution, we don't partition windows
|
249 |
+
self.shift_size = 0
|
250 |
+
self.window_size = min(self.input_resolution)
|
251 |
+
assert (
|
252 |
+
0 <= self.shift_size < self.window_size
|
253 |
+
), "shift_size must in 0-window_size"
|
254 |
+
self.res_scale = res_scale
|
255 |
+
|
256 |
+
self.attn = WindowAttentionWrapperV2(
|
257 |
+
shift_size=self.shift_size,
|
258 |
+
input_resolution=self.input_resolution,
|
259 |
+
dim=dim,
|
260 |
+
window_size=to_2tuple(self.window_size),
|
261 |
+
num_heads=num_heads,
|
262 |
+
qkv_bias=qkv_bias,
|
263 |
+
attn_drop=attn_drop,
|
264 |
+
proj_drop=drop,
|
265 |
+
pretrained_window_size=to_2tuple(pretrained_window_size),
|
266 |
+
use_pe=use_pe,
|
267 |
+
)
|
268 |
+
self.norm1 = norm_layer(dim)
|
269 |
+
|
270 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
271 |
+
|
272 |
+
self.mlp = Mlp(
|
273 |
+
in_features=dim,
|
274 |
+
hidden_features=int(dim * mlp_ratio),
|
275 |
+
act_layer=act_layer,
|
276 |
+
drop=drop,
|
277 |
+
)
|
278 |
+
self.norm2 = norm_layer(dim)
|
279 |
+
|
280 |
+
def forward(self, x, x_size):
|
281 |
+
# Window attention
|
282 |
+
x = x + self.res_scale * self.drop_path(self.norm1(self.attn(x, x_size)))
|
283 |
+
# FFN
|
284 |
+
x = x + self.res_scale * self.drop_path(self.norm2(self.mlp(x)))
|
285 |
+
|
286 |
+
return x
|
287 |
+
|
288 |
+
def extra_repr(self) -> str:
|
289 |
+
return (
|
290 |
+
f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, "
|
291 |
+
f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}, res_scale={self.res_scale}"
|
292 |
+
)
|
293 |
+
|
294 |
+
def flops(self):
|
295 |
+
flops = 0
|
296 |
+
H, W = self.input_resolution
|
297 |
+
# norm1
|
298 |
+
flops += self.dim * H * W
|
299 |
+
# W-MSA/SW-MSA
|
300 |
+
nW = H * W / self.window_size / self.window_size
|
301 |
+
flops += nW * self.attn.flops(self.window_size * self.window_size)
|
302 |
+
# mlp
|
303 |
+
flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
|
304 |
+
# norm2
|
305 |
+
flops += self.dim * H * W
|
306 |
+
return flops
|
architecture/grl_common/upsample.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
|
6 |
+
class Upsample(nn.Module):
|
7 |
+
"""Upsample module.
|
8 |
+
Args:
|
9 |
+
scale (int): Scale factor. Supported scales: 2^n and 3.
|
10 |
+
num_feat (int): Channel number of intermediate features.
|
11 |
+
"""
|
12 |
+
|
13 |
+
def __init__(self, scale, num_feat):
|
14 |
+
super(Upsample, self).__init__()
|
15 |
+
m = []
|
16 |
+
if (scale & (scale - 1)) == 0: # scale = 2^n
|
17 |
+
for _ in range(int(math.log(scale, 2))):
|
18 |
+
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
|
19 |
+
m.append(nn.PixelShuffle(2))
|
20 |
+
elif scale == 3:
|
21 |
+
m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
|
22 |
+
m.append(nn.PixelShuffle(3))
|
23 |
+
else:
|
24 |
+
raise ValueError(
|
25 |
+
f"scale {scale} is not supported. " "Supported scales: 2^n and 3."
|
26 |
+
)
|
27 |
+
self.up = nn.Sequential(*m)
|
28 |
+
|
29 |
+
def forward(self, x):
|
30 |
+
return self.up(x)
|
31 |
+
|
32 |
+
|
33 |
+
class UpsampleOneStep(nn.Module):
|
34 |
+
"""UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
|
35 |
+
Used in lightweight SR to save parameters.
|
36 |
+
Args:
|
37 |
+
scale (int): Scale factor. Supported scales: 2^n and 3.
|
38 |
+
num_feat (int): Channel number of intermediate features.
|
39 |
+
"""
|
40 |
+
|
41 |
+
def __init__(self, scale, num_feat, num_out_ch):
|
42 |
+
super(UpsampleOneStep, self).__init__()
|
43 |
+
self.num_feat = num_feat
|
44 |
+
m = []
|
45 |
+
m.append(nn.Conv2d(num_feat, (scale**2) * num_out_ch, 3, 1, 1))
|
46 |
+
m.append(nn.PixelShuffle(scale))
|
47 |
+
self.up = nn.Sequential(*m)
|
48 |
+
|
49 |
+
def forward(self, x):
|
50 |
+
return self.up(x)
|
architecture/rrdb.py
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
# Paper Github Repository: https://github.com/xinntao/Real-ESRGAN
|
4 |
+
# Code snippet from: https://github.com/XPixelGroup/BasicSR/blob/master/basicsr/archs/rrdbnet_arch.py
|
5 |
+
# Paper: https://arxiv.org/pdf/2107.10833.pdf
|
6 |
+
|
7 |
+
import os, sys
|
8 |
+
import torch
|
9 |
+
from torch import nn as nn
|
10 |
+
from torch.nn import functional as F
|
11 |
+
from itertools import repeat
|
12 |
+
from torch.nn import init as init
|
13 |
+
from torch.nn.modules.batchnorm import _BatchNorm
|
14 |
+
|
15 |
+
|
16 |
+
def pixel_unshuffle(x, scale):
|
17 |
+
""" Pixel unshuffle.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
x (Tensor): Input feature with shape (b, c, hh, hw).
|
21 |
+
scale (int): Downsample ratio.
|
22 |
+
|
23 |
+
Returns:
|
24 |
+
Tensor: the pixel unshuffled feature.
|
25 |
+
"""
|
26 |
+
b, c, hh, hw = x.size()
|
27 |
+
out_channel = c * (scale**2)
|
28 |
+
assert hh % scale == 0 and hw % scale == 0
|
29 |
+
h = hh // scale
|
30 |
+
w = hw // scale
|
31 |
+
x_view = x.view(b, c, h, scale, w, scale)
|
32 |
+
return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
|
33 |
+
|
34 |
+
def make_layer(basic_block, num_basic_block, **kwarg):
|
35 |
+
"""Make layers by stacking the same blocks.
|
36 |
+
|
37 |
+
Args:
|
38 |
+
basic_block (nn.module): nn.module class for basic block.
|
39 |
+
num_basic_block (int): number of blocks.
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
nn.Sequential: Stacked blocks in nn.Sequential.
|
43 |
+
"""
|
44 |
+
layers = []
|
45 |
+
for _ in range(num_basic_block):
|
46 |
+
layers.append(basic_block(**kwarg))
|
47 |
+
return nn.Sequential(*layers)
|
48 |
+
|
49 |
+
def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
|
50 |
+
"""Initialize network weights.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
module_list (list[nn.Module] | nn.Module): Modules to be initialized.
|
54 |
+
scale (float): Scale initialized weights, especially for residual
|
55 |
+
blocks. Default: 1.
|
56 |
+
bias_fill (float): The value to fill bias. Default: 0
|
57 |
+
kwargs (dict): Other arguments for initialization function.
|
58 |
+
"""
|
59 |
+
if not isinstance(module_list, list):
|
60 |
+
module_list = [module_list]
|
61 |
+
for module in module_list:
|
62 |
+
for m in module.modules():
|
63 |
+
if isinstance(m, nn.Conv2d):
|
64 |
+
init.kaiming_normal_(m.weight, **kwargs)
|
65 |
+
m.weight.data *= scale
|
66 |
+
if m.bias is not None:
|
67 |
+
m.bias.data.fill_(bias_fill)
|
68 |
+
elif isinstance(m, nn.Linear):
|
69 |
+
init.kaiming_normal_(m.weight, **kwargs)
|
70 |
+
m.weight.data *= scale
|
71 |
+
if m.bias is not None:
|
72 |
+
m.bias.data.fill_(bias_fill)
|
73 |
+
elif isinstance(m, _BatchNorm):
|
74 |
+
init.constant_(m.weight, 1)
|
75 |
+
if m.bias is not None:
|
76 |
+
m.bias.data.fill_(bias_fill)
|
77 |
+
|
78 |
+
class ResidualDenseBlock(nn.Module):
|
79 |
+
"""Residual Dense Block.
|
80 |
+
|
81 |
+
Used in RRDB block in ESRGAN.
|
82 |
+
|
83 |
+
Args:
|
84 |
+
num_feat (int): Channel number of intermediate features.
|
85 |
+
num_grow_ch (int): Channels for each growth.
|
86 |
+
"""
|
87 |
+
|
88 |
+
def __init__(self, num_feat=64, num_grow_ch=32):
|
89 |
+
super(ResidualDenseBlock, self).__init__()
|
90 |
+
self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
|
91 |
+
self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
|
92 |
+
self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
|
93 |
+
self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
|
94 |
+
self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
|
95 |
+
|
96 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
97 |
+
|
98 |
+
# initialization
|
99 |
+
default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
|
100 |
+
|
101 |
+
def forward(self, x):
|
102 |
+
x1 = self.lrelu(self.conv1(x))
|
103 |
+
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
104 |
+
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
105 |
+
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
106 |
+
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
107 |
+
# Empirically, we use 0.2 to scale the residual for better performance
|
108 |
+
return x5 * 0.2 + x
|
109 |
+
|
110 |
+
|
111 |
+
class RRDB(nn.Module):
|
112 |
+
"""Residual in Residual Dense Block.
|
113 |
+
|
114 |
+
Used in RRDB-Net in ESRGAN.
|
115 |
+
|
116 |
+
Args:
|
117 |
+
num_feat (int): Channel number of intermediate features.
|
118 |
+
num_grow_ch (int): Channels for each growth.
|
119 |
+
"""
|
120 |
+
|
121 |
+
def __init__(self, num_feat, num_grow_ch=32):
|
122 |
+
super(RRDB, self).__init__()
|
123 |
+
self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
|
124 |
+
self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
|
125 |
+
self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
|
126 |
+
|
127 |
+
def forward(self, x):
|
128 |
+
out = self.rdb1(x)
|
129 |
+
out = self.rdb2(out)
|
130 |
+
out = self.rdb3(out)
|
131 |
+
# Empirically, we use 0.2 to scale the residual for better performance
|
132 |
+
return out * 0.2 + x
|
133 |
+
|
134 |
+
|
135 |
+
|
136 |
+
class RRDBNet(nn.Module):
|
137 |
+
"""Networks consisting of Residual in Residual Dense Block, which is used
|
138 |
+
in ESRGAN.
|
139 |
+
|
140 |
+
ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.
|
141 |
+
|
142 |
+
We extend ESRGAN for scale x2 and scale x1.
|
143 |
+
Note: This is one option for scale 1, scale 2 in RRDBNet.
|
144 |
+
We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size
|
145 |
+
and enlarge the channel size before feeding inputs into the main ESRGAN architecture.
|
146 |
+
|
147 |
+
Args:
|
148 |
+
num_in_ch (int): Channel number of inputs.
|
149 |
+
num_out_ch (int): Channel number of outputs.
|
150 |
+
num_feat (int): Channel number of intermediate features.
|
151 |
+
Default: 64
|
152 |
+
num_block (int): Block number in the trunk network. Defaults: 6 for our Anime training cases
|
153 |
+
num_grow_ch (int): Channels for each growth. Default: 32.
|
154 |
+
"""
|
155 |
+
|
156 |
+
def __init__(self, num_in_ch, num_out_ch, scale, num_feat=64, num_block=6, num_grow_ch=32):
|
157 |
+
|
158 |
+
super(RRDBNet, self).__init__()
|
159 |
+
self.scale = scale
|
160 |
+
if scale == 2:
|
161 |
+
num_in_ch = num_in_ch * 4
|
162 |
+
elif scale == 1:
|
163 |
+
num_in_ch = num_in_ch * 16
|
164 |
+
self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
|
165 |
+
self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch)
|
166 |
+
self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
167 |
+
# upsample
|
168 |
+
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
169 |
+
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
170 |
+
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
171 |
+
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
172 |
+
|
173 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
174 |
+
|
175 |
+
def forward(self, x):
|
176 |
+
if self.scale == 2:
|
177 |
+
feat = pixel_unshuffle(x, scale=2)
|
178 |
+
elif self.scale == 1:
|
179 |
+
feat = pixel_unshuffle(x, scale=4)
|
180 |
+
else:
|
181 |
+
feat = x
|
182 |
+
feat = self.conv_first(feat)
|
183 |
+
body_feat = self.conv_body(self.body(feat))
|
184 |
+
feat = feat + body_feat
|
185 |
+
# upsample
|
186 |
+
feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
|
187 |
+
feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
|
188 |
+
out = self.conv_last(self.lrelu(self.conv_hr(feat)))
|
189 |
+
return out
|
190 |
+
|
191 |
+
|
192 |
+
|
193 |
+
def main():
|
194 |
+
root_path = os.path.abspath('.')
|
195 |
+
sys.path.append(root_path)
|
196 |
+
|
197 |
+
from opt import opt # Manage GPU to choose
|
198 |
+
from pthflops import count_ops
|
199 |
+
from torchsummary import summary
|
200 |
+
import time
|
201 |
+
|
202 |
+
# We use RRDB 6Blocks by default.
|
203 |
+
model = RRDBNet(3, 3).cuda()
|
204 |
+
pytorch_total_params = sum(p.numel() for p in model.parameters())
|
205 |
+
print(f"RRDB has param {pytorch_total_params//1000} K params")
|
206 |
+
|
207 |
+
|
208 |
+
# Count the number of FLOPs to double check
|
209 |
+
x = torch.randn((1, 3, 180, 180)).cuda()
|
210 |
+
start = time.time()
|
211 |
+
x = model(x)
|
212 |
+
print("output size is ", x.shape)
|
213 |
+
total = time.time() - start
|
214 |
+
print(total)
|
215 |
+
|
216 |
+
|
217 |
+
if __name__ == "__main__":
|
218 |
+
main()
|
architecture/swinir.py
ADDED
@@ -0,0 +1,874 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -----------------------------------------------------------------------------------
|
2 |
+
# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257
|
3 |
+
# Originally Written by Ze Liu, Modified by Jingyun Liang.
|
4 |
+
# -----------------------------------------------------------------------------------
|
5 |
+
|
6 |
+
import math
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import torch.utils.checkpoint as checkpoint
|
11 |
+
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
12 |
+
|
13 |
+
|
14 |
+
class Mlp(nn.Module):
|
15 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
16 |
+
super().__init__()
|
17 |
+
out_features = out_features or in_features
|
18 |
+
hidden_features = hidden_features or in_features
|
19 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
20 |
+
self.act = act_layer()
|
21 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
22 |
+
self.drop = nn.Dropout(drop)
|
23 |
+
|
24 |
+
def forward(self, x):
|
25 |
+
x = self.fc1(x)
|
26 |
+
x = self.act(x)
|
27 |
+
x = self.drop(x)
|
28 |
+
x = self.fc2(x)
|
29 |
+
x = self.drop(x)
|
30 |
+
return x
|
31 |
+
|
32 |
+
|
33 |
+
def window_partition(x, window_size):
|
34 |
+
"""
|
35 |
+
Args:
|
36 |
+
x: (B, H, W, C)
|
37 |
+
window_size (int): window size
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
windows: (num_windows*B, window_size, window_size, C)
|
41 |
+
"""
|
42 |
+
B, H, W, C = x.shape
|
43 |
+
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
44 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
45 |
+
return windows
|
46 |
+
|
47 |
+
|
48 |
+
def window_reverse(windows, window_size, H, W):
|
49 |
+
"""
|
50 |
+
Args:
|
51 |
+
windows: (num_windows*B, window_size, window_size, C)
|
52 |
+
window_size (int): Window size
|
53 |
+
H (int): Height of image
|
54 |
+
W (int): Width of image
|
55 |
+
|
56 |
+
Returns:
|
57 |
+
x: (B, H, W, C)
|
58 |
+
"""
|
59 |
+
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
60 |
+
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
61 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
62 |
+
return x
|
63 |
+
|
64 |
+
|
65 |
+
class WindowAttention(nn.Module):
|
66 |
+
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
67 |
+
It supports both of shifted and non-shifted window.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
dim (int): Number of input channels.
|
71 |
+
window_size (tuple[int]): The height and width of the window.
|
72 |
+
num_heads (int): Number of attention heads.
|
73 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
74 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
75 |
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
76 |
+
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
77 |
+
"""
|
78 |
+
|
79 |
+
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
|
80 |
+
|
81 |
+
super().__init__()
|
82 |
+
self.dim = dim
|
83 |
+
self.window_size = window_size # Wh, Ww
|
84 |
+
self.num_heads = num_heads
|
85 |
+
head_dim = dim // num_heads
|
86 |
+
self.scale = qk_scale or head_dim ** -0.5
|
87 |
+
|
88 |
+
# define a parameter table of relative position bias
|
89 |
+
self.relative_position_bias_table = nn.Parameter(
|
90 |
+
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
91 |
+
|
92 |
+
# get pair-wise relative position index for each token inside the window
|
93 |
+
coords_h = torch.arange(self.window_size[0])
|
94 |
+
coords_w = torch.arange(self.window_size[1])
|
95 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
96 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
97 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
98 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
99 |
+
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
|
100 |
+
relative_coords[:, :, 1] += self.window_size[1] - 1
|
101 |
+
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
102 |
+
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
103 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
104 |
+
|
105 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
106 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
107 |
+
self.proj = nn.Linear(dim, dim)
|
108 |
+
|
109 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
110 |
+
|
111 |
+
trunc_normal_(self.relative_position_bias_table, std=.02)
|
112 |
+
self.softmax = nn.Softmax(dim=-1)
|
113 |
+
|
114 |
+
def forward(self, x, mask=None):
|
115 |
+
"""
|
116 |
+
Args:
|
117 |
+
x: input features with shape of (num_windows*B, N, C)
|
118 |
+
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
119 |
+
"""
|
120 |
+
B_, N, C = x.shape
|
121 |
+
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
122 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
123 |
+
|
124 |
+
q = q * self.scale
|
125 |
+
attn = (q @ k.transpose(-2, -1))
|
126 |
+
|
127 |
+
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
128 |
+
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
|
129 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
130 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
131 |
+
|
132 |
+
if mask is not None:
|
133 |
+
nW = mask.shape[0]
|
134 |
+
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
135 |
+
attn = attn.view(-1, self.num_heads, N, N)
|
136 |
+
attn = self.softmax(attn)
|
137 |
+
else:
|
138 |
+
attn = self.softmax(attn)
|
139 |
+
|
140 |
+
attn = self.attn_drop(attn)
|
141 |
+
|
142 |
+
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
143 |
+
x = self.proj(x)
|
144 |
+
x = self.proj_drop(x)
|
145 |
+
return x
|
146 |
+
|
147 |
+
def extra_repr(self) -> str:
|
148 |
+
return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
|
149 |
+
|
150 |
+
def flops(self, N):
|
151 |
+
# calculate flops for 1 window with token length of N
|
152 |
+
flops = 0
|
153 |
+
# qkv = self.qkv(x)
|
154 |
+
flops += N * self.dim * 3 * self.dim
|
155 |
+
# attn = (q @ k.transpose(-2, -1))
|
156 |
+
flops += self.num_heads * N * (self.dim // self.num_heads) * N
|
157 |
+
# x = (attn @ v)
|
158 |
+
flops += self.num_heads * N * N * (self.dim // self.num_heads)
|
159 |
+
# x = self.proj(x)
|
160 |
+
flops += N * self.dim * self.dim
|
161 |
+
return flops
|
162 |
+
|
163 |
+
|
164 |
+
class SwinTransformerBlock(nn.Module):
|
165 |
+
r""" Swin Transformer Block.
|
166 |
+
|
167 |
+
Args:
|
168 |
+
dim (int): Number of input channels.
|
169 |
+
input_resolution (tuple[int]): Input resulotion.
|
170 |
+
num_heads (int): Number of attention heads.
|
171 |
+
window_size (int): Window size.
|
172 |
+
shift_size (int): Shift size for SW-MSA.
|
173 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
174 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
175 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
176 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
177 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
178 |
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
179 |
+
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
180 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
181 |
+
"""
|
182 |
+
|
183 |
+
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
|
184 |
+
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
|
185 |
+
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
186 |
+
super().__init__()
|
187 |
+
self.dim = dim
|
188 |
+
self.input_resolution = input_resolution
|
189 |
+
self.num_heads = num_heads
|
190 |
+
self.window_size = window_size
|
191 |
+
self.shift_size = shift_size
|
192 |
+
self.mlp_ratio = mlp_ratio
|
193 |
+
if min(self.input_resolution) <= self.window_size:
|
194 |
+
# if window size is larger than input resolution, we don't partition windows
|
195 |
+
self.shift_size = 0
|
196 |
+
self.window_size = min(self.input_resolution)
|
197 |
+
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
|
198 |
+
|
199 |
+
self.norm1 = norm_layer(dim)
|
200 |
+
self.attn = WindowAttention(
|
201 |
+
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
|
202 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
203 |
+
|
204 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
205 |
+
self.norm2 = norm_layer(dim)
|
206 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
207 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
208 |
+
|
209 |
+
if self.shift_size > 0:
|
210 |
+
attn_mask = self.calculate_mask(self.input_resolution)
|
211 |
+
else:
|
212 |
+
attn_mask = None
|
213 |
+
|
214 |
+
self.register_buffer("attn_mask", attn_mask)
|
215 |
+
|
216 |
+
def calculate_mask(self, x_size):
|
217 |
+
# calculate attention mask for SW-MSA
|
218 |
+
H, W = x_size
|
219 |
+
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
|
220 |
+
h_slices = (slice(0, -self.window_size),
|
221 |
+
slice(-self.window_size, -self.shift_size),
|
222 |
+
slice(-self.shift_size, None))
|
223 |
+
w_slices = (slice(0, -self.window_size),
|
224 |
+
slice(-self.window_size, -self.shift_size),
|
225 |
+
slice(-self.shift_size, None))
|
226 |
+
cnt = 0
|
227 |
+
for h in h_slices:
|
228 |
+
for w in w_slices:
|
229 |
+
img_mask[:, h, w, :] = cnt
|
230 |
+
cnt += 1
|
231 |
+
|
232 |
+
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
|
233 |
+
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
234 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
235 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
236 |
+
|
237 |
+
return attn_mask
|
238 |
+
|
239 |
+
def forward(self, x, x_size):
|
240 |
+
H, W = x_size
|
241 |
+
B, L, C = x.shape
|
242 |
+
# assert L == H * W, "input feature has wrong size"
|
243 |
+
|
244 |
+
shortcut = x
|
245 |
+
x = self.norm1(x)
|
246 |
+
x = x.view(B, H, W, C)
|
247 |
+
|
248 |
+
# cyclic shift
|
249 |
+
if self.shift_size > 0:
|
250 |
+
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
251 |
+
else:
|
252 |
+
shifted_x = x
|
253 |
+
|
254 |
+
# partition windows
|
255 |
+
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
256 |
+
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
|
257 |
+
|
258 |
+
# W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
|
259 |
+
if self.input_resolution == x_size:
|
260 |
+
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
|
261 |
+
else:
|
262 |
+
attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
|
263 |
+
|
264 |
+
# merge windows
|
265 |
+
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
266 |
+
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
|
267 |
+
|
268 |
+
# reverse cyclic shift
|
269 |
+
if self.shift_size > 0:
|
270 |
+
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
271 |
+
else:
|
272 |
+
x = shifted_x
|
273 |
+
x = x.view(B, H * W, C)
|
274 |
+
|
275 |
+
# FFN
|
276 |
+
x = shortcut + self.drop_path(x)
|
277 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
278 |
+
|
279 |
+
return x
|
280 |
+
|
281 |
+
def extra_repr(self) -> str:
|
282 |
+
return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
|
283 |
+
f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
|
284 |
+
|
285 |
+
def flops(self):
|
286 |
+
flops = 0
|
287 |
+
H, W = self.input_resolution
|
288 |
+
# norm1
|
289 |
+
flops += self.dim * H * W
|
290 |
+
# W-MSA/SW-MSA
|
291 |
+
nW = H * W / self.window_size / self.window_size
|
292 |
+
flops += nW * self.attn.flops(self.window_size * self.window_size)
|
293 |
+
# mlp
|
294 |
+
flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
|
295 |
+
# norm2
|
296 |
+
flops += self.dim * H * W
|
297 |
+
return flops
|
298 |
+
|
299 |
+
|
300 |
+
class PatchMerging(nn.Module):
|
301 |
+
r""" Patch Merging Layer.
|
302 |
+
|
303 |
+
Args:
|
304 |
+
input_resolution (tuple[int]): Resolution of input feature.
|
305 |
+
dim (int): Number of input channels.
|
306 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
307 |
+
"""
|
308 |
+
|
309 |
+
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
|
310 |
+
super().__init__()
|
311 |
+
self.input_resolution = input_resolution
|
312 |
+
self.dim = dim
|
313 |
+
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
314 |
+
self.norm = norm_layer(4 * dim)
|
315 |
+
|
316 |
+
def forward(self, x):
|
317 |
+
"""
|
318 |
+
x: B, H*W, C
|
319 |
+
"""
|
320 |
+
H, W = self.input_resolution
|
321 |
+
B, L, C = x.shape
|
322 |
+
assert L == H * W, "input feature has wrong size"
|
323 |
+
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
|
324 |
+
|
325 |
+
x = x.view(B, H, W, C)
|
326 |
+
|
327 |
+
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
328 |
+
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
329 |
+
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
330 |
+
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
331 |
+
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
332 |
+
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
333 |
+
|
334 |
+
x = self.norm(x)
|
335 |
+
x = self.reduction(x)
|
336 |
+
|
337 |
+
return x
|
338 |
+
|
339 |
+
def extra_repr(self) -> str:
|
340 |
+
return f"input_resolution={self.input_resolution}, dim={self.dim}"
|
341 |
+
|
342 |
+
def flops(self):
|
343 |
+
H, W = self.input_resolution
|
344 |
+
flops = H * W * self.dim
|
345 |
+
flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
|
346 |
+
return flops
|
347 |
+
|
348 |
+
|
349 |
+
class BasicLayer(nn.Module):
|
350 |
+
""" A basic Swin Transformer layer for one stage.
|
351 |
+
|
352 |
+
Args:
|
353 |
+
dim (int): Number of input channels.
|
354 |
+
input_resolution (tuple[int]): Input resolution.
|
355 |
+
depth (int): Number of blocks.
|
356 |
+
num_heads (int): Number of attention heads.
|
357 |
+
window_size (int): Local window size.
|
358 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
359 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
360 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
361 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
362 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
363 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
364 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
365 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
366 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
367 |
+
"""
|
368 |
+
|
369 |
+
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
|
370 |
+
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
|
371 |
+
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
|
372 |
+
|
373 |
+
super().__init__()
|
374 |
+
self.dim = dim
|
375 |
+
self.input_resolution = input_resolution
|
376 |
+
self.depth = depth
|
377 |
+
self.use_checkpoint = use_checkpoint
|
378 |
+
|
379 |
+
# build blocks
|
380 |
+
self.blocks = nn.ModuleList([
|
381 |
+
SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
|
382 |
+
num_heads=num_heads, window_size=window_size,
|
383 |
+
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
384 |
+
mlp_ratio=mlp_ratio,
|
385 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
386 |
+
drop=drop, attn_drop=attn_drop,
|
387 |
+
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
388 |
+
norm_layer=norm_layer)
|
389 |
+
for i in range(depth)])
|
390 |
+
|
391 |
+
# patch merging layer
|
392 |
+
if downsample is not None:
|
393 |
+
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
|
394 |
+
else:
|
395 |
+
self.downsample = None
|
396 |
+
|
397 |
+
def forward(self, x, x_size):
|
398 |
+
for blk in self.blocks:
|
399 |
+
if self.use_checkpoint:
|
400 |
+
x = checkpoint.checkpoint(blk, x, x_size)
|
401 |
+
else:
|
402 |
+
x = blk(x, x_size)
|
403 |
+
if self.downsample is not None:
|
404 |
+
x = self.downsample(x)
|
405 |
+
return x
|
406 |
+
|
407 |
+
def extra_repr(self) -> str:
|
408 |
+
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
|
409 |
+
|
410 |
+
def flops(self):
|
411 |
+
flops = 0
|
412 |
+
for blk in self.blocks:
|
413 |
+
flops += blk.flops()
|
414 |
+
if self.downsample is not None:
|
415 |
+
flops += self.downsample.flops()
|
416 |
+
return flops
|
417 |
+
|
418 |
+
|
419 |
+
class RSTB(nn.Module):
|
420 |
+
"""Residual Swin Transformer Block (RSTB).
|
421 |
+
|
422 |
+
Args:
|
423 |
+
dim (int): Number of input channels.
|
424 |
+
input_resolution (tuple[int]): Input resolution.
|
425 |
+
depth (int): Number of blocks.
|
426 |
+
num_heads (int): Number of attention heads.
|
427 |
+
window_size (int): Local window size.
|
428 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
429 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
430 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
431 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
432 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
433 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
434 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
435 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
436 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
437 |
+
img_size: Input image size.
|
438 |
+
patch_size: Patch size.
|
439 |
+
resi_connection: The convolutional block before residual connection.
|
440 |
+
"""
|
441 |
+
|
442 |
+
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
|
443 |
+
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
|
444 |
+
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
|
445 |
+
img_size=224, patch_size=4, resi_connection='1conv'):
|
446 |
+
super(RSTB, self).__init__()
|
447 |
+
|
448 |
+
self.dim = dim
|
449 |
+
self.input_resolution = input_resolution
|
450 |
+
|
451 |
+
self.residual_group = BasicLayer(dim=dim,
|
452 |
+
input_resolution=input_resolution,
|
453 |
+
depth=depth,
|
454 |
+
num_heads=num_heads,
|
455 |
+
window_size=window_size,
|
456 |
+
mlp_ratio=mlp_ratio,
|
457 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
458 |
+
drop=drop, attn_drop=attn_drop,
|
459 |
+
drop_path=drop_path,
|
460 |
+
norm_layer=norm_layer,
|
461 |
+
downsample=downsample,
|
462 |
+
use_checkpoint=use_checkpoint)
|
463 |
+
|
464 |
+
if resi_connection == '1conv':
|
465 |
+
self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
|
466 |
+
elif resi_connection == '3conv':
|
467 |
+
# to save parameters and memory
|
468 |
+
self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
469 |
+
nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
|
470 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
471 |
+
nn.Conv2d(dim // 4, dim, 3, 1, 1))
|
472 |
+
|
473 |
+
self.patch_embed = PatchEmbed(
|
474 |
+
img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
|
475 |
+
norm_layer=None)
|
476 |
+
|
477 |
+
self.patch_unembed = PatchUnEmbed(
|
478 |
+
img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
|
479 |
+
norm_layer=None)
|
480 |
+
|
481 |
+
def forward(self, x, x_size):
|
482 |
+
return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
|
483 |
+
|
484 |
+
def flops(self):
|
485 |
+
flops = 0
|
486 |
+
flops += self.residual_group.flops()
|
487 |
+
H, W = self.input_resolution
|
488 |
+
flops += H * W * self.dim * self.dim * 9
|
489 |
+
flops += self.patch_embed.flops()
|
490 |
+
flops += self.patch_unembed.flops()
|
491 |
+
|
492 |
+
return flops
|
493 |
+
|
494 |
+
|
495 |
+
class PatchEmbed(nn.Module):
|
496 |
+
r""" Image to Patch Embedding
|
497 |
+
|
498 |
+
Args:
|
499 |
+
img_size (int): Image size. Default: 224.
|
500 |
+
patch_size (int): Patch token size. Default: 4.
|
501 |
+
in_chans (int): Number of input image channels. Default: 3.
|
502 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
503 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
504 |
+
"""
|
505 |
+
|
506 |
+
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
507 |
+
super().__init__()
|
508 |
+
img_size = to_2tuple(img_size)
|
509 |
+
patch_size = to_2tuple(patch_size)
|
510 |
+
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
|
511 |
+
self.img_size = img_size
|
512 |
+
self.patch_size = patch_size
|
513 |
+
self.patches_resolution = patches_resolution
|
514 |
+
self.num_patches = patches_resolution[0] * patches_resolution[1]
|
515 |
+
|
516 |
+
self.in_chans = in_chans
|
517 |
+
self.embed_dim = embed_dim
|
518 |
+
|
519 |
+
if norm_layer is not None:
|
520 |
+
self.norm = norm_layer(embed_dim)
|
521 |
+
else:
|
522 |
+
self.norm = None
|
523 |
+
|
524 |
+
def forward(self, x):
|
525 |
+
x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
|
526 |
+
if self.norm is not None:
|
527 |
+
x = self.norm(x)
|
528 |
+
return x
|
529 |
+
|
530 |
+
def flops(self):
|
531 |
+
flops = 0
|
532 |
+
H, W = self.img_size
|
533 |
+
if self.norm is not None:
|
534 |
+
flops += H * W * self.embed_dim
|
535 |
+
return flops
|
536 |
+
|
537 |
+
|
538 |
+
class PatchUnEmbed(nn.Module):
|
539 |
+
r""" Image to Patch Unembedding
|
540 |
+
|
541 |
+
Args:
|
542 |
+
img_size (int): Image size. Default: 224.
|
543 |
+
patch_size (int): Patch token size. Default: 4.
|
544 |
+
in_chans (int): Number of input image channels. Default: 3.
|
545 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
546 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
547 |
+
"""
|
548 |
+
|
549 |
+
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
550 |
+
super().__init__()
|
551 |
+
img_size = to_2tuple(img_size)
|
552 |
+
patch_size = to_2tuple(patch_size)
|
553 |
+
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
|
554 |
+
self.img_size = img_size
|
555 |
+
self.patch_size = patch_size
|
556 |
+
self.patches_resolution = patches_resolution
|
557 |
+
self.num_patches = patches_resolution[0] * patches_resolution[1]
|
558 |
+
|
559 |
+
self.in_chans = in_chans
|
560 |
+
self.embed_dim = embed_dim
|
561 |
+
|
562 |
+
def forward(self, x, x_size):
|
563 |
+
B, HW, C = x.shape
|
564 |
+
x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
|
565 |
+
return x
|
566 |
+
|
567 |
+
def flops(self):
|
568 |
+
flops = 0
|
569 |
+
return flops
|
570 |
+
|
571 |
+
|
572 |
+
class Upsample(nn.Sequential):
|
573 |
+
"""Upsample module.
|
574 |
+
|
575 |
+
Args:
|
576 |
+
scale (int): Scale factor. Supported scales: 2^n and 3.
|
577 |
+
num_feat (int): Channel number of intermediate features.
|
578 |
+
"""
|
579 |
+
|
580 |
+
def __init__(self, scale, num_feat):
|
581 |
+
m = []
|
582 |
+
if (scale & (scale - 1)) == 0: # scale = 2^n
|
583 |
+
for _ in range(int(math.log(scale, 2))):
|
584 |
+
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
|
585 |
+
m.append(nn.PixelShuffle(2))
|
586 |
+
elif scale == 3:
|
587 |
+
m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
|
588 |
+
m.append(nn.PixelShuffle(3))
|
589 |
+
else:
|
590 |
+
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
|
591 |
+
super(Upsample, self).__init__(*m)
|
592 |
+
|
593 |
+
|
594 |
+
class UpsampleOneStep(nn.Sequential):
|
595 |
+
"""UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
|
596 |
+
Used in lightweight SR to save parameters.
|
597 |
+
|
598 |
+
Args:
|
599 |
+
scale (int): Scale factor. Supported scales: 2^n and 3.
|
600 |
+
num_feat (int): Channel number of intermediate features.
|
601 |
+
|
602 |
+
"""
|
603 |
+
|
604 |
+
def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
|
605 |
+
self.num_feat = num_feat
|
606 |
+
self.input_resolution = input_resolution
|
607 |
+
m = []
|
608 |
+
m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1))
|
609 |
+
m.append(nn.PixelShuffle(scale))
|
610 |
+
super(UpsampleOneStep, self).__init__(*m)
|
611 |
+
|
612 |
+
def flops(self):
|
613 |
+
H, W = self.input_resolution
|
614 |
+
flops = H * W * self.num_feat * 3 * 9
|
615 |
+
return flops
|
616 |
+
|
617 |
+
|
618 |
+
class SwinIR(nn.Module):
|
619 |
+
r""" SwinIR
|
620 |
+
A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer.
|
621 |
+
|
622 |
+
Args:
|
623 |
+
img_size (int | tuple(int)): Input image size. Default 64
|
624 |
+
patch_size (int | tuple(int)): Patch size. Default: 1
|
625 |
+
in_chans (int): Number of input image channels. Default: 3
|
626 |
+
embed_dim (int): Patch embedding dimension. Default: 96
|
627 |
+
depths (tuple(int)): Depth of each Swin Transformer layer.
|
628 |
+
num_heads (tuple(int)): Number of attention heads in different layers.
|
629 |
+
window_size (int): Window size. Default: 7
|
630 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
|
631 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
632 |
+
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
|
633 |
+
drop_rate (float): Dropout rate. Default: 0
|
634 |
+
attn_drop_rate (float): Attention dropout rate. Default: 0
|
635 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.1
|
636 |
+
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
637 |
+
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
|
638 |
+
patch_norm (bool): If True, add normalization after patch embedding. Default: True
|
639 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
|
640 |
+
upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
|
641 |
+
img_range: Image range. 1. or 255.
|
642 |
+
upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
|
643 |
+
resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
|
644 |
+
"""
|
645 |
+
|
646 |
+
def __init__(self, img_size=64, patch_size=1, in_chans=3,
|
647 |
+
embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
|
648 |
+
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
|
649 |
+
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
|
650 |
+
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
|
651 |
+
use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
|
652 |
+
**kwargs):
|
653 |
+
super(SwinIR, self).__init__()
|
654 |
+
num_in_ch = in_chans
|
655 |
+
num_out_ch = in_chans
|
656 |
+
num_feat = 64
|
657 |
+
self.img_range = img_range
|
658 |
+
if in_chans == 3:
|
659 |
+
rgb_mean = (0.4488, 0.4371, 0.4040)
|
660 |
+
self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
|
661 |
+
else:
|
662 |
+
self.mean = torch.zeros(1, 1, 1, 1)
|
663 |
+
self.upscale = upscale
|
664 |
+
self.upsampler = upsampler
|
665 |
+
self.window_size = window_size
|
666 |
+
|
667 |
+
#####################################################################################################
|
668 |
+
################################### 1, shallow feature extraction ###################################
|
669 |
+
self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
|
670 |
+
|
671 |
+
#####################################################################################################
|
672 |
+
################################### 2, deep feature extraction ######################################
|
673 |
+
self.num_layers = len(depths)
|
674 |
+
self.embed_dim = embed_dim
|
675 |
+
self.ape = ape
|
676 |
+
self.patch_norm = patch_norm
|
677 |
+
self.num_features = embed_dim
|
678 |
+
self.mlp_ratio = mlp_ratio
|
679 |
+
|
680 |
+
# split image into non-overlapping patches
|
681 |
+
self.patch_embed = PatchEmbed(
|
682 |
+
img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
|
683 |
+
norm_layer=norm_layer if self.patch_norm else None)
|
684 |
+
num_patches = self.patch_embed.num_patches
|
685 |
+
patches_resolution = self.patch_embed.patches_resolution
|
686 |
+
self.patches_resolution = patches_resolution
|
687 |
+
|
688 |
+
# merge non-overlapping patches into image
|
689 |
+
self.patch_unembed = PatchUnEmbed(
|
690 |
+
img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
|
691 |
+
norm_layer=norm_layer if self.patch_norm else None)
|
692 |
+
|
693 |
+
# absolute position embedding
|
694 |
+
if self.ape:
|
695 |
+
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
|
696 |
+
trunc_normal_(self.absolute_pos_embed, std=.02)
|
697 |
+
|
698 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
699 |
+
|
700 |
+
# stochastic depth
|
701 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
702 |
+
|
703 |
+
# build Residual Swin Transformer blocks (RSTB)
|
704 |
+
self.layers = nn.ModuleList()
|
705 |
+
for i_layer in range(self.num_layers):
|
706 |
+
layer = RSTB(dim=embed_dim,
|
707 |
+
input_resolution=(patches_resolution[0],
|
708 |
+
patches_resolution[1]),
|
709 |
+
depth=depths[i_layer],
|
710 |
+
num_heads=num_heads[i_layer],
|
711 |
+
window_size=window_size,
|
712 |
+
mlp_ratio=self.mlp_ratio,
|
713 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
714 |
+
drop=drop_rate, attn_drop=attn_drop_rate,
|
715 |
+
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
|
716 |
+
norm_layer=norm_layer,
|
717 |
+
downsample=None,
|
718 |
+
use_checkpoint=use_checkpoint,
|
719 |
+
img_size=img_size,
|
720 |
+
patch_size=patch_size,
|
721 |
+
resi_connection=resi_connection
|
722 |
+
|
723 |
+
)
|
724 |
+
self.layers.append(layer)
|
725 |
+
self.norm = norm_layer(self.num_features)
|
726 |
+
|
727 |
+
# build the last conv layer in deep feature extraction
|
728 |
+
if resi_connection == '1conv':
|
729 |
+
self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
|
730 |
+
elif resi_connection == '3conv':
|
731 |
+
# to save parameters and memory
|
732 |
+
self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
|
733 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
734 |
+
nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
|
735 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
736 |
+
nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
|
737 |
+
|
738 |
+
#####################################################################################################
|
739 |
+
################################ 3, high quality image reconstruction ################################
|
740 |
+
if self.upsampler == 'pixelshuffle':
|
741 |
+
# for classical SR
|
742 |
+
self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
|
743 |
+
nn.LeakyReLU(inplace=True))
|
744 |
+
self.upsample = Upsample(upscale, num_feat)
|
745 |
+
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
746 |
+
elif self.upsampler == 'pixelshuffledirect':
|
747 |
+
# for lightweight SR (to save parameters)
|
748 |
+
self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
|
749 |
+
(patches_resolution[0], patches_resolution[1]))
|
750 |
+
elif self.upsampler == 'nearest+conv':
|
751 |
+
# for real-world SR (less artifacts)
|
752 |
+
self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
|
753 |
+
nn.LeakyReLU(inplace=True))
|
754 |
+
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
755 |
+
if self.upscale == 4:
|
756 |
+
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
757 |
+
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
758 |
+
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
759 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
760 |
+
else:
|
761 |
+
# for image denoising and JPEG compression artifact reduction
|
762 |
+
self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
|
763 |
+
|
764 |
+
self.apply(self._init_weights)
|
765 |
+
|
766 |
+
def _init_weights(self, m):
|
767 |
+
if isinstance(m, nn.Linear):
|
768 |
+
trunc_normal_(m.weight, std=.02)
|
769 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
770 |
+
nn.init.constant_(m.bias, 0)
|
771 |
+
elif isinstance(m, nn.LayerNorm):
|
772 |
+
nn.init.constant_(m.bias, 0)
|
773 |
+
nn.init.constant_(m.weight, 1.0)
|
774 |
+
|
775 |
+
@torch.jit.ignore
|
776 |
+
def no_weight_decay(self):
|
777 |
+
return {'absolute_pos_embed'}
|
778 |
+
|
779 |
+
@torch.jit.ignore
|
780 |
+
def no_weight_decay_keywords(self):
|
781 |
+
return {'relative_position_bias_table'}
|
782 |
+
|
783 |
+
def check_image_size(self, x):
|
784 |
+
_, _, h, w = x.size()
|
785 |
+
mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
|
786 |
+
mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
|
787 |
+
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
|
788 |
+
return x
|
789 |
+
|
790 |
+
def forward_features(self, x):
|
791 |
+
x_size = (x.shape[2], x.shape[3])
|
792 |
+
x = self.patch_embed(x)
|
793 |
+
if self.ape:
|
794 |
+
x = x + self.absolute_pos_embed
|
795 |
+
x = self.pos_drop(x)
|
796 |
+
|
797 |
+
for layer in self.layers:
|
798 |
+
x = layer(x, x_size)
|
799 |
+
|
800 |
+
x = self.norm(x) # B L C
|
801 |
+
x = self.patch_unembed(x, x_size)
|
802 |
+
|
803 |
+
return x
|
804 |
+
|
805 |
+
def forward(self, x):
|
806 |
+
H, W = x.shape[2:]
|
807 |
+
x = self.check_image_size(x)
|
808 |
+
|
809 |
+
self.mean = self.mean.type_as(x)
|
810 |
+
x = (x - self.mean) * self.img_range
|
811 |
+
|
812 |
+
if self.upsampler == 'pixelshuffle':
|
813 |
+
# for classical SR
|
814 |
+
x = self.conv_first(x)
|
815 |
+
x = self.conv_after_body(self.forward_features(x)) + x
|
816 |
+
x = self.conv_before_upsample(x)
|
817 |
+
x = self.conv_last(self.upsample(x))
|
818 |
+
elif self.upsampler == 'pixelshuffledirect':
|
819 |
+
# for lightweight SR
|
820 |
+
x = self.conv_first(x)
|
821 |
+
x = self.conv_after_body(self.forward_features(x)) + x
|
822 |
+
x = self.upsample(x)
|
823 |
+
elif self.upsampler == 'nearest+conv':
|
824 |
+
# for real-world SR
|
825 |
+
x = self.conv_first(x)
|
826 |
+
x = self.conv_after_body(self.forward_features(x)) + x
|
827 |
+
x = self.conv_before_upsample(x)
|
828 |
+
x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
|
829 |
+
if self.upscale == 4:
|
830 |
+
x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
|
831 |
+
x = self.conv_last(self.lrelu(self.conv_hr(x)))
|
832 |
+
else:
|
833 |
+
# for image denoising and JPEG compression artifact reduction
|
834 |
+
x_first = self.conv_first(x)
|
835 |
+
res = self.conv_after_body(self.forward_features(x_first)) + x_first
|
836 |
+
x = x + self.conv_last(res)
|
837 |
+
|
838 |
+
x = x / self.img_range + self.mean
|
839 |
+
|
840 |
+
return x[:, :, :H*self.upscale, :W*self.upscale]
|
841 |
+
|
842 |
+
def flops(self):
|
843 |
+
flops = 0
|
844 |
+
H, W = self.patches_resolution
|
845 |
+
flops += H * W * 3 * self.embed_dim * 9
|
846 |
+
flops += self.patch_embed.flops()
|
847 |
+
for i, layer in enumerate(self.layers):
|
848 |
+
flops += layer.flops()
|
849 |
+
flops += H * W * 3 * self.embed_dim * self.embed_dim
|
850 |
+
flops += self.upsample.flops()
|
851 |
+
return flops
|
852 |
+
|
853 |
+
|
854 |
+
if __name__ == '__main__':
|
855 |
+
upscale = 4
|
856 |
+
window_size = 8
|
857 |
+
height = (1024 // upscale // window_size + 1) * window_size
|
858 |
+
width = (720 // upscale // window_size + 1) * window_size
|
859 |
+
model = SwinIR(upscale=2, img_size=(height, width),
|
860 |
+
window_size=window_size, img_range=1., depths=[6, 6, 6, 6],
|
861 |
+
embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect').cuda()
|
862 |
+
print(model)
|
863 |
+
|
864 |
+
pytorch_total_params = sum(p.numel() for p in model.parameters())
|
865 |
+
print(f"pathGAN has param {pytorch_total_params//1000} K params")
|
866 |
+
|
867 |
+
|
868 |
+
# Count the time
|
869 |
+
import time
|
870 |
+
x = torch.randn((1, 3, 180, 180)).cuda()
|
871 |
+
start = time.time()
|
872 |
+
x = model(x)
|
873 |
+
total = time.time() - start
|
874 |
+
print("total time spent is ", total)
|
dataset_curation_pipeline/IC9600/ICNet.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchvision
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
|
7 |
+
|
8 |
+
class slam(nn.Module):
|
9 |
+
def __init__(self, spatial_dim):
|
10 |
+
super(slam,self).__init__()
|
11 |
+
self.spatial_dim = spatial_dim
|
12 |
+
self.linear = nn.Sequential(
|
13 |
+
nn.Linear(spatial_dim**2,512),
|
14 |
+
nn.ReLU(),
|
15 |
+
nn.Linear(512,1),
|
16 |
+
nn.Sigmoid()
|
17 |
+
)
|
18 |
+
|
19 |
+
def forward(self, feature):
|
20 |
+
n,c,h,w = feature.shape
|
21 |
+
if (h != self.spatial_dim):
|
22 |
+
x = F.interpolate(feature,size=(self.spatial_dim,self.spatial_dim),mode= "bilinear", align_corners=True)
|
23 |
+
else:
|
24 |
+
x = feature
|
25 |
+
|
26 |
+
|
27 |
+
x = x.view(n,c,-1)
|
28 |
+
x = self.linear(x)
|
29 |
+
x = x.unsqueeze(dim =3)
|
30 |
+
out = x.expand_as(feature)*feature
|
31 |
+
|
32 |
+
return out
|
33 |
+
|
34 |
+
|
35 |
+
class to_map(nn.Module):
|
36 |
+
def __init__(self,channels):
|
37 |
+
super(to_map,self).__init__()
|
38 |
+
self.to_map = nn.Sequential(
|
39 |
+
nn.Conv2d(in_channels=channels,out_channels=1, kernel_size=1,stride=1),
|
40 |
+
nn.Sigmoid()
|
41 |
+
)
|
42 |
+
|
43 |
+
def forward(self,feature):
|
44 |
+
return self.to_map(feature)
|
45 |
+
|
46 |
+
|
47 |
+
class conv_bn_relu(nn.Module):
|
48 |
+
def __init__(self,in_channels, out_channels, kernel_size = 3, padding = 1, stride = 1):
|
49 |
+
super(conv_bn_relu,self).__init__()
|
50 |
+
self.conv = nn.Conv2d(in_channels= in_channels, out_channels= out_channels, kernel_size= kernel_size, padding= padding, stride = stride)
|
51 |
+
self.bn = nn.BatchNorm2d(out_channels)
|
52 |
+
self.relu = nn.ReLU()
|
53 |
+
|
54 |
+
def forward(self,x):
|
55 |
+
x = self.conv(x)
|
56 |
+
x = self.bn(x)
|
57 |
+
x = self.relu(x)
|
58 |
+
return x
|
59 |
+
|
60 |
+
|
61 |
+
|
62 |
+
class up_conv_bn_relu(nn.Module):
|
63 |
+
def __init__(self,up_size, in_channels, out_channels = 64, kernal_size = 1, padding =0, stride = 1):
|
64 |
+
super(up_conv_bn_relu,self).__init__()
|
65 |
+
self.upSample = nn.Upsample(size = (up_size,up_size),mode="bilinear",align_corners=True)
|
66 |
+
self.conv = nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size = kernal_size, stride = stride, padding= padding)
|
67 |
+
self.bn = nn.BatchNorm2d(num_features=out_channels)
|
68 |
+
self.act = nn.ReLU()
|
69 |
+
|
70 |
+
def forward(self,x):
|
71 |
+
x = self.upSample(x)
|
72 |
+
x = self.conv(x)
|
73 |
+
x = self.bn(x)
|
74 |
+
x = self.act(x)
|
75 |
+
return x
|
76 |
+
|
77 |
+
|
78 |
+
|
79 |
+
class ICNet(nn.Module):
|
80 |
+
def __init__(self, is_pretrain = True, size1 = 512, size2 = 256):
|
81 |
+
super(ICNet,self).__init__()
|
82 |
+
resnet18Pretrained1 = torchvision.models.resnet18(pretrained= is_pretrain)
|
83 |
+
resnet18Pretrained2 = torchvision.models.resnet18(pretrained= is_pretrain)
|
84 |
+
|
85 |
+
self.size1 = size1
|
86 |
+
self.size2 = size2
|
87 |
+
|
88 |
+
## detail branch
|
89 |
+
self.b1_1 = nn.Sequential(*list(resnet18Pretrained1.children())[:5])
|
90 |
+
self.b1_1_slam = slam(32)
|
91 |
+
|
92 |
+
self.b1_2 = list(resnet18Pretrained1.children())[5]
|
93 |
+
self.b1_2_slam = slam(32)
|
94 |
+
|
95 |
+
## context branch
|
96 |
+
self.b2_1 = nn.Sequential(*list(resnet18Pretrained2.children())[:5])
|
97 |
+
self.b2_1_slam = slam(32)
|
98 |
+
|
99 |
+
self.b2_2 = list(resnet18Pretrained2.children())[5]
|
100 |
+
self.b2_2_slam = slam(32)
|
101 |
+
|
102 |
+
self.b2_3 = list(resnet18Pretrained2.children())[6]
|
103 |
+
self.b2_3_slam = slam(16)
|
104 |
+
|
105 |
+
self.b2_4 = list(resnet18Pretrained2.children())[7]
|
106 |
+
self.b2_4_slam = slam(8)
|
107 |
+
|
108 |
+
## upsample
|
109 |
+
self.upsize = size1 // 8
|
110 |
+
self.up1 = up_conv_bn_relu(up_size = self.upsize, in_channels = 128, out_channels = 256)
|
111 |
+
self.up2 = up_conv_bn_relu(up_size = self.upsize, in_channels = 512, out_channels = 256)
|
112 |
+
|
113 |
+
## map prediction head
|
114 |
+
self.to_map_f = conv_bn_relu(256*2,256*2)
|
115 |
+
self.to_map_f_slam = slam(32)
|
116 |
+
self.to_map = to_map(256*2)
|
117 |
+
|
118 |
+
## score prediction head
|
119 |
+
self.to_score_f = conv_bn_relu(256*2,256*2)
|
120 |
+
self.to_score_f_slam = slam(32)
|
121 |
+
self.head = nn.Sequential(
|
122 |
+
nn.Linear(256*2,512),
|
123 |
+
nn.ReLU(),
|
124 |
+
nn.Linear(512,1),
|
125 |
+
nn.Sigmoid()
|
126 |
+
)
|
127 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1,1))
|
128 |
+
|
129 |
+
|
130 |
+
def forward(self,x1):
|
131 |
+
assert(x1.shape[2] == x1.shape[3] == self.size1)
|
132 |
+
x2 = F.interpolate(x1, size= (self.size2,self.size2), mode = "bilinear", align_corners= True)
|
133 |
+
|
134 |
+
x1 = self.b1_2_slam(self.b1_2(self.b1_1_slam(self.b1_1(x1))))
|
135 |
+
x2 = self.b2_2_slam(self.b2_2(self.b2_1_slam(self.b2_1(x2))))
|
136 |
+
x2 = self.b2_4_slam(self.b2_4(self.b2_3_slam(self.b2_3(x2))))
|
137 |
+
|
138 |
+
|
139 |
+
x1 = self.up1(x1)
|
140 |
+
x2 = self.up2(x2)
|
141 |
+
x_cat = torch.cat((x1,x2),dim = 1)
|
142 |
+
|
143 |
+
cly_map = self.to_map(self.to_map_f_slam(self.to_map_f(x_cat)))
|
144 |
+
|
145 |
+
score_feature = self.to_score_f_slam(self.to_score_f(x_cat))
|
146 |
+
score_feature = self.avgpool(score_feature)
|
147 |
+
score_feature = score_feature.squeeze()
|
148 |
+
score = self.head(score_feature)
|
149 |
+
score = score.squeeze()
|
150 |
+
|
151 |
+
return score,cly_map
|
dataset_curation_pipeline/IC9600/gene.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os, sys
|
3 |
+
import torch
|
4 |
+
import cv2
|
5 |
+
from torchvision import transforms
|
6 |
+
from PIL import Image
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import numpy as np
|
9 |
+
from matplotlib import pyplot as plt
|
10 |
+
from tqdm import tqdm
|
11 |
+
|
12 |
+
# Import files from the local folder
|
13 |
+
root_path = os.path.abspath('.')
|
14 |
+
sys.path.append(root_path)
|
15 |
+
from opt import opt
|
16 |
+
from dataset_curation_pipeline.IC9600.ICNet import ICNet
|
17 |
+
|
18 |
+
|
19 |
+
|
20 |
+
inference_transform = transforms.Compose([
|
21 |
+
transforms.Resize((512,512)),
|
22 |
+
transforms.ToTensor(),
|
23 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
24 |
+
])
|
25 |
+
|
26 |
+
def blend(ori_img, ic_img, alpha = 0.8, cm = plt.get_cmap("magma")):
|
27 |
+
cm_ic_map = cm(ic_img)
|
28 |
+
heatmap = Image.fromarray((cm_ic_map[:, :, -2::-1]*255).astype(np.uint8))
|
29 |
+
ori_img = Image.fromarray(ori_img)
|
30 |
+
blend = Image.blend(ori_img,heatmap,alpha=alpha)
|
31 |
+
blend = np.array(blend)
|
32 |
+
return blend
|
33 |
+
|
34 |
+
|
35 |
+
def infer_one_image(model, img_path):
|
36 |
+
with torch.no_grad():
|
37 |
+
ori_img = Image.open(img_path).convert("RGB")
|
38 |
+
ori_height = ori_img.height
|
39 |
+
ori_width = ori_img.width
|
40 |
+
img = inference_transform(ori_img)
|
41 |
+
img = img.cuda()
|
42 |
+
img = img.unsqueeze(0)
|
43 |
+
ic_score, ic_map = model(img)
|
44 |
+
ic_score = ic_score.item()
|
45 |
+
|
46 |
+
|
47 |
+
# ic_map = F.interpolate(ic_map, (ori_height, ori_width), mode = 'bilinear')
|
48 |
+
|
49 |
+
## gene ic map
|
50 |
+
# ic_map_np = ic_map.squeeze().detach().cpu().numpy()
|
51 |
+
# out_ic_map_name = os.path.basename(img_path).split('.')[0] + '_' + str(ic_score)[:7] + '.npy'
|
52 |
+
# out_ic_map_path = os.path.join(args.output, out_ic_map_name)
|
53 |
+
# np.save(out_ic_map_path, ic_map_np)
|
54 |
+
|
55 |
+
## gene blend map
|
56 |
+
# ic_map_img = (ic_map * 255).round().squeeze().detach().cpu().numpy().astype('uint8')
|
57 |
+
# blend_img = blend(np.array(ori_img), ic_map_img)
|
58 |
+
# out_blend_img_name = os.path.basename(img_path).split('.')[0] + '.png'
|
59 |
+
# out_blend_img_path = os.path.join(args.output, out_blend_img_name)
|
60 |
+
# cv2.imwrite(out_blend_img_path, blend_img)
|
61 |
+
return ic_score
|
62 |
+
|
63 |
+
|
64 |
+
|
65 |
+
def infer_directory(img_dir):
|
66 |
+
imgs = sorted(os.listdir(img_dir))
|
67 |
+
scores = []
|
68 |
+
for img in tqdm(imgs):
|
69 |
+
img_path = os.path.join(img_dir, img)
|
70 |
+
score = infer_one_image(img_path)
|
71 |
+
|
72 |
+
scores.append((score, img_path))
|
73 |
+
print(img_path, score)
|
74 |
+
|
75 |
+
scores = sorted(scores, key=lambda x: x[0])
|
76 |
+
scores = scores[::-1]
|
77 |
+
|
78 |
+
for score in scores[:50]:
|
79 |
+
print(score)
|
80 |
+
|
81 |
+
|
82 |
+
if __name__ == "__main__":
|
83 |
+
parser = argparse.ArgumentParser()
|
84 |
+
parser.add_argument('-i', '--input', type = str, default = './example')
|
85 |
+
parser.add_argument('-o', '--output', type = str, default = './out')
|
86 |
+
parser.add_argument('-d', '--device', type = int, default=0)
|
87 |
+
|
88 |
+
args = parser.parse_args()
|
89 |
+
|
90 |
+
model = ICNet()
|
91 |
+
model.load_state_dict(torch.load('./checkpoint/ck.pth',map_location=torch.device('cpu')))
|
92 |
+
model.eval()
|
93 |
+
device = torch.device(args.device)
|
94 |
+
model.to(device)
|
95 |
+
|
96 |
+
inference_transform = transforms.Compose([
|
97 |
+
transforms.Resize((512,512)),
|
98 |
+
transforms.ToTensor(),
|
99 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
100 |
+
])
|
101 |
+
if os.path.isfile(args.input):
|
102 |
+
infer_one_image(args.input)
|
103 |
+
else:
|
104 |
+
infer_directory(args.input)
|
105 |
+
|
106 |
+
|
107 |
+
|
108 |
+
|
109 |
+
|
110 |
+
|
111 |
+
|
112 |
+
|
113 |
+
|
dataset_curation_pipeline/collect.py
ADDED
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
This file is the whole dataset curation pipeline to collect the least compressed and the most informative frames from video source.
|
3 |
+
'''
|
4 |
+
import os, time, sys
|
5 |
+
import shutil
|
6 |
+
import cv2
|
7 |
+
import torch
|
8 |
+
import argparse
|
9 |
+
|
10 |
+
# Import files from the local folder
|
11 |
+
root_path = os.path.abspath('.')
|
12 |
+
sys.path.append(root_path)
|
13 |
+
from opt import opt
|
14 |
+
from dataset_curation_pipeline.IC9600.gene import infer_one_image
|
15 |
+
from dataset_curation_pipeline.IC9600.ICNet import ICNet
|
16 |
+
|
17 |
+
|
18 |
+
class video_scoring:
|
19 |
+
|
20 |
+
def __init__(self, IC9600_pretrained_weight_path) -> None:
|
21 |
+
|
22 |
+
# Init the model
|
23 |
+
self.scorer = ICNet()
|
24 |
+
self.scorer.load_state_dict(torch.load(IC9600_pretrained_weight_path, map_location=torch.device('cpu')))
|
25 |
+
self.scorer.eval().cuda()
|
26 |
+
|
27 |
+
|
28 |
+
def select_frame(self, skip_num, img_lists, target_frame_num, save_dir, output_name_head, partition_idx):
|
29 |
+
''' Execution of scoring to all I-Frame in img_folder and select target_frame to return back
|
30 |
+
Args:
|
31 |
+
skip_num (int): Only 1 in skip_num will be chosen to accelerate.
|
32 |
+
img_lists (str): The image lists of all files we want to process
|
33 |
+
target_frame_num (int): The number of frames we need to choose
|
34 |
+
save_dir (str): The path where we save those images
|
35 |
+
output_name_head (str): This is the input video name head
|
36 |
+
partition_idx (int): The partition idx
|
37 |
+
'''
|
38 |
+
|
39 |
+
stores = []
|
40 |
+
for idx, image_path in enumerate(sorted(img_lists)):
|
41 |
+
if idx % skip_num != 0:
|
42 |
+
# We only process 1 in 3 to accelerate and also prevent minor case of repeated scene.
|
43 |
+
continue
|
44 |
+
|
45 |
+
|
46 |
+
# Evaluate the image complexity score for this image
|
47 |
+
score = infer_one_image(self.scorer, image_path)
|
48 |
+
|
49 |
+
if verbose:
|
50 |
+
print(image_path, score)
|
51 |
+
stores.append((score, image_path))
|
52 |
+
|
53 |
+
if verbose:
|
54 |
+
print(image_path, score)
|
55 |
+
|
56 |
+
|
57 |
+
# Find the top most scores' images
|
58 |
+
stores.sort(key=lambda x:x[0])
|
59 |
+
selected = stores[-target_frame_num:]
|
60 |
+
# print(len(stores), len(selected))
|
61 |
+
if verbose:
|
62 |
+
print("The lowest selected score is ", selected[0]) # This is a kind of info
|
63 |
+
|
64 |
+
|
65 |
+
# Store the selected images
|
66 |
+
for idx, (score, img_path) in enumerate(selected):
|
67 |
+
output_name = output_name_head + "_" +str(partition_idx)+ "_" + str(idx) + ".png"
|
68 |
+
output_path = os.path.join(save_dir, output_name)
|
69 |
+
shutil.copyfile(img_path, output_path)
|
70 |
+
|
71 |
+
|
72 |
+
def run(self, skip_num, img_folder, target_frame_num, save_dir, output_name_head, partition_num):
|
73 |
+
''' Execution of scoring to all I-Frame in img_folder and select target_frame to return back
|
74 |
+
Args:
|
75 |
+
skip_num (int): Only 1 in skip_num will be chosen to accelerate.
|
76 |
+
img_folder (str): The image folder of all I-Frames we need to process
|
77 |
+
target_frame_num (int): The number of frames we need to choose
|
78 |
+
save_dir (str): The path where we save those images
|
79 |
+
output_name_head (str): This is the input video name head
|
80 |
+
partition_num (int): The number of partition we want to crop the video to
|
81 |
+
'''
|
82 |
+
assert(target_frame_num%partition_num == 0)
|
83 |
+
|
84 |
+
img_lists = []
|
85 |
+
for img_name in sorted(os.listdir(img_folder)):
|
86 |
+
path = os.path.join(img_folder, img_name)
|
87 |
+
img_lists.append(path)
|
88 |
+
length = len(img_lists)
|
89 |
+
unit_length = (length // partition_num)
|
90 |
+
target_partition_num = target_frame_num // partition_num
|
91 |
+
|
92 |
+
# Cut the folder to several partition and select those with the highest score
|
93 |
+
for idx in range(partition_num):
|
94 |
+
select_lists = img_lists[unit_length*idx : unit_length*(idx+1)]
|
95 |
+
self.select_frame(skip_num, select_lists, target_partition_num, save_dir, output_name_head, idx)
|
96 |
+
|
97 |
+
|
98 |
+
class frame_collector:
|
99 |
+
|
100 |
+
def __init__(self, IC9600_pretrained_weight_path, verbose) -> None:
|
101 |
+
|
102 |
+
self.scoring = video_scoring(IC9600_pretrained_weight_path)
|
103 |
+
self.verbose = verbose
|
104 |
+
|
105 |
+
|
106 |
+
def video_split_by_IFrame(self, video_path, tmp_path):
|
107 |
+
''' Split the video to its I-Frames format
|
108 |
+
Args:
|
109 |
+
video_path (str): The directory to a single video
|
110 |
+
tmp_path (str): A temporary working places to work and will be delete at the end
|
111 |
+
'''
|
112 |
+
|
113 |
+
# Prepare the work folder needed
|
114 |
+
if os.path.exists(tmp_path):
|
115 |
+
shutil.rmtree(tmp_path)
|
116 |
+
os.makedirs(tmp_path)
|
117 |
+
|
118 |
+
|
119 |
+
# Split Video I-frame
|
120 |
+
cmd = "ffmpeg -i " + video_path + " -loglevel error -vf select='eq(pict_type\,I)' -vsync 2 -f image2 -q:v 1 " + tmp_path + "/image-%06d.png" # At most support 100K I-Frames per video
|
121 |
+
|
122 |
+
if self.verbose:
|
123 |
+
print(cmd)
|
124 |
+
os.system(cmd)
|
125 |
+
|
126 |
+
|
127 |
+
|
128 |
+
def collect_frames(self, video_folder_dir, save_dir, tmp_path, skip_num, target_frames, partition_num):
|
129 |
+
''' Automatically collect frames from the video dir
|
130 |
+
Args:
|
131 |
+
video_folder_dir (str): The directory of all videos input
|
132 |
+
save_dir (str): The directory we will store the selected frames
|
133 |
+
tmp_path (str): A temporary working places to work and will be delete at the end
|
134 |
+
skip_num (int): Only 1 in skip_num will be chosen to accelerate.
|
135 |
+
target_frames (list): [# of frames for video under 30 min, # of frames for video over 30 min]
|
136 |
+
partition_num (int): The number of partition we want to crop the video to
|
137 |
+
'''
|
138 |
+
|
139 |
+
# Iterate all video under video_folder_dir
|
140 |
+
for video_name in sorted(os.listdir(video_folder_dir)):
|
141 |
+
# Sanity check for this video file format
|
142 |
+
info = video_name.split('.')
|
143 |
+
if info[-1] not in ['mp4', 'mkv', '']:
|
144 |
+
continue
|
145 |
+
output_name_head, extension = info
|
146 |
+
|
147 |
+
|
148 |
+
# Get info of this video
|
149 |
+
video_path = os.path.join(video_folder_dir, video_name)
|
150 |
+
duration = get_duration(video_path) # unit in minutes
|
151 |
+
print("We are processing " + video_path + " with duration " + str(duration) + " min")
|
152 |
+
|
153 |
+
|
154 |
+
# Split the video to I-frame
|
155 |
+
self.video_split_by_IFrame(video_path, tmp_path)
|
156 |
+
|
157 |
+
|
158 |
+
# Score the frames and select those top scored frames we need
|
159 |
+
if duration <= 30:
|
160 |
+
target_frame_num = target_frames[0]
|
161 |
+
else:
|
162 |
+
target_frame_num = target_frames[1]
|
163 |
+
|
164 |
+
self.scoring.run(skip_num, tmp_path, target_frame_num, save_dir, output_name_head, partition_num)
|
165 |
+
|
166 |
+
|
167 |
+
# Remove folders if needed
|
168 |
+
|
169 |
+
|
170 |
+
def get_duration(filename):
|
171 |
+
video = cv2.VideoCapture(filename)
|
172 |
+
fps = video.get(cv2.CAP_PROP_FPS)
|
173 |
+
frame_count = video.get(cv2.CAP_PROP_FRAME_COUNT)
|
174 |
+
seconds = frame_count / fps
|
175 |
+
minutes = int(seconds / 60)
|
176 |
+
return minutes
|
177 |
+
|
178 |
+
|
179 |
+
if __name__ == "__main__":
|
180 |
+
|
181 |
+
# Fundamental setting
|
182 |
+
parser = argparse.ArgumentParser()
|
183 |
+
parser.add_argument('--video_folder_dir', type = str, default = '../anime_videos', help = "A folder with video sources")
|
184 |
+
parser.add_argument('--IC9600_pretrained_weight_path', type = str, default = "pretrained/ck.pth", help = "The pretrained IC9600 weight")
|
185 |
+
parser.add_argument('--save_dir', type = str, default = 'APISR_dataset', help = "The folder to store filtered dataset")
|
186 |
+
parser.add_argument('--skip_num', type = int, default = 5, help = "Only 1 in skip_num will be chosen in sequential I-frames to accelerate.")
|
187 |
+
parser.add_argument('--target_frames', type = list, default = [16, 24], help = "[# of frames for video under 30 min, # of frames for video over 30 min]")
|
188 |
+
parser.add_argument('--partition_num', type = int, default = 8, help = "The number of partition we want to crop the video to, to increase diversity of sampling")
|
189 |
+
parser.add_argument('--verbose', type = bool, default = True, help = "Whether we print log message")
|
190 |
+
args = parser.parse_args()
|
191 |
+
|
192 |
+
|
193 |
+
# Transform to variable
|
194 |
+
video_folder_dir = args.video_folder_dir
|
195 |
+
IC9600_pretrained_weight_path = args.IC9600_pretrained_weight_path
|
196 |
+
save_dir = args.save_dir
|
197 |
+
skip_num = args.skip_num
|
198 |
+
target_frames = args.target_frames # [# of frames for video under 30 min, # of frames for video over 30 min]
|
199 |
+
partition_num = args.partition_num
|
200 |
+
verbose = args.verbose
|
201 |
+
|
202 |
+
|
203 |
+
# Secondary setting
|
204 |
+
tmp_path = "tmp_dataset"
|
205 |
+
|
206 |
+
|
207 |
+
# Prepare
|
208 |
+
if os.path.exists(save_dir):
|
209 |
+
shutil.rmtree(save_dir)
|
210 |
+
os.makedirs(save_dir)
|
211 |
+
|
212 |
+
|
213 |
+
# Process
|
214 |
+
start = time.time()
|
215 |
+
|
216 |
+
obj = frame_collector(IC9600_pretrained_weight_path, verbose)
|
217 |
+
obj.collect_frames(video_folder_dir, save_dir, tmp_path, skip_num, target_frames, partition_num)
|
218 |
+
|
219 |
+
total_time = (time.time() - start)//60
|
220 |
+
print("Total time spent is {} min".format(total_time))
|
221 |
+
|
222 |
+
shutil.rmtree(tmp_path)
|
degradation/ESR/degradation_esr_shared.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
import cv2
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
import os, shutil, time
|
8 |
+
import sys, random
|
9 |
+
from multiprocessing import Pool
|
10 |
+
from os import path as osp
|
11 |
+
from tqdm import tqdm
|
12 |
+
from math import log10, sqrt
|
13 |
+
import torch.nn.functional as F
|
14 |
+
|
15 |
+
root_path = os.path.abspath('.')
|
16 |
+
sys.path.append(root_path)
|
17 |
+
from degradation.ESR.degradations_functionality import *
|
18 |
+
from degradation.ESR.diffjpeg import *
|
19 |
+
from degradation.ESR.utils import filter2D
|
20 |
+
from degradation.image_compression.jpeg import JPEG
|
21 |
+
from degradation.image_compression.webp import WEBP
|
22 |
+
from degradation.image_compression.heif import HEIF
|
23 |
+
from degradation.image_compression.avif import AVIF
|
24 |
+
from opt import opt
|
25 |
+
|
26 |
+
|
27 |
+
def PSNR(original, compressed):
|
28 |
+
mse = np.mean((original - compressed) ** 2)
|
29 |
+
if(mse == 0): # MSE is zero means no noise is present in the signal .
|
30 |
+
# Therefore PSNR have no importance.
|
31 |
+
return 100
|
32 |
+
max_pixel = 255.0
|
33 |
+
psnr = 20 * log10(max_pixel / sqrt(mse))
|
34 |
+
return psnr
|
35 |
+
|
36 |
+
|
37 |
+
|
38 |
+
def downsample_1st(out, opt):
|
39 |
+
# Resize with different mode
|
40 |
+
updown_type = random.choices(['up', 'down', 'keep'], opt['resize_prob'])[0]
|
41 |
+
if updown_type == 'up':
|
42 |
+
scale = np.random.uniform(1, opt['resize_range'][1])
|
43 |
+
elif updown_type == 'down':
|
44 |
+
scale = np.random.uniform(opt['resize_range'][0], 1)
|
45 |
+
else:
|
46 |
+
scale = 1
|
47 |
+
mode = random.choice(opt['resize_options'])
|
48 |
+
out = F.interpolate(out, scale_factor=scale, mode=mode)
|
49 |
+
|
50 |
+
return out
|
51 |
+
|
52 |
+
|
53 |
+
def downsample_2nd(out, opt, ori_h, ori_w):
|
54 |
+
# Second Resize for 4x scaling
|
55 |
+
if opt['scale'] == 4:
|
56 |
+
updown_type = random.choices(['up', 'down', 'keep'], opt['resize_prob2'])[0]
|
57 |
+
if updown_type == 'up':
|
58 |
+
scale = np.random.uniform(1, opt['resize_range2'][1])
|
59 |
+
elif updown_type == 'down':
|
60 |
+
scale = np.random.uniform(opt['resize_range2'][0], 1)
|
61 |
+
else:
|
62 |
+
scale = 1
|
63 |
+
mode = random.choice(opt['resize_options'])
|
64 |
+
# Resize这边改回来原来的版本,不用连续的resize了
|
65 |
+
# out = F.interpolate(out, scale_factor=scale, mode=mode)
|
66 |
+
out = F.interpolate(
|
67 |
+
out, size=(int(ori_h / opt['scale'] * scale), int(ori_w / opt['scale'] * scale)), mode=mode
|
68 |
+
)
|
69 |
+
|
70 |
+
return out
|
71 |
+
|
72 |
+
|
73 |
+
def common_degradation(out, opt, kernels, process_id, verbose = False):
|
74 |
+
jpeger = DiffJPEG(differentiable=False).cuda()
|
75 |
+
kernel1, kernel2 = kernels
|
76 |
+
|
77 |
+
|
78 |
+
downsample_1st_position = random.choices([0, 1, 2])[0]
|
79 |
+
if opt['scale'] == 4:
|
80 |
+
# Only do the second downsample at 4x scale
|
81 |
+
downsample_2nd_position = random.choices([0, 1, 2])[0]
|
82 |
+
else:
|
83 |
+
# print("We don't use the second resize")
|
84 |
+
downsample_2nd_position = -1
|
85 |
+
|
86 |
+
|
87 |
+
####---------------------------- Frist Degradation ----------------------------------####
|
88 |
+
batch_size, _, ori_h, ori_w = out.size()
|
89 |
+
|
90 |
+
if downsample_1st_position == 0:
|
91 |
+
out = downsample_1st(out, opt)
|
92 |
+
|
93 |
+
# Bluring kernel
|
94 |
+
out = filter2D(out, kernel1)
|
95 |
+
if verbose: print(f"(1st) blur noise")
|
96 |
+
|
97 |
+
|
98 |
+
if downsample_1st_position == 1:
|
99 |
+
out = downsample_1st(out, opt)
|
100 |
+
|
101 |
+
|
102 |
+
# Noise effect (gaussian / poisson)
|
103 |
+
gray_noise_prob = opt['gray_noise_prob']
|
104 |
+
if np.random.uniform() < opt['gaussian_noise_prob']:
|
105 |
+
# Gaussian noise
|
106 |
+
out = random_add_gaussian_noise_pt(
|
107 |
+
out, sigma_range=opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob)
|
108 |
+
name = "gaussian_noise"
|
109 |
+
else:
|
110 |
+
# Poisson noise
|
111 |
+
out = random_add_poisson_noise_pt(
|
112 |
+
out,
|
113 |
+
scale_range=opt['poisson_scale_range'],
|
114 |
+
gray_prob=gray_noise_prob,
|
115 |
+
clip=True,
|
116 |
+
rounds=False)
|
117 |
+
name = "poisson_noise"
|
118 |
+
if verbose: print("(1st) " + str(name))
|
119 |
+
|
120 |
+
|
121 |
+
if downsample_1st_position == 2:
|
122 |
+
out = downsample_1st(out, opt)
|
123 |
+
|
124 |
+
|
125 |
+
# Choose an image compression codec (All degradation batch use the same codec)
|
126 |
+
image_codec = random.choices(opt['compression_codec1'], opt['compression_codec_prob1'])[0] # All lower case
|
127 |
+
if image_codec == "jpeg":
|
128 |
+
out = JPEG.compress_tensor(out)
|
129 |
+
elif image_codec == "webp":
|
130 |
+
try:
|
131 |
+
out = WEBP.compress_tensor(out, idx=process_id)
|
132 |
+
except Exception:
|
133 |
+
print("There is exception again in webp!")
|
134 |
+
out = WEBP.compress_tensor(out, idx=process_id)
|
135 |
+
elif image_codec == "heif":
|
136 |
+
out = HEIF.compress_tensor(out, idx=process_id)
|
137 |
+
elif image_codec == "avif":
|
138 |
+
out = AVIF.compress_tensor(out, idx=process_id)
|
139 |
+
else:
|
140 |
+
raise NotImplementedError("We don't have such image compression designed!")
|
141 |
+
# ##########################################################################################
|
142 |
+
|
143 |
+
|
144 |
+
# ####---------------------------- Second Degradation ----------------------------------####
|
145 |
+
if downsample_2nd_position == 0:
|
146 |
+
out = downsample_2nd(out, opt, ori_h, ori_w)
|
147 |
+
|
148 |
+
|
149 |
+
# Add blur 2nd time
|
150 |
+
if np.random.uniform() < opt['second_blur_prob']:
|
151 |
+
# 这个bluring不是必定触发的
|
152 |
+
if verbose: print("(2nd) blur noise")
|
153 |
+
out = filter2D(out, kernel2)
|
154 |
+
|
155 |
+
|
156 |
+
if downsample_2nd_position == 1:
|
157 |
+
out = downsample_2nd(out, opt, ori_h, ori_w)
|
158 |
+
|
159 |
+
|
160 |
+
# Add noise 2nd time
|
161 |
+
gray_noise_prob = opt['gray_noise_prob2']
|
162 |
+
if np.random.uniform() < opt['gaussian_noise_prob2']:
|
163 |
+
# gaussian noise
|
164 |
+
if verbose: print("(2nd) gaussian noise")
|
165 |
+
out = random_add_gaussian_noise_pt(
|
166 |
+
out, sigma_range=opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob)
|
167 |
+
name = "gaussian_noise"
|
168 |
+
else:
|
169 |
+
# poisson noise
|
170 |
+
if verbose: print("(2nd) poisson noise")
|
171 |
+
out = random_add_poisson_noise_pt(
|
172 |
+
out, scale_range=opt['poisson_scale_range2'], gray_prob=gray_noise_prob, clip=True, rounds=False)
|
173 |
+
name = "poisson_noise"
|
174 |
+
|
175 |
+
|
176 |
+
if downsample_2nd_position == 2:
|
177 |
+
out = downsample_2nd(out, opt, ori_h, ori_w)
|
178 |
+
|
179 |
+
|
180 |
+
return out
|
degradation/ESR/degradations_functionality.py
ADDED
@@ -0,0 +1,785 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import math
|
5 |
+
import numpy as np
|
6 |
+
import random
|
7 |
+
import torch
|
8 |
+
from scipy import special
|
9 |
+
from scipy.stats import multivariate_normal
|
10 |
+
from torchvision.transforms.functional_tensor import rgb_to_grayscale
|
11 |
+
|
12 |
+
# -------------------------------------------------------------------- #
|
13 |
+
# --------------------------- blur kernels --------------------------- #
|
14 |
+
# -------------------------------------------------------------------- #
|
15 |
+
|
16 |
+
|
17 |
+
# --------------------------- util functions --------------------------- #
|
18 |
+
def sigma_matrix2(sig_x, sig_y, theta):
|
19 |
+
"""Calculate the rotated sigma matrix (two dimensional matrix).
|
20 |
+
|
21 |
+
Args:
|
22 |
+
sig_x (float):
|
23 |
+
sig_y (float):
|
24 |
+
theta (float): Radian measurement.
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
ndarray: Rotated sigma matrix.
|
28 |
+
"""
|
29 |
+
d_matrix = np.array([[sig_x**2, 0], [0, sig_y**2]])
|
30 |
+
u_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
|
31 |
+
return np.dot(u_matrix, np.dot(d_matrix, u_matrix.T))
|
32 |
+
|
33 |
+
|
34 |
+
def mesh_grid(kernel_size):
|
35 |
+
"""Generate the mesh grid, centering at zero.
|
36 |
+
|
37 |
+
Args:
|
38 |
+
kernel_size (int):
|
39 |
+
|
40 |
+
Returns:
|
41 |
+
xy (ndarray): with the shape (kernel_size, kernel_size, 2)
|
42 |
+
xx (ndarray): with the shape (kernel_size, kernel_size)
|
43 |
+
yy (ndarray): with the shape (kernel_size, kernel_size)
|
44 |
+
"""
|
45 |
+
ax = np.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.)
|
46 |
+
xx, yy = np.meshgrid(ax, ax)
|
47 |
+
xy = np.hstack((xx.reshape((kernel_size * kernel_size, 1)), yy.reshape(kernel_size * kernel_size,
|
48 |
+
1))).reshape(kernel_size, kernel_size, 2)
|
49 |
+
return xy, xx, yy
|
50 |
+
|
51 |
+
|
52 |
+
def pdf2(sigma_matrix, grid):
|
53 |
+
"""Calculate PDF of the bivariate Gaussian distribution.
|
54 |
+
|
55 |
+
Args:
|
56 |
+
sigma_matrix (ndarray): with the shape (2, 2)
|
57 |
+
grid (ndarray): generated by :func:`mesh_grid`,
|
58 |
+
with the shape (K, K, 2), K is the kernel size.
|
59 |
+
|
60 |
+
Returns:
|
61 |
+
kernel (ndarrray): un-normalized kernel.
|
62 |
+
"""
|
63 |
+
inverse_sigma = np.linalg.inv(sigma_matrix)
|
64 |
+
kernel = np.exp(-0.5 * np.sum(np.dot(grid, inverse_sigma) * grid, 2))
|
65 |
+
return kernel
|
66 |
+
|
67 |
+
|
68 |
+
def cdf2(d_matrix, grid):
|
69 |
+
"""Calculate the CDF of the standard bivariate Gaussian distribution.
|
70 |
+
Used in skewed Gaussian distribution.
|
71 |
+
|
72 |
+
Args:
|
73 |
+
d_matrix (ndarrasy): skew matrix.
|
74 |
+
grid (ndarray): generated by :func:`mesh_grid`,
|
75 |
+
with the shape (K, K, 2), K is the kernel size.
|
76 |
+
|
77 |
+
Returns:
|
78 |
+
cdf (ndarray): skewed cdf.
|
79 |
+
"""
|
80 |
+
rv = multivariate_normal([0, 0], [[1, 0], [0, 1]])
|
81 |
+
grid = np.dot(grid, d_matrix)
|
82 |
+
cdf = rv.cdf(grid)
|
83 |
+
return cdf
|
84 |
+
|
85 |
+
|
86 |
+
def bivariate_Gaussian(kernel_size, sig_x, sig_y, theta, grid=None, isotropic=True):
|
87 |
+
"""Generate a bivariate isotropic or anisotropic Gaussian kernel.
|
88 |
+
|
89 |
+
In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
|
90 |
+
|
91 |
+
Args:
|
92 |
+
kernel_size (int):
|
93 |
+
sig_x (float):
|
94 |
+
sig_y (float):
|
95 |
+
theta (float): Radian measurement.
|
96 |
+
grid (ndarray, optional): generated by :func:`mesh_grid`,
|
97 |
+
with the shape (K, K, 2), K is the kernel size. Default: None
|
98 |
+
isotropic (bool):
|
99 |
+
|
100 |
+
Returns:
|
101 |
+
kernel (ndarray): normalized kernel.
|
102 |
+
"""
|
103 |
+
if grid is None:
|
104 |
+
grid, _, _ = mesh_grid(kernel_size)
|
105 |
+
if isotropic:
|
106 |
+
sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
|
107 |
+
else:
|
108 |
+
sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
|
109 |
+
kernel = pdf2(sigma_matrix, grid)
|
110 |
+
kernel = kernel / np.sum(kernel)
|
111 |
+
return kernel
|
112 |
+
|
113 |
+
|
114 |
+
def bivariate_generalized_Gaussian(kernel_size, sig_x, sig_y, theta, beta, grid=None, isotropic=True):
|
115 |
+
"""Generate a bivariate generalized Gaussian kernel.
|
116 |
+
Described in `Parameter Estimation For Multivariate Generalized
|
117 |
+
Gaussian Distributions`_
|
118 |
+
by Pascal et. al (2013).
|
119 |
+
|
120 |
+
In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
|
121 |
+
|
122 |
+
Args:
|
123 |
+
kernel_size (int):
|
124 |
+
sig_x (float):
|
125 |
+
sig_y (float):
|
126 |
+
theta (float): Radian measurement.
|
127 |
+
beta (float): shape parameter, beta = 1 is the normal distribution.
|
128 |
+
grid (ndarray, optional): generated by :func:`mesh_grid`,
|
129 |
+
with the shape (K, K, 2), K is the kernel size. Default: None
|
130 |
+
|
131 |
+
Returns:
|
132 |
+
kernel (ndarray): normalized kernel.
|
133 |
+
|
134 |
+
.. _Parameter Estimation For Multivariate Generalized Gaussian
|
135 |
+
Distributions: https://arxiv.org/abs/1302.6498
|
136 |
+
"""
|
137 |
+
if grid is None:
|
138 |
+
grid, _, _ = mesh_grid(kernel_size)
|
139 |
+
if isotropic:
|
140 |
+
sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
|
141 |
+
else:
|
142 |
+
sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
|
143 |
+
inverse_sigma = np.linalg.inv(sigma_matrix)
|
144 |
+
kernel = np.exp(-0.5 * np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta))
|
145 |
+
kernel = kernel / np.sum(kernel)
|
146 |
+
return kernel
|
147 |
+
|
148 |
+
|
149 |
+
def bivariate_plateau(kernel_size, sig_x, sig_y, theta, beta, grid=None, isotropic=True):
|
150 |
+
"""Generate a plateau-like anisotropic kernel.
|
151 |
+
1 / (1+x^(beta))
|
152 |
+
|
153 |
+
Ref: https://stats.stackexchange.com/questions/203629/is-there-a-plateau-shaped-distribution
|
154 |
+
|
155 |
+
In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
|
156 |
+
|
157 |
+
Args:
|
158 |
+
kernel_size (int):
|
159 |
+
sig_x (float):
|
160 |
+
sig_y (float):
|
161 |
+
theta (float): Radian measurement.
|
162 |
+
beta (float): shape parameter, beta = 1 is the normal distribution.
|
163 |
+
grid (ndarray, optional): generated by :func:`mesh_grid`,
|
164 |
+
with the shape (K, K, 2), K is the kernel size. Default: None
|
165 |
+
|
166 |
+
Returns:
|
167 |
+
kernel (ndarray): normalized kernel.
|
168 |
+
"""
|
169 |
+
if grid is None:
|
170 |
+
grid, _, _ = mesh_grid(kernel_size)
|
171 |
+
if isotropic:
|
172 |
+
sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
|
173 |
+
else:
|
174 |
+
sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
|
175 |
+
inverse_sigma = np.linalg.inv(sigma_matrix)
|
176 |
+
kernel = np.reciprocal(np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta) + 1)
|
177 |
+
kernel = kernel / np.sum(kernel)
|
178 |
+
return kernel
|
179 |
+
|
180 |
+
|
181 |
+
def random_bivariate_Gaussian(kernel_size,
|
182 |
+
sigma_x_range,
|
183 |
+
sigma_y_range,
|
184 |
+
rotation_range,
|
185 |
+
noise_range=None,
|
186 |
+
isotropic=True):
|
187 |
+
"""Randomly generate bivariate isotropic or anisotropic Gaussian kernels.
|
188 |
+
|
189 |
+
In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
|
190 |
+
|
191 |
+
Args:
|
192 |
+
kernel_size (int):
|
193 |
+
sigma_x_range (tuple): [0.6, 5]
|
194 |
+
sigma_y_range (tuple): [0.6, 5]
|
195 |
+
rotation range (tuple): [-math.pi, math.pi]
|
196 |
+
noise_range(tuple, optional): multiplicative kernel noise,
|
197 |
+
[0.75, 1.25]. Default: None
|
198 |
+
|
199 |
+
Returns:
|
200 |
+
kernel (ndarray):
|
201 |
+
"""
|
202 |
+
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
|
203 |
+
assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
|
204 |
+
sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
|
205 |
+
if isotropic is False:
|
206 |
+
assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
|
207 |
+
assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
|
208 |
+
sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
|
209 |
+
rotation = np.random.uniform(rotation_range[0], rotation_range[1])
|
210 |
+
else:
|
211 |
+
sigma_y = sigma_x
|
212 |
+
rotation = 0
|
213 |
+
|
214 |
+
kernel = bivariate_Gaussian(kernel_size, sigma_x, sigma_y, rotation, isotropic=isotropic)
|
215 |
+
|
216 |
+
# add multiplicative noise
|
217 |
+
if noise_range is not None:
|
218 |
+
assert noise_range[0] < noise_range[1], 'Wrong noise range.'
|
219 |
+
noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
|
220 |
+
kernel = kernel * noise
|
221 |
+
kernel = kernel / np.sum(kernel)
|
222 |
+
return kernel
|
223 |
+
|
224 |
+
|
225 |
+
def random_bivariate_generalized_Gaussian(kernel_size,
|
226 |
+
sigma_x_range,
|
227 |
+
sigma_y_range,
|
228 |
+
rotation_range,
|
229 |
+
beta_range,
|
230 |
+
noise_range=None,
|
231 |
+
isotropic=True):
|
232 |
+
"""Randomly generate bivariate generalized Gaussian kernels.
|
233 |
+
|
234 |
+
In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
|
235 |
+
|
236 |
+
Args:
|
237 |
+
kernel_size (int):
|
238 |
+
sigma_x_range (tuple): [0.6, 5]
|
239 |
+
sigma_y_range (tuple): [0.6, 5]
|
240 |
+
rotation range (tuple): [-math.pi, math.pi]
|
241 |
+
beta_range (tuple): [0.5, 8]
|
242 |
+
noise_range(tuple, optional): multiplicative kernel noise,
|
243 |
+
[0.75, 1.25]. Default: None
|
244 |
+
|
245 |
+
Returns:
|
246 |
+
kernel (ndarray):
|
247 |
+
"""
|
248 |
+
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
|
249 |
+
assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
|
250 |
+
sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
|
251 |
+
if isotropic is False:
|
252 |
+
assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
|
253 |
+
assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
|
254 |
+
sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
|
255 |
+
rotation = np.random.uniform(rotation_range[0], rotation_range[1])
|
256 |
+
else:
|
257 |
+
sigma_y = sigma_x
|
258 |
+
rotation = 0
|
259 |
+
|
260 |
+
# assume beta_range[0] < 1 < beta_range[1]
|
261 |
+
if np.random.uniform() < 0.5:
|
262 |
+
beta = np.random.uniform(beta_range[0], 1)
|
263 |
+
else:
|
264 |
+
beta = np.random.uniform(1, beta_range[1])
|
265 |
+
|
266 |
+
kernel = bivariate_generalized_Gaussian(kernel_size, sigma_x, sigma_y, rotation, beta, isotropic=isotropic)
|
267 |
+
|
268 |
+
# add multiplicative noise
|
269 |
+
if noise_range is not None:
|
270 |
+
assert noise_range[0] < noise_range[1], 'Wrong noise range.'
|
271 |
+
noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
|
272 |
+
kernel = kernel * noise
|
273 |
+
kernel = kernel / np.sum(kernel)
|
274 |
+
return kernel
|
275 |
+
|
276 |
+
|
277 |
+
def random_bivariate_plateau(kernel_size,
|
278 |
+
sigma_x_range,
|
279 |
+
sigma_y_range,
|
280 |
+
rotation_range,
|
281 |
+
beta_range,
|
282 |
+
noise_range=None,
|
283 |
+
isotropic=True):
|
284 |
+
"""Randomly generate bivariate plateau kernels.
|
285 |
+
|
286 |
+
In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
|
287 |
+
|
288 |
+
Args:
|
289 |
+
kernel_size (int):
|
290 |
+
sigma_x_range (tuple): [0.6, 5]
|
291 |
+
sigma_y_range (tuple): [0.6, 5]
|
292 |
+
rotation range (tuple): [-math.pi/2, math.pi/2]
|
293 |
+
beta_range (tuple): [1, 4]
|
294 |
+
noise_range(tuple, optional): multiplicative kernel noise,
|
295 |
+
[0.75, 1.25]. Default: None
|
296 |
+
|
297 |
+
Returns:
|
298 |
+
kernel (ndarray):
|
299 |
+
"""
|
300 |
+
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
|
301 |
+
assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
|
302 |
+
sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
|
303 |
+
if isotropic is False:
|
304 |
+
assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
|
305 |
+
assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
|
306 |
+
sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
|
307 |
+
rotation = np.random.uniform(rotation_range[0], rotation_range[1])
|
308 |
+
else:
|
309 |
+
sigma_y = sigma_x
|
310 |
+
rotation = 0
|
311 |
+
|
312 |
+
# TODO: this may be not proper
|
313 |
+
if np.random.uniform() < 0.5:
|
314 |
+
beta = np.random.uniform(beta_range[0], 1)
|
315 |
+
else:
|
316 |
+
beta = np.random.uniform(1, beta_range[1])
|
317 |
+
|
318 |
+
kernel = bivariate_plateau(kernel_size, sigma_x, sigma_y, rotation, beta, isotropic=isotropic)
|
319 |
+
# add multiplicative noise
|
320 |
+
if noise_range is not None:
|
321 |
+
assert noise_range[0] < noise_range[1], 'Wrong noise range.'
|
322 |
+
noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
|
323 |
+
kernel = kernel * noise
|
324 |
+
kernel = kernel / np.sum(kernel)
|
325 |
+
|
326 |
+
return kernel
|
327 |
+
|
328 |
+
|
329 |
+
def random_mixed_kernels(kernel_list,
|
330 |
+
kernel_prob,
|
331 |
+
kernel_size=21,
|
332 |
+
sigma_x_range=(0.6, 5),
|
333 |
+
sigma_y_range=(0.6, 5),
|
334 |
+
rotation_range=(-math.pi, math.pi),
|
335 |
+
betag_range=(0.5, 8),
|
336 |
+
betap_range=(0.5, 8),
|
337 |
+
noise_range=None):
|
338 |
+
"""Randomly generate mixed kernels.
|
339 |
+
|
340 |
+
Args:
|
341 |
+
kernel_list (tuple): a list name of kernel types,
|
342 |
+
support ['iso', 'aniso', 'skew', 'generalized', 'plateau_iso',
|
343 |
+
'plateau_aniso']
|
344 |
+
kernel_prob (tuple): corresponding kernel probability for each
|
345 |
+
kernel type
|
346 |
+
kernel_size (int):
|
347 |
+
sigma_x_range (tuple): [0.6, 5]
|
348 |
+
sigma_y_range (tuple): [0.6, 5]
|
349 |
+
rotation range (tuple): [-math.pi, math.pi]
|
350 |
+
beta_range (tuple): [0.5, 8]
|
351 |
+
noise_range(tuple, optional): multiplicative kernel noise,
|
352 |
+
[0.75, 1.25]. Default: None
|
353 |
+
|
354 |
+
Returns:
|
355 |
+
kernel (ndarray):
|
356 |
+
"""
|
357 |
+
kernel_type = random.choices(kernel_list, kernel_prob)[0]
|
358 |
+
if kernel_type == 'iso':
|
359 |
+
kernel = random_bivariate_Gaussian(
|
360 |
+
kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=True)
|
361 |
+
elif kernel_type == 'aniso':
|
362 |
+
kernel = random_bivariate_Gaussian(
|
363 |
+
kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=False)
|
364 |
+
elif kernel_type == 'generalized_iso':
|
365 |
+
kernel = random_bivariate_generalized_Gaussian(
|
366 |
+
kernel_size,
|
367 |
+
sigma_x_range,
|
368 |
+
sigma_y_range,
|
369 |
+
rotation_range,
|
370 |
+
betag_range,
|
371 |
+
noise_range=noise_range,
|
372 |
+
isotropic=True)
|
373 |
+
elif kernel_type == 'generalized_aniso':
|
374 |
+
kernel = random_bivariate_generalized_Gaussian(
|
375 |
+
kernel_size,
|
376 |
+
sigma_x_range,
|
377 |
+
sigma_y_range,
|
378 |
+
rotation_range,
|
379 |
+
betag_range,
|
380 |
+
noise_range=noise_range,
|
381 |
+
isotropic=False)
|
382 |
+
elif kernel_type == 'plateau_iso':
|
383 |
+
kernel = random_bivariate_plateau(
|
384 |
+
kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=True)
|
385 |
+
elif kernel_type == 'plateau_aniso':
|
386 |
+
kernel = random_bivariate_plateau(
|
387 |
+
kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=False)
|
388 |
+
return kernel
|
389 |
+
|
390 |
+
|
391 |
+
np.seterr(divide='ignore', invalid='ignore')
|
392 |
+
|
393 |
+
|
394 |
+
def circular_lowpass_kernel(cutoff, kernel_size, pad_to=0):
|
395 |
+
"""2D sinc filter, ref: https://dsp.stackexchange.com/questions/58301/2-d-circularly-symmetric-low-pass-filter
|
396 |
+
=====》 这个地方好好调研一下,能做出来的效果决定了后面的上线!
|
397 |
+
Args:
|
398 |
+
cutoff (float): cutoff frequency in radians (pi is max)
|
399 |
+
kernel_size (int): horizontal and vertical size, must be odd.
|
400 |
+
pad_to (int): pad kernel size to desired size, must be odd or zero.
|
401 |
+
"""
|
402 |
+
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
|
403 |
+
kernel = np.fromfunction(
|
404 |
+
lambda x, y: cutoff * special.j1(cutoff * np.sqrt(
|
405 |
+
(x - (kernel_size - 1) / 2)**2 + (y - (kernel_size - 1) / 2)**2)) / (2 * np.pi * np.sqrt(
|
406 |
+
(x - (kernel_size - 1) / 2)**2 + (y - (kernel_size - 1) / 2)**2)), [kernel_size, kernel_size])
|
407 |
+
kernel[(kernel_size - 1) // 2, (kernel_size - 1) // 2] = cutoff**2 / (4 * np.pi)
|
408 |
+
kernel = kernel / np.sum(kernel)
|
409 |
+
if pad_to > kernel_size:
|
410 |
+
pad_size = (pad_to - kernel_size) // 2
|
411 |
+
kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
|
412 |
+
return kernel
|
413 |
+
|
414 |
+
|
415 |
+
# ------------------------------------------------------------- #
|
416 |
+
# --------------------------- noise --------------------------- #
|
417 |
+
# ------------------------------------------------------------- #
|
418 |
+
|
419 |
+
# ----------------------- Gaussian Noise ----------------------- #
|
420 |
+
|
421 |
+
|
422 |
+
def generate_gaussian_noise(img, sigma=10, gray_noise=False):
|
423 |
+
"""Generate Gaussian noise.
|
424 |
+
|
425 |
+
Args:
|
426 |
+
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
|
427 |
+
sigma (float): Noise scale (measured in range 255). Default: 10.
|
428 |
+
|
429 |
+
Returns:
|
430 |
+
(Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
|
431 |
+
float32.
|
432 |
+
"""
|
433 |
+
if gray_noise:
|
434 |
+
noise = np.float32(np.random.randn(*(img.shape[0:2]))) * sigma / 255.
|
435 |
+
noise = np.expand_dims(noise, axis=2).repeat(3, axis=2)
|
436 |
+
else:
|
437 |
+
noise = np.float32(np.random.randn(*(img.shape))) * sigma / 255.
|
438 |
+
return noise
|
439 |
+
|
440 |
+
|
441 |
+
def add_gaussian_noise(img, sigma=10, clip=True, rounds=False, gray_noise=False):
|
442 |
+
"""Add Gaussian noise.
|
443 |
+
|
444 |
+
Args:
|
445 |
+
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
|
446 |
+
sigma (float): Noise scale (measured in range 255). Default: 10.
|
447 |
+
|
448 |
+
Returns:
|
449 |
+
(Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
|
450 |
+
float32.
|
451 |
+
"""
|
452 |
+
noise = generate_gaussian_noise(img, sigma, gray_noise)
|
453 |
+
out = img + noise
|
454 |
+
if clip and rounds:
|
455 |
+
out = np.clip((out * 255.0).round(), 0, 255) / 255.
|
456 |
+
elif clip:
|
457 |
+
out = np.clip(out, 0, 1)
|
458 |
+
elif rounds:
|
459 |
+
out = (out * 255.0).round() / 255.
|
460 |
+
return out
|
461 |
+
|
462 |
+
|
463 |
+
def generate_gaussian_noise_pt(img, sigma=10, gray_noise=0):
|
464 |
+
"""Add Gaussian noise (PyTorch version).
|
465 |
+
|
466 |
+
Args:
|
467 |
+
img (Tensor): Shape (b, c, h, w), range[0, 1], float32.
|
468 |
+
sigma (float | Tensor): 每一个batch都被分配了一个(share 一个)
|
469 |
+
gray_noise (float | Tensor): 不是1就是0
|
470 |
+
|
471 |
+
Returns:
|
472 |
+
(Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
|
473 |
+
float32.
|
474 |
+
"""
|
475 |
+
b, _, h, w = img.size()
|
476 |
+
if not isinstance(sigma, (float, int)):
|
477 |
+
sigma = sigma.view(img.size(0), 1, 1, 1)
|
478 |
+
if isinstance(gray_noise, (float, int)):
|
479 |
+
cal_gray_noise = gray_noise > 0
|
480 |
+
else:
|
481 |
+
gray_noise = gray_noise.view(b, 1, 1, 1)
|
482 |
+
cal_gray_noise = torch.sum(gray_noise) > 0
|
483 |
+
|
484 |
+
if cal_gray_noise:
|
485 |
+
noise_gray = torch.randn(*img.size()[2:4], dtype=img.dtype, device=img.device) * sigma / 255.
|
486 |
+
noise_gray = noise_gray.view(b, 1, h, w)
|
487 |
+
|
488 |
+
# always calculate color noise
|
489 |
+
noise = torch.randn(*img.size(), dtype=img.dtype, device=img.device) * sigma / 255.
|
490 |
+
|
491 |
+
if cal_gray_noise:
|
492 |
+
noise = noise * (1 - gray_noise) + noise_gray * gray_noise
|
493 |
+
return noise
|
494 |
+
|
495 |
+
|
496 |
+
def add_gaussian_noise_pt(img, sigma=10, gray_noise=0, clip=True, rounds=False):
|
497 |
+
"""Add Gaussian noise (PyTorch version).
|
498 |
+
|
499 |
+
Args:
|
500 |
+
img (Tensor): Shape (b, c, h, w), range[0, 1], float32.
|
501 |
+
scale (float | Tensor): Noise scale. Default: 1.0.
|
502 |
+
|
503 |
+
Returns:
|
504 |
+
(Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
|
505 |
+
float32.
|
506 |
+
"""
|
507 |
+
noise = generate_gaussian_noise_pt(img, sigma, gray_noise) # sigma 就是gray_noise的保存率
|
508 |
+
out = img + noise
|
509 |
+
if clip and rounds:
|
510 |
+
out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
|
511 |
+
elif clip:
|
512 |
+
out = torch.clamp(out, 0, 1)
|
513 |
+
elif rounds:
|
514 |
+
out = (out * 255.0).round() / 255.
|
515 |
+
return out
|
516 |
+
|
517 |
+
|
518 |
+
# ----------------------- Random Gaussian Noise ----------------------- #
|
519 |
+
def random_generate_gaussian_noise(img, sigma_range=(0, 10), gray_prob=0):
|
520 |
+
sigma = np.random.uniform(sigma_range[0], sigma_range[1])
|
521 |
+
if np.random.uniform() < gray_prob:
|
522 |
+
gray_noise = True
|
523 |
+
else:
|
524 |
+
gray_noise = False
|
525 |
+
return generate_gaussian_noise(img, sigma, gray_noise)
|
526 |
+
|
527 |
+
|
528 |
+
def random_add_gaussian_noise(img, sigma_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
|
529 |
+
noise = random_generate_gaussian_noise(img, sigma_range, gray_prob)
|
530 |
+
out = img + noise
|
531 |
+
if clip and rounds:
|
532 |
+
out = np.clip((out * 255.0).round(), 0, 255) / 255.
|
533 |
+
elif clip:
|
534 |
+
out = np.clip(out, 0, 1)
|
535 |
+
elif rounds:
|
536 |
+
out = (out * 255.0).round() / 255.
|
537 |
+
return out
|
538 |
+
|
539 |
+
|
540 |
+
def random_generate_gaussian_noise_pt(img, sigma_range=(0, 10), gray_prob=0):
|
541 |
+
sigma = torch.rand(img.size(0), dtype=img.dtype, device=img.device) * (sigma_range[1] - sigma_range[0]) + sigma_range[0]
|
542 |
+
|
543 |
+
gray_noise = torch.rand(img.size(0), dtype=img.dtype, device=img.device)
|
544 |
+
gray_noise = (gray_noise < gray_prob).float()
|
545 |
+
return generate_gaussian_noise_pt(img, sigma, gray_noise)
|
546 |
+
|
547 |
+
|
548 |
+
def random_add_gaussian_noise_pt(img, sigma_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
|
549 |
+
# sigma_range 就是noise保存比例
|
550 |
+
noise = random_generate_gaussian_noise_pt(img, sigma_range, gray_prob)
|
551 |
+
out = img + noise
|
552 |
+
if clip and rounds:
|
553 |
+
out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
|
554 |
+
elif clip:
|
555 |
+
out = torch.clamp(out, 0, 1)
|
556 |
+
elif rounds:
|
557 |
+
out = (out * 255.0).round() / 255.
|
558 |
+
return out
|
559 |
+
|
560 |
+
|
561 |
+
# ----------------------- Poisson (Shot) Noise ----------------------- #
|
562 |
+
|
563 |
+
|
564 |
+
def generate_poisson_noise(img, scale=1.0, gray_noise=False):
|
565 |
+
"""Generate poisson noise.
|
566 |
+
|
567 |
+
Ref: https://github.com/scikit-image/scikit-image/blob/main/skimage/util/noise.py#L37-L219
|
568 |
+
|
569 |
+
Args:
|
570 |
+
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
|
571 |
+
scale (float): Noise scale. Default: 1.0.
|
572 |
+
gray_noise (bool): Whether generate gray noise. Default: False.
|
573 |
+
|
574 |
+
Returns:
|
575 |
+
(Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
|
576 |
+
float32.
|
577 |
+
"""
|
578 |
+
if gray_noise:
|
579 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
580 |
+
# round and clip image for counting vals correctly
|
581 |
+
img = np.clip((img * 255.0).round(), 0, 255) / 255.
|
582 |
+
vals = len(np.unique(img))
|
583 |
+
vals = 2**np.ceil(np.log2(vals))
|
584 |
+
out = np.float32(np.random.poisson(img * vals) / float(vals))
|
585 |
+
noise = out - img
|
586 |
+
if gray_noise:
|
587 |
+
noise = np.repeat(noise[:, :, np.newaxis], 3, axis=2)
|
588 |
+
return noise * scale
|
589 |
+
|
590 |
+
|
591 |
+
def add_poisson_noise(img, scale=1.0, clip=True, rounds=False, gray_noise=False):
|
592 |
+
"""Add poisson noise.
|
593 |
+
|
594 |
+
Args:
|
595 |
+
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
|
596 |
+
scale (float): Noise scale. Default: 1.0.
|
597 |
+
gray_noise (bool): Whether generate gray noise. Default: False.
|
598 |
+
|
599 |
+
Returns:
|
600 |
+
(Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
|
601 |
+
float32.
|
602 |
+
"""
|
603 |
+
noise = generate_poisson_noise(img, scale, gray_noise)
|
604 |
+
out = img + noise
|
605 |
+
if clip and rounds:
|
606 |
+
out = np.clip((out * 255.0).round(), 0, 255) / 255.
|
607 |
+
elif clip:
|
608 |
+
out = np.clip(out, 0, 1)
|
609 |
+
elif rounds:
|
610 |
+
out = (out * 255.0).round() / 255.
|
611 |
+
return out
|
612 |
+
|
613 |
+
|
614 |
+
def generate_poisson_noise_pt(img, scale=1.0, gray_noise=0):
|
615 |
+
"""Generate a batch of poisson noise (PyTorch version)
|
616 |
+
|
617 |
+
Args:
|
618 |
+
img (Tensor): Input image, shape (b, c, h, w), range [0, 1], float32.
|
619 |
+
scale (float | Tensor): Noise scale. Number or Tensor with shape (b).
|
620 |
+
Default: 1.0.
|
621 |
+
可以是个batch形式(Tensor)
|
622 |
+
gray_noise (float | Tensor): 0-1 number or Tensor with shape (b).
|
623 |
+
0 for False, 1 for True. Default: 0.
|
624 |
+
可以是个batch形式(Tensor)
|
625 |
+
|
626 |
+
Returns:
|
627 |
+
(Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
|
628 |
+
float32.
|
629 |
+
"""
|
630 |
+
b, _, h, w = img.size()
|
631 |
+
if isinstance(gray_noise, (float, int)):
|
632 |
+
cal_gray_noise = gray_noise > 0
|
633 |
+
else:
|
634 |
+
gray_noise = gray_noise.view(b, 1, 1, 1)
|
635 |
+
# 这下面跟原论文有点小不一样的地方,如果按照我现在128 batch size,基本上每个都会有gray noise
|
636 |
+
cal_gray_noise = torch.sum(gray_noise) > 0
|
637 |
+
if cal_gray_noise:
|
638 |
+
# 这里实际上我是觉得写的不是很efficient,因为有些地方如果不加那不是完全白计算了吗,现在gray noise的概率低得很
|
639 |
+
img_gray = rgb_to_grayscale(img, num_output_channels=1) # 返回的只有luminance这一个channel
|
640 |
+
# round and clip image for counting vals correctly, ensure that it only has 256 possible floats at the end
|
641 |
+
img_gray = torch.clamp((img_gray * 255.0).round(), 0, 255) / 255.
|
642 |
+
# use for-loop to get the unique values for each sample
|
643 |
+
|
644 |
+
# Note: 这里加上noise完全看的是本图片(一张)的颜色diversity,这应该就解释了为什么在比较单一的flat图像,他会noise更加明显
|
645 |
+
vals_list = [len(torch.unique(img_gray[i, :, :, :])) for i in range(b)]
|
646 |
+
vals_list = [2**np.ceil(np.log2(vals)) for vals in vals_list]
|
647 |
+
vals = img_gray.new_tensor(vals_list).view(b, 1, 1, 1)
|
648 |
+
|
649 |
+
# Since the img is in range [0,1], the noise by possion distribution should also lies in [0,1]
|
650 |
+
# Note: 这只是我个人的理解,现在对于单调的图片,整体会比较集中poisson noise在一个高点,就不如unique值高的图片会广泛分布(看possison distribution的图都看的出来)
|
651 |
+
out = torch.poisson(img_gray * vals) / vals
|
652 |
+
noise_gray = out - img_gray
|
653 |
+
noise_gray = noise_gray.expand(b, 3, h, w)
|
654 |
+
|
655 |
+
# always calculate color noise
|
656 |
+
# round and clip image for counting vals correctly
|
657 |
+
img = torch.clamp((img * 255.0).round(), 0, 255) / 255.
|
658 |
+
# use for-loop to get the unique values for each sample
|
659 |
+
vals_list = [len(torch.unique(img[i, :, :, :])) for i in range(b)]
|
660 |
+
vals_list = [2**np.ceil(np.log2(vals)) for vals in vals_list]
|
661 |
+
vals = img.new_tensor(vals_list).view(b, 1, 1, 1)
|
662 |
+
out = torch.poisson(img * vals) / vals # output还是正数
|
663 |
+
noise = out - img # 这个会导致负值的产生
|
664 |
+
if cal_gray_noise:
|
665 |
+
# Note: 这里noise要么全加,要么不加(换成gray_noise)
|
666 |
+
noise = noise * (1 - gray_noise) + noise_gray * gray_noise # In this place, I don't know why it sometimes run out of memory
|
667 |
+
if not isinstance(scale, (float, int)):
|
668 |
+
scale = scale.view(b, 1, 1, 1)
|
669 |
+
|
670 |
+
# Note: noise这边产出的值都是-0.x ---- +0.x 这个范围: 负的值相当于减弱pixel值的效果
|
671 |
+
# print("poisson noise range is ", sorted(torch.unique(noise))[:10])
|
672 |
+
# print(sorted(torch.unique(noise))[-10:])
|
673 |
+
return noise * scale
|
674 |
+
|
675 |
+
|
676 |
+
def add_poisson_noise_pt(img, scale=1.0, clip=True, rounds=False, gray_noise=0):
|
677 |
+
"""Add poisson noise to a batch of images (PyTorch version).
|
678 |
+
|
679 |
+
Args:
|
680 |
+
img (Tensor): Input image, shape (b, c, h, w), range [0, 1], float32.
|
681 |
+
scale (float | Tensor): Noise scale. Number or Tensor with shape (b).
|
682 |
+
Default: 1.0.
|
683 |
+
gray_noise (float | Tensor): 0-1 number or Tensor with shape (b).
|
684 |
+
0 for False, 1 for True. Default: 0.
|
685 |
+
|
686 |
+
Returns:
|
687 |
+
(Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
|
688 |
+
float32.
|
689 |
+
"""
|
690 |
+
noise = generate_poisson_noise_pt(img, scale, gray_noise)
|
691 |
+
out = img + noise
|
692 |
+
if clip and rounds:
|
693 |
+
out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
|
694 |
+
elif clip:
|
695 |
+
out = torch.clamp(out, 0, 1)
|
696 |
+
elif rounds:
|
697 |
+
out = (out * 255.0).round() / 255.
|
698 |
+
return out
|
699 |
+
|
700 |
+
|
701 |
+
# ----------------------- Random Poisson (Shot) Noise ----------------------- #
|
702 |
+
|
703 |
+
|
704 |
+
def random_generate_poisson_noise(img, scale_range=(0, 1.0), gray_prob=0):
|
705 |
+
scale = np.random.uniform(scale_range[0], scale_range[1])
|
706 |
+
if np.random.uniform() < gray_prob:
|
707 |
+
gray_noise = True
|
708 |
+
else:
|
709 |
+
gray_noise = False
|
710 |
+
return generate_poisson_noise(img, scale, gray_noise)
|
711 |
+
|
712 |
+
|
713 |
+
def random_add_poisson_noise(img, scale_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
|
714 |
+
noise = random_generate_poisson_noise(img, scale_range, gray_prob)
|
715 |
+
out = img + noise
|
716 |
+
if clip and rounds:
|
717 |
+
out = np.clip((out * 255.0).round(), 0, 255) / 255.
|
718 |
+
elif clip:
|
719 |
+
out = np.clip(out, 0, 1)
|
720 |
+
elif rounds:
|
721 |
+
out = (out * 255.0).round() / 255.
|
722 |
+
return out
|
723 |
+
|
724 |
+
|
725 |
+
def random_generate_poisson_noise_pt(img, scale_range=(0, 1.0), gray_prob=0):
|
726 |
+
# scale_range 还是保存的大小
|
727 |
+
# img.size(0) 代表就是batch中的每个图片都有一个自己的scale level
|
728 |
+
scale = torch.rand(img.size(0), dtype=img.dtype, device=img.device) * (scale_range[1] - scale_range[0]) + scale_range[0]
|
729 |
+
|
730 |
+
gray_noise = torch.rand(img.size(0), dtype=img.dtype, device=img.device)
|
731 |
+
gray_noise = (gray_noise < gray_prob).float()
|
732 |
+
return generate_poisson_noise_pt(img, scale, gray_noise) # scale 和 gray_noise应该都是tensor的batch形式
|
733 |
+
|
734 |
+
|
735 |
+
def random_add_poisson_noise_pt(img, scale_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
|
736 |
+
noise = random_generate_poisson_noise_pt(img, scale_range, gray_prob)
|
737 |
+
out = img + noise
|
738 |
+
if clip and rounds:
|
739 |
+
out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
|
740 |
+
elif clip:
|
741 |
+
out = torch.clamp(out, 0, 1)
|
742 |
+
elif rounds:
|
743 |
+
out = (out * 255.0).round() / 255.
|
744 |
+
return out
|
745 |
+
|
746 |
+
|
747 |
+
# ------------------------------------------------------------------------ #
|
748 |
+
# --------------------------- JPEG compression --------------------------- #
|
749 |
+
# ------------------------------------------------------------------------ #
|
750 |
+
|
751 |
+
|
752 |
+
def add_jpg_compression(img, quality=90):
|
753 |
+
"""Add JPG compression artifacts.
|
754 |
+
|
755 |
+
Args:
|
756 |
+
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
|
757 |
+
quality (float): JPG compression quality. 0 for lowest quality, 100 for
|
758 |
+
best quality. Default: 90.
|
759 |
+
|
760 |
+
Returns:
|
761 |
+
(Numpy array): Returned image after JPG, shape (h, w, c), range[0, 1],
|
762 |
+
float32.
|
763 |
+
"""
|
764 |
+
img = np.clip(img, 0, 1)
|
765 |
+
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality]
|
766 |
+
_, encimg = cv2.imencode('.jpg', img * 255., encode_param)
|
767 |
+
img = np.float32(cv2.imdecode(encimg, 1)) / 255.
|
768 |
+
return img
|
769 |
+
|
770 |
+
|
771 |
+
def random_add_jpg_compression(img, quality_range=(90, 100)):
|
772 |
+
"""Randomly add JPG compression artifacts.
|
773 |
+
|
774 |
+
Args:
|
775 |
+
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
|
776 |
+
quality_range (tuple[float] | list[float]): JPG compression quality
|
777 |
+
range. 0 for lowest quality, 100 for best quality.
|
778 |
+
Default: (90, 100).
|
779 |
+
|
780 |
+
Returns:
|
781 |
+
(Numpy array): Returned image after JPG, shape (h, w, c), range[0, 1],
|
782 |
+
float32.
|
783 |
+
"""
|
784 |
+
quality = np.random.uniform(quality_range[0], quality_range[1])
|
785 |
+
return add_jpg_compression(img, quality)
|
degradation/ESR/diffjpeg.py
ADDED
@@ -0,0 +1,517 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
"""
|
4 |
+
Modified from https://github.com/mlomnitz/DiffJPEG
|
5 |
+
|
6 |
+
For images not divisible by 8
|
7 |
+
https://dsp.stackexchange.com/questions/35339/jpeg-dct-padding/35343#35343
|
8 |
+
"""
|
9 |
+
import itertools
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
from torch.nn import functional as F
|
14 |
+
|
15 |
+
# ------------------------ utils ------------------------#
|
16 |
+
y_table = np.array(
|
17 |
+
[[16, 11, 10, 16, 24, 40, 51, 61], [12, 12, 14, 19, 26, 58, 60, 55], [14, 13, 16, 24, 40, 57, 69, 56],
|
18 |
+
[14, 17, 22, 29, 51, 87, 80, 62], [18, 22, 37, 56, 68, 109, 103, 77], [24, 35, 55, 64, 81, 104, 113, 92],
|
19 |
+
[49, 64, 78, 87, 103, 121, 120, 101], [72, 92, 95, 98, 112, 100, 103, 99]],
|
20 |
+
dtype=np.float32).T
|
21 |
+
y_table = nn.Parameter(torch.from_numpy(y_table))
|
22 |
+
c_table = np.empty((8, 8), dtype=np.float32)
|
23 |
+
c_table.fill(99)
|
24 |
+
c_table[:4, :4] = np.array([[17, 18, 24, 47], [18, 21, 26, 66], [24, 26, 56, 99], [47, 66, 99, 99]]).T
|
25 |
+
c_table = nn.Parameter(torch.from_numpy(c_table))
|
26 |
+
|
27 |
+
|
28 |
+
def diff_round(x):
|
29 |
+
""" Differentiable rounding function
|
30 |
+
"""
|
31 |
+
return torch.round(x) + (x - torch.round(x))**3
|
32 |
+
|
33 |
+
|
34 |
+
def quality_to_factor(quality):
|
35 |
+
""" Calculate factor corresponding to quality
|
36 |
+
|
37 |
+
Args:
|
38 |
+
quality(float): Quality for jpeg compression.
|
39 |
+
|
40 |
+
Returns:
|
41 |
+
float: Compression factor.
|
42 |
+
"""
|
43 |
+
if quality < 50:
|
44 |
+
quality = 5000. / quality
|
45 |
+
else:
|
46 |
+
quality = 200. - quality * 2
|
47 |
+
return quality / 100.
|
48 |
+
|
49 |
+
|
50 |
+
# ------------------------ compression ------------------------#
|
51 |
+
class RGB2YCbCrJpeg(nn.Module):
|
52 |
+
""" Converts RGB image to YCbCr
|
53 |
+
"""
|
54 |
+
|
55 |
+
def __init__(self):
|
56 |
+
super(RGB2YCbCrJpeg, self).__init__()
|
57 |
+
matrix = np.array([[0.299, 0.587, 0.114], [-0.168736, -0.331264, 0.5], [0.5, -0.418688, -0.081312]],
|
58 |
+
dtype=np.float32).T
|
59 |
+
self.shift = nn.Parameter(torch.tensor([0., 128., 128.]))
|
60 |
+
self.matrix = nn.Parameter(torch.from_numpy(matrix))
|
61 |
+
|
62 |
+
def forward(self, image):
|
63 |
+
"""
|
64 |
+
Args:
|
65 |
+
image(Tensor): batch x 3 x height x width
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
Tensor: batch x height x width x 3
|
69 |
+
"""
|
70 |
+
image = image.permute(0, 2, 3, 1)
|
71 |
+
result = torch.tensordot(image, self.matrix, dims=1) + self.shift
|
72 |
+
return result.view(image.shape)
|
73 |
+
|
74 |
+
|
75 |
+
class ChromaSubsampling(nn.Module):
|
76 |
+
""" Chroma subsampling on CbCr channels
|
77 |
+
"""
|
78 |
+
|
79 |
+
def __init__(self):
|
80 |
+
super(ChromaSubsampling, self).__init__()
|
81 |
+
|
82 |
+
def forward(self, image):
|
83 |
+
"""
|
84 |
+
Args:
|
85 |
+
image(tensor): batch x height x width x 3
|
86 |
+
|
87 |
+
Returns:
|
88 |
+
y(tensor): batch x height x width
|
89 |
+
cb(tensor): batch x height/2 x width/2
|
90 |
+
cr(tensor): batch x height/2 x width/2
|
91 |
+
"""
|
92 |
+
image_2 = image.permute(0, 3, 1, 2).clone()
|
93 |
+
cb = F.avg_pool2d(image_2[:, 1, :, :].unsqueeze(1), kernel_size=2, stride=(2, 2), count_include_pad=False)
|
94 |
+
cr = F.avg_pool2d(image_2[:, 2, :, :].unsqueeze(1), kernel_size=2, stride=(2, 2), count_include_pad=False)
|
95 |
+
cb = cb.permute(0, 2, 3, 1)
|
96 |
+
cr = cr.permute(0, 2, 3, 1)
|
97 |
+
return image[:, :, :, 0], cb.squeeze(3), cr.squeeze(3)
|
98 |
+
|
99 |
+
|
100 |
+
class BlockSplitting(nn.Module):
|
101 |
+
""" Splitting image into patches
|
102 |
+
"""
|
103 |
+
|
104 |
+
def __init__(self):
|
105 |
+
super(BlockSplitting, self).__init__()
|
106 |
+
self.k = 8
|
107 |
+
|
108 |
+
def forward(self, image):
|
109 |
+
"""
|
110 |
+
Args:
|
111 |
+
image(tensor): batch x height x width
|
112 |
+
|
113 |
+
Returns:
|
114 |
+
Tensor: batch x h*w/64 x h x w
|
115 |
+
"""
|
116 |
+
height, _ = image.shape[1:3]
|
117 |
+
batch_size = image.shape[0]
|
118 |
+
image_reshaped = image.view(batch_size, height // self.k, self.k, -1, self.k)
|
119 |
+
image_transposed = image_reshaped.permute(0, 1, 3, 2, 4)
|
120 |
+
return image_transposed.contiguous().view(batch_size, -1, self.k, self.k)
|
121 |
+
|
122 |
+
|
123 |
+
class DCT8x8(nn.Module):
|
124 |
+
""" Discrete Cosine Transformation
|
125 |
+
"""
|
126 |
+
|
127 |
+
def __init__(self):
|
128 |
+
super(DCT8x8, self).__init__()
|
129 |
+
tensor = np.zeros((8, 8, 8, 8), dtype=np.float32)
|
130 |
+
for x, y, u, v in itertools.product(range(8), repeat=4):
|
131 |
+
tensor[x, y, u, v] = np.cos((2 * x + 1) * u * np.pi / 16) * np.cos((2 * y + 1) * v * np.pi / 16)
|
132 |
+
alpha = np.array([1. / np.sqrt(2)] + [1] * 7)
|
133 |
+
self.tensor = nn.Parameter(torch.from_numpy(tensor).float())
|
134 |
+
self.scale = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha) * 0.25).float())
|
135 |
+
|
136 |
+
def forward(self, image):
|
137 |
+
"""
|
138 |
+
Args:
|
139 |
+
image(tensor): batch x height x width
|
140 |
+
|
141 |
+
Returns:
|
142 |
+
Tensor: batch x height x width
|
143 |
+
"""
|
144 |
+
image = image - 128
|
145 |
+
result = self.scale * torch.tensordot(image, self.tensor, dims=2)
|
146 |
+
result.view(image.shape)
|
147 |
+
return result
|
148 |
+
|
149 |
+
|
150 |
+
class YQuantize(nn.Module):
|
151 |
+
""" JPEG Quantization for Y channel
|
152 |
+
|
153 |
+
Args:
|
154 |
+
rounding(function): rounding function to use
|
155 |
+
"""
|
156 |
+
|
157 |
+
def __init__(self, rounding):
|
158 |
+
super(YQuantize, self).__init__()
|
159 |
+
self.rounding = rounding
|
160 |
+
self.y_table = y_table
|
161 |
+
|
162 |
+
def forward(self, image, factor=1):
|
163 |
+
"""
|
164 |
+
Args:
|
165 |
+
image(tensor): batch x height x width
|
166 |
+
|
167 |
+
Returns:
|
168 |
+
Tensor: batch x height x width
|
169 |
+
"""
|
170 |
+
if isinstance(factor, (int, float)):
|
171 |
+
image = image.float() / (self.y_table * factor)
|
172 |
+
else:
|
173 |
+
b = factor.size(0)
|
174 |
+
table = self.y_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1)
|
175 |
+
image = image.float() / table
|
176 |
+
image = self.rounding(image)
|
177 |
+
return image
|
178 |
+
|
179 |
+
|
180 |
+
class CQuantize(nn.Module):
|
181 |
+
""" JPEG Quantization for CbCr channels
|
182 |
+
|
183 |
+
Args:
|
184 |
+
rounding(function): rounding function to use
|
185 |
+
"""
|
186 |
+
|
187 |
+
def __init__(self, rounding):
|
188 |
+
super(CQuantize, self).__init__()
|
189 |
+
self.rounding = rounding
|
190 |
+
self.c_table = c_table
|
191 |
+
|
192 |
+
def forward(self, image, factor=1):
|
193 |
+
"""
|
194 |
+
Args:
|
195 |
+
image(tensor): batch x height x width
|
196 |
+
|
197 |
+
Returns:
|
198 |
+
Tensor: batch x height x width
|
199 |
+
"""
|
200 |
+
if isinstance(factor, (int, float)):
|
201 |
+
image = image.float() / (self.c_table * factor)
|
202 |
+
else:
|
203 |
+
b = factor.size(0)
|
204 |
+
table = self.c_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1)
|
205 |
+
image = image.float() / table
|
206 |
+
image = self.rounding(image)
|
207 |
+
return image
|
208 |
+
|
209 |
+
|
210 |
+
class CompressJpeg(nn.Module):
|
211 |
+
"""Full JPEG compression algorithm
|
212 |
+
|
213 |
+
Args:
|
214 |
+
rounding(function): rounding function to use
|
215 |
+
"""
|
216 |
+
|
217 |
+
def __init__(self, rounding=torch.round):
|
218 |
+
super(CompressJpeg, self).__init__()
|
219 |
+
self.l1 = nn.Sequential(RGB2YCbCrJpeg(), ChromaSubsampling())
|
220 |
+
self.l2 = nn.Sequential(BlockSplitting(), DCT8x8())
|
221 |
+
self.c_quantize = CQuantize(rounding=rounding)
|
222 |
+
self.y_quantize = YQuantize(rounding=rounding)
|
223 |
+
|
224 |
+
def forward(self, image, factor=1):
|
225 |
+
"""
|
226 |
+
Args:
|
227 |
+
image(tensor): batch x 3 x height x width
|
228 |
+
|
229 |
+
Returns:
|
230 |
+
dict(tensor): Compressed tensor with batch x h*w/64 x 8 x 8.
|
231 |
+
"""
|
232 |
+
y, cb, cr = self.l1(image * 255)
|
233 |
+
components = {'y': y, 'cb': cb, 'cr': cr}
|
234 |
+
for k in components.keys():
|
235 |
+
comp = self.l2(components[k])
|
236 |
+
if k in ('cb', 'cr'):
|
237 |
+
comp = self.c_quantize(comp, factor=factor)
|
238 |
+
else:
|
239 |
+
comp = self.y_quantize(comp, factor=factor)
|
240 |
+
|
241 |
+
components[k] = comp
|
242 |
+
|
243 |
+
return components['y'], components['cb'], components['cr']
|
244 |
+
|
245 |
+
|
246 |
+
# ------------------------ decompression ------------------------#
|
247 |
+
|
248 |
+
|
249 |
+
class YDequantize(nn.Module):
|
250 |
+
"""Dequantize Y channel
|
251 |
+
"""
|
252 |
+
|
253 |
+
def __init__(self):
|
254 |
+
super(YDequantize, self).__init__()
|
255 |
+
self.y_table = y_table
|
256 |
+
|
257 |
+
def forward(self, image, factor=1):
|
258 |
+
"""
|
259 |
+
Args:
|
260 |
+
image(tensor): batch x height x width
|
261 |
+
|
262 |
+
Returns:
|
263 |
+
Tensor: batch x height x width
|
264 |
+
"""
|
265 |
+
if isinstance(factor, (int, float)):
|
266 |
+
out = image * (self.y_table * factor)
|
267 |
+
else:
|
268 |
+
b = factor.size(0)
|
269 |
+
table = self.y_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1)
|
270 |
+
out = image * table
|
271 |
+
return out
|
272 |
+
|
273 |
+
|
274 |
+
class CDequantize(nn.Module):
|
275 |
+
"""Dequantize CbCr channel
|
276 |
+
"""
|
277 |
+
|
278 |
+
def __init__(self):
|
279 |
+
super(CDequantize, self).__init__()
|
280 |
+
self.c_table = c_table
|
281 |
+
|
282 |
+
def forward(self, image, factor=1):
|
283 |
+
"""
|
284 |
+
Args:
|
285 |
+
image(tensor): batch x height x width
|
286 |
+
|
287 |
+
Returns:
|
288 |
+
Tensor: batch x height x width
|
289 |
+
"""
|
290 |
+
if isinstance(factor, (int, float)):
|
291 |
+
out = image * (self.c_table * factor)
|
292 |
+
else:
|
293 |
+
b = factor.size(0)
|
294 |
+
table = self.c_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1)
|
295 |
+
out = image * table
|
296 |
+
return out
|
297 |
+
|
298 |
+
|
299 |
+
class iDCT8x8(nn.Module):
|
300 |
+
"""Inverse discrete Cosine Transformation
|
301 |
+
"""
|
302 |
+
|
303 |
+
def __init__(self):
|
304 |
+
super(iDCT8x8, self).__init__()
|
305 |
+
alpha = np.array([1. / np.sqrt(2)] + [1] * 7)
|
306 |
+
self.alpha = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha)).float())
|
307 |
+
tensor = np.zeros((8, 8, 8, 8), dtype=np.float32)
|
308 |
+
for x, y, u, v in itertools.product(range(8), repeat=4):
|
309 |
+
tensor[x, y, u, v] = np.cos((2 * u + 1) * x * np.pi / 16) * np.cos((2 * v + 1) * y * np.pi / 16)
|
310 |
+
self.tensor = nn.Parameter(torch.from_numpy(tensor).float())
|
311 |
+
|
312 |
+
def forward(self, image):
|
313 |
+
"""
|
314 |
+
Args:
|
315 |
+
image(tensor): batch x height x width
|
316 |
+
|
317 |
+
Returns:
|
318 |
+
Tensor: batch x height x width
|
319 |
+
"""
|
320 |
+
image = image * self.alpha
|
321 |
+
result = 0.25 * torch.tensordot(image, self.tensor, dims=2) + 128
|
322 |
+
result.view(image.shape)
|
323 |
+
return result
|
324 |
+
|
325 |
+
|
326 |
+
class BlockMerging(nn.Module):
|
327 |
+
"""Merge patches into image
|
328 |
+
"""
|
329 |
+
|
330 |
+
def __init__(self):
|
331 |
+
super(BlockMerging, self).__init__()
|
332 |
+
|
333 |
+
def forward(self, patches, height, width):
|
334 |
+
"""
|
335 |
+
Args:
|
336 |
+
patches(tensor) batch x height*width/64, height x width
|
337 |
+
height(int)
|
338 |
+
width(int)
|
339 |
+
|
340 |
+
Returns:
|
341 |
+
Tensor: batch x height x width
|
342 |
+
"""
|
343 |
+
k = 8
|
344 |
+
batch_size = patches.shape[0]
|
345 |
+
image_reshaped = patches.view(batch_size, height // k, width // k, k, k)
|
346 |
+
image_transposed = image_reshaped.permute(0, 1, 3, 2, 4)
|
347 |
+
return image_transposed.contiguous().view(batch_size, height, width)
|
348 |
+
|
349 |
+
|
350 |
+
class ChromaUpsampling(nn.Module):
|
351 |
+
"""Upsample chroma layers
|
352 |
+
"""
|
353 |
+
|
354 |
+
def __init__(self):
|
355 |
+
super(ChromaUpsampling, self).__init__()
|
356 |
+
|
357 |
+
def forward(self, y, cb, cr):
|
358 |
+
"""
|
359 |
+
Args:
|
360 |
+
y(tensor): y channel image
|
361 |
+
cb(tensor): cb channel
|
362 |
+
cr(tensor): cr channel
|
363 |
+
|
364 |
+
Returns:
|
365 |
+
Tensor: batch x height x width x 3
|
366 |
+
"""
|
367 |
+
|
368 |
+
def repeat(x, k=2):
|
369 |
+
height, width = x.shape[1:3]
|
370 |
+
x = x.unsqueeze(-1)
|
371 |
+
x = x.repeat(1, 1, k, k)
|
372 |
+
x = x.view(-1, height * k, width * k)
|
373 |
+
return x
|
374 |
+
|
375 |
+
cb = repeat(cb)
|
376 |
+
cr = repeat(cr)
|
377 |
+
return torch.cat([y.unsqueeze(3), cb.unsqueeze(3), cr.unsqueeze(3)], dim=3)
|
378 |
+
|
379 |
+
|
380 |
+
class YCbCr2RGBJpeg(nn.Module):
|
381 |
+
"""Converts YCbCr image to RGB JPEG
|
382 |
+
"""
|
383 |
+
|
384 |
+
def __init__(self):
|
385 |
+
super(YCbCr2RGBJpeg, self).__init__()
|
386 |
+
|
387 |
+
matrix = np.array([[1., 0., 1.402], [1, -0.344136, -0.714136], [1, 1.772, 0]], dtype=np.float32).T
|
388 |
+
self.shift = nn.Parameter(torch.tensor([0, -128., -128.]))
|
389 |
+
self.matrix = nn.Parameter(torch.from_numpy(matrix))
|
390 |
+
|
391 |
+
def forward(self, image):
|
392 |
+
"""
|
393 |
+
Args:
|
394 |
+
image(tensor): batch x height x width x 3
|
395 |
+
|
396 |
+
Returns:
|
397 |
+
Tensor: batch x 3 x height x width
|
398 |
+
"""
|
399 |
+
result = torch.tensordot(image + self.shift, self.matrix, dims=1)
|
400 |
+
return result.view(image.shape).permute(0, 3, 1, 2)
|
401 |
+
|
402 |
+
|
403 |
+
class DeCompressJpeg(nn.Module):
|
404 |
+
"""Full JPEG decompression algorithm
|
405 |
+
|
406 |
+
Args:
|
407 |
+
rounding(function): rounding function to use
|
408 |
+
"""
|
409 |
+
|
410 |
+
def __init__(self, rounding=torch.round):
|
411 |
+
super(DeCompressJpeg, self).__init__()
|
412 |
+
self.c_dequantize = CDequantize()
|
413 |
+
self.y_dequantize = YDequantize()
|
414 |
+
self.idct = iDCT8x8()
|
415 |
+
self.merging = BlockMerging()
|
416 |
+
self.chroma = ChromaUpsampling()
|
417 |
+
self.colors = YCbCr2RGBJpeg()
|
418 |
+
|
419 |
+
def forward(self, y, cb, cr, imgh, imgw, factor=1):
|
420 |
+
"""
|
421 |
+
Args:
|
422 |
+
compressed(dict(tensor)): batch x h*w/64 x 8 x 8
|
423 |
+
imgh(int)
|
424 |
+
imgw(int)
|
425 |
+
factor(float)
|
426 |
+
|
427 |
+
Returns:
|
428 |
+
Tensor: batch x 3 x height x width
|
429 |
+
"""
|
430 |
+
components = {'y': y, 'cb': cb, 'cr': cr}
|
431 |
+
for k in components.keys():
|
432 |
+
if k in ('cb', 'cr'):
|
433 |
+
comp = self.c_dequantize(components[k], factor=factor)
|
434 |
+
height, width = int(imgh / 2), int(imgw / 2)
|
435 |
+
else:
|
436 |
+
comp = self.y_dequantize(components[k], factor=factor)
|
437 |
+
height, width = imgh, imgw
|
438 |
+
comp = self.idct(comp)
|
439 |
+
components[k] = self.merging(comp, height, width)
|
440 |
+
#
|
441 |
+
image = self.chroma(components['y'], components['cb'], components['cr'])
|
442 |
+
image = self.colors(image)
|
443 |
+
|
444 |
+
image = torch.min(255 * torch.ones_like(image), torch.max(torch.zeros_like(image), image))
|
445 |
+
return image / 255
|
446 |
+
|
447 |
+
|
448 |
+
# ------------------------ main DiffJPEG ------------------------ #
|
449 |
+
|
450 |
+
|
451 |
+
class DiffJPEG(nn.Module):
|
452 |
+
"""This JPEG algorithm result is slightly different from cv2.
|
453 |
+
DiffJPEG supports batch processing.
|
454 |
+
|
455 |
+
Args:
|
456 |
+
differentiable(bool): If True, uses custom differentiable rounding function, if False, uses standard torch.round
|
457 |
+
"""
|
458 |
+
|
459 |
+
def __init__(self, differentiable=True):
|
460 |
+
super(DiffJPEG, self).__init__()
|
461 |
+
if differentiable:
|
462 |
+
rounding = diff_round
|
463 |
+
else:
|
464 |
+
rounding = torch.round
|
465 |
+
|
466 |
+
self.compress = CompressJpeg(rounding=rounding)
|
467 |
+
self.decompress = DeCompressJpeg(rounding=rounding)
|
468 |
+
|
469 |
+
def forward(self, x, quality):
|
470 |
+
"""
|
471 |
+
Args:
|
472 |
+
x (Tensor): Input image, bchw, rgb, [0, 1]
|
473 |
+
quality(float): Quality factor for jpeg compression scheme.
|
474 |
+
"""
|
475 |
+
factor = quality
|
476 |
+
if isinstance(factor, (int, float)):
|
477 |
+
factor = quality_to_factor(factor)
|
478 |
+
else:
|
479 |
+
for i in range(factor.size(0)):
|
480 |
+
factor[i] = quality_to_factor(factor[i])
|
481 |
+
h, w = x.size()[-2:]
|
482 |
+
h_pad, w_pad = 0, 0
|
483 |
+
# why should use 16
|
484 |
+
if h % 16 != 0:
|
485 |
+
h_pad = 16 - h % 16
|
486 |
+
if w % 16 != 0:
|
487 |
+
w_pad = 16 - w % 16
|
488 |
+
x = F.pad(x, (0, w_pad, 0, h_pad), mode='constant', value=0)
|
489 |
+
|
490 |
+
y, cb, cr = self.compress(x, factor=factor)
|
491 |
+
recovered = self.decompress(y, cb, cr, (h + h_pad), (w + w_pad), factor=factor)
|
492 |
+
recovered = recovered[:, :, 0:h, 0:w]
|
493 |
+
return recovered
|
494 |
+
|
495 |
+
|
496 |
+
if __name__ == '__main__':
|
497 |
+
import cv2
|
498 |
+
|
499 |
+
from basicsr.utils import img2tensor, tensor2img
|
500 |
+
|
501 |
+
img_gt = cv2.imread('test.png') / 255.
|
502 |
+
|
503 |
+
# -------------- cv2 -------------- #
|
504 |
+
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), 20]
|
505 |
+
_, encimg = cv2.imencode('.jpg', img_gt * 255., encode_param)
|
506 |
+
img_lq = np.float32(cv2.imdecode(encimg, 1))
|
507 |
+
cv2.imwrite('cv2_JPEG_20.png', img_lq)
|
508 |
+
|
509 |
+
# -------------- DiffJPEG -------------- #
|
510 |
+
jpeger = DiffJPEG(differentiable=False).cuda()
|
511 |
+
img_gt = img2tensor(img_gt)
|
512 |
+
img_gt = torch.stack([img_gt, img_gt]).cuda()
|
513 |
+
quality = img_gt.new_tensor([20, 40])
|
514 |
+
out = jpeger(img_gt, quality=quality)
|
515 |
+
|
516 |
+
cv2.imwrite('pt_JPEG_20.png', tensor2img(out[0]))
|
517 |
+
cv2.imwrite('pt_JPEG_40.png', tensor2img(out[1]))
|
degradation/ESR/usm_sharp.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from torch.nn import functional as F
|
7 |
+
|
8 |
+
import os, sys
|
9 |
+
root_path = os.path.abspath('.')
|
10 |
+
sys.path.append(root_path)
|
11 |
+
from degradation.ESR.utils import filter2D, np2tensor, tensor2np
|
12 |
+
|
13 |
+
|
14 |
+
def usm_sharp_func(img, weight=0.5, radius=50, threshold=10):
|
15 |
+
"""USM sharpening.
|
16 |
+
|
17 |
+
Input image: I; Blurry image: B.
|
18 |
+
1. sharp = I + weight * (I - B)
|
19 |
+
2. Mask = 1 if abs(I - B) > threshold, else: 0
|
20 |
+
3. Blur mask:
|
21 |
+
4. Out = Mask * sharp + (1 - Mask) * I
|
22 |
+
|
23 |
+
|
24 |
+
Args:
|
25 |
+
img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
|
26 |
+
weight (float): Sharp weight. Default: 1.
|
27 |
+
radius (float): Kernel size of Gaussian blur. Default: 50.
|
28 |
+
threshold (int):
|
29 |
+
"""
|
30 |
+
if radius % 2 == 0:
|
31 |
+
radius += 1
|
32 |
+
blur = cv2.GaussianBlur(img, (radius, radius), 0)
|
33 |
+
residual = img - blur
|
34 |
+
mask = np.abs(residual) * 255 > threshold
|
35 |
+
mask = mask.astype('float32')
|
36 |
+
soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
|
37 |
+
|
38 |
+
sharp = img + weight * residual
|
39 |
+
sharp = np.clip(sharp, 0, 1)
|
40 |
+
return soft_mask * sharp + (1 - soft_mask) * img
|
41 |
+
|
42 |
+
|
43 |
+
|
44 |
+
class USMSharp(torch.nn.Module):
|
45 |
+
|
46 |
+
def __init__(self, type, radius=50, sigma=0):
|
47 |
+
super(USMSharp, self).__init__()
|
48 |
+
if radius % 2 == 0:
|
49 |
+
radius += 1
|
50 |
+
self.radius = radius
|
51 |
+
kernel = cv2.getGaussianKernel(radius, sigma)
|
52 |
+
kernel = torch.FloatTensor(np.dot(kernel, kernel.transpose())).unsqueeze_(0).cuda()
|
53 |
+
self.register_buffer('kernel', kernel)
|
54 |
+
|
55 |
+
self.type = type
|
56 |
+
|
57 |
+
|
58 |
+
def forward(self, img, weight=0.5, threshold=10, store=False):
|
59 |
+
|
60 |
+
if self.type == "cv2":
|
61 |
+
# pre-process cv2 type
|
62 |
+
img = np2tensor(img)
|
63 |
+
|
64 |
+
blur = filter2D(img, self.kernel.cuda())
|
65 |
+
if store:
|
66 |
+
cv2.imwrite("blur.png", tensor2np(blur))
|
67 |
+
|
68 |
+
residual = img - blur
|
69 |
+
if store:
|
70 |
+
cv2.imwrite("residual.png", tensor2np(residual))
|
71 |
+
|
72 |
+
mask = torch.abs(residual) * 255 > threshold
|
73 |
+
if store:
|
74 |
+
cv2.imwrite("mask.png", tensor2np(mask))
|
75 |
+
|
76 |
+
|
77 |
+
mask = mask.float()
|
78 |
+
soft_mask = filter2D(mask, self.kernel.cuda())
|
79 |
+
if store:
|
80 |
+
cv2.imwrite("soft_mask.png", tensor2np(soft_mask))
|
81 |
+
|
82 |
+
sharp = img + weight * residual
|
83 |
+
sharp = torch.clip(sharp, 0, 1)
|
84 |
+
if store:
|
85 |
+
cv2.imwrite("sharp.png", tensor2np(sharp))
|
86 |
+
|
87 |
+
output = soft_mask * sharp + (1 - soft_mask) * img
|
88 |
+
if self.type == "cv2":
|
89 |
+
output = tensor2np(output)
|
90 |
+
|
91 |
+
return output
|
92 |
+
|
93 |
+
|
94 |
+
|
95 |
+
if __name__ == "__main__":
|
96 |
+
|
97 |
+
usm_sharper = USMSharp(type="cv2")
|
98 |
+
img = cv2.imread("sample3.png")
|
99 |
+
print(img.shape)
|
100 |
+
sharp_output = usm_sharper(img, store=False, threshold=10)
|
101 |
+
cv2.imwrite(os.path.join("output.png"), sharp_output)
|
102 |
+
|
103 |
+
|
104 |
+
# dir = r"C:\Users\HikariDawn\Desktop\Real-CUGAN\datasets\sample"
|
105 |
+
# output_dir = r"C:\Users\HikariDawn\Desktop\Real-CUGAN\datasets\sharp_regular"
|
106 |
+
# if not os.path.exists(output_dir):
|
107 |
+
# os.makedirs(output_dir)
|
108 |
+
|
109 |
+
# for file_name in sorted(os.listdir(dir)):
|
110 |
+
# print(file_name)
|
111 |
+
# file = os.path.join(dir, file_name)
|
112 |
+
# img = cv2.imread(file)
|
113 |
+
# sharp_output = usm_sharper(img)
|
114 |
+
# cv2.imwrite(os.path.join(output_dir, file_name), sharp_output)
|
degradation/ESR/utils.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
'''
|
4 |
+
From ESRGAN
|
5 |
+
'''
|
6 |
+
|
7 |
+
|
8 |
+
import os, sys
|
9 |
+
import cv2
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
from torch.nn import functional as F
|
13 |
+
from scipy import special
|
14 |
+
import random
|
15 |
+
import math
|
16 |
+
from torchvision.utils import make_grid
|
17 |
+
|
18 |
+
from degradation.ESR.degradations_functionality import *
|
19 |
+
|
20 |
+
root_path = os.path.abspath('.')
|
21 |
+
sys.path.append(root_path)
|
22 |
+
|
23 |
+
|
24 |
+
def np2tensor(np_frame):
|
25 |
+
return torch.from_numpy(np.transpose(np_frame, (2, 0, 1))).unsqueeze(0).cuda().float()/255
|
26 |
+
|
27 |
+
def tensor2np(tensor):
|
28 |
+
# tensor should be batch size1 and cannot be grayscale input
|
29 |
+
return (np.transpose(tensor.detach().squeeze(0).cpu().numpy(), (1, 2, 0))) * 255
|
30 |
+
|
31 |
+
def mass_tensor2np(tensor):
|
32 |
+
''' The input tensor is massive tensor
|
33 |
+
'''
|
34 |
+
return (np.transpose(tensor.detach().squeeze(0).cpu().numpy(), (0, 2, 3, 1))) * 255
|
35 |
+
|
36 |
+
def save_img(tensor, save_name):
|
37 |
+
np_img = tensor2np(tensor)[:,:,16]
|
38 |
+
# np_img = np.expand_dims(np_img, axis=2)
|
39 |
+
cv2.imwrite(save_name, np_img)
|
40 |
+
|
41 |
+
|
42 |
+
def filter2D(img, kernel):
|
43 |
+
"""PyTorch version of cv2.filter2D
|
44 |
+
|
45 |
+
Args:
|
46 |
+
img (Tensor): (b, c, h, w)
|
47 |
+
kernel (Tensor): (b, k, k)
|
48 |
+
"""
|
49 |
+
k = kernel.size(-1)
|
50 |
+
b, c, h, w = img.size()
|
51 |
+
if k % 2 == 1:
|
52 |
+
img = F.pad(img, (k // 2, k // 2, k // 2, k // 2), mode='reflect')
|
53 |
+
else:
|
54 |
+
raise ValueError('Wrong kernel size')
|
55 |
+
|
56 |
+
ph, pw = img.size()[-2:]
|
57 |
+
|
58 |
+
if kernel.size(0) == 1:
|
59 |
+
# apply the same kernel to all batch images
|
60 |
+
img = img.view(b * c, 1, ph, pw)
|
61 |
+
kernel = kernel.view(1, 1, k, k)
|
62 |
+
return F.conv2d(img, kernel, padding=0).view(b, c, h, w)
|
63 |
+
else:
|
64 |
+
img = img.view(1, b * c, ph, pw)
|
65 |
+
kernel = kernel.view(b, 1, k, k).repeat(1, c, 1, 1).view(b * c, 1, k, k)
|
66 |
+
return F.conv2d(img, kernel, groups=b * c).view(b, c, h, w)
|
67 |
+
|
68 |
+
|
69 |
+
def generate_kernels(opt):
|
70 |
+
|
71 |
+
kernel_range = [2 * v + 1 for v in range(opt["kernel_range"][0], opt["kernel_range"][1])]
|
72 |
+
|
73 |
+
# ------------------------ Generate kernels (used in the first degradation) ------------------------ #
|
74 |
+
kernel_size = random.choice(kernel_range)
|
75 |
+
if np.random.uniform() < opt['sinc_prob']:
|
76 |
+
# 里面加一层sinc filter,但是10%的概率
|
77 |
+
# this sinc filter setting is for kernels ranging from [7, 21]
|
78 |
+
if kernel_size < 13:
|
79 |
+
omega_c = np.random.uniform(np.pi / 3, np.pi)
|
80 |
+
else:
|
81 |
+
omega_c = np.random.uniform(np.pi / 5, np.pi)
|
82 |
+
kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
|
83 |
+
else:
|
84 |
+
kernel = random_mixed_kernels(
|
85 |
+
opt['kernel_list'],
|
86 |
+
opt['kernel_prob'],
|
87 |
+
kernel_size,
|
88 |
+
opt['blur_sigma'],
|
89 |
+
opt['blur_sigma'], [-math.pi, math.pi],
|
90 |
+
opt['betag_range'],
|
91 |
+
opt['betap_range'],
|
92 |
+
noise_range=None)
|
93 |
+
# pad kernel: -在v2我是直接省略了padding
|
94 |
+
pad_size = (21 - kernel_size) // 2
|
95 |
+
kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
|
96 |
+
|
97 |
+
|
98 |
+
# ------------------------ Generate kernels (used in the second degradation) ------------------------ #
|
99 |
+
kernel_size = random.choice(kernel_range)
|
100 |
+
if np.random.uniform() < opt['sinc_prob2']:
|
101 |
+
# 里面加一层sinc filter,但是10%的概率
|
102 |
+
if kernel_size < 13:
|
103 |
+
omega_c = np.random.uniform(np.pi / 3, np.pi)
|
104 |
+
else:
|
105 |
+
omega_c = np.random.uniform(np.pi / 5, np.pi)
|
106 |
+
kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
|
107 |
+
else:
|
108 |
+
kernel2 = random_mixed_kernels(
|
109 |
+
opt['kernel_list2'],
|
110 |
+
opt['kernel_prob2'],
|
111 |
+
kernel_size,
|
112 |
+
opt['blur_sigma2'],
|
113 |
+
opt['blur_sigma2'], [-math.pi, math.pi],
|
114 |
+
opt['betag_range2'],
|
115 |
+
opt['betap_range2'],
|
116 |
+
noise_range=None)
|
117 |
+
|
118 |
+
# pad kernel
|
119 |
+
pad_size = (21 - kernel_size) // 2
|
120 |
+
kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size)))
|
121 |
+
|
122 |
+
kernel = torch.FloatTensor(kernel)
|
123 |
+
kernel2 = torch.FloatTensor(kernel2)
|
124 |
+
return (kernel, kernel2)
|
125 |
+
|
126 |
+
|
degradation/degradation_esr.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
import torch
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
root_path = os.path.abspath('.')
|
8 |
+
sys.path.append(root_path)
|
9 |
+
# Import files from the local folder
|
10 |
+
from opt import opt
|
11 |
+
from degradation.ESR.utils import generate_kernels, mass_tensor2np, tensor2np
|
12 |
+
from degradation.ESR.degradations_functionality import *
|
13 |
+
from degradation.ESR.degradation_esr_shared import common_degradation as regular_common_degradation
|
14 |
+
from degradation.image_compression.jpeg import JPEG # 这里最好后面用一个继承解决一切
|
15 |
+
from degradation.image_compression.webp import WEBP
|
16 |
+
from degradation.image_compression.heif import HEIF
|
17 |
+
from degradation.image_compression.avif import AVIF
|
18 |
+
from degradation.video_compression.h264 import H264
|
19 |
+
from degradation.video_compression.h265 import H265
|
20 |
+
from degradation.video_compression.mpeg2 import MPEG2
|
21 |
+
from degradation.video_compression.mpeg4 import MPEG4
|
22 |
+
|
23 |
+
|
24 |
+
class degradation_v1:
|
25 |
+
def __init__(self):
|
26 |
+
self.kernel1, self.kernel2, self.sinc_kernel = None, None, None
|
27 |
+
self.queue_size = 160
|
28 |
+
|
29 |
+
# Init the compression instance
|
30 |
+
self.jpeg_instance = JPEG()
|
31 |
+
self.webp_instance = WEBP()
|
32 |
+
# self.heif_instance = HEIF()
|
33 |
+
self.avif_instance = AVIF()
|
34 |
+
self.H264_instance = H264()
|
35 |
+
self.H265_instance = H265()
|
36 |
+
self.MPEG2_instance = MPEG2()
|
37 |
+
self.MPEG4_instance = MPEG4()
|
38 |
+
|
39 |
+
|
40 |
+
def reset_kernels(self, opt):
|
41 |
+
kernel1, kernel2 = generate_kernels(opt)
|
42 |
+
self.kernel1 = kernel1.unsqueeze(0).cuda()
|
43 |
+
self.kernel2 = kernel2.unsqueeze(0).cuda()
|
44 |
+
|
45 |
+
|
46 |
+
@torch.no_grad()
|
47 |
+
def degradate_process(self, out, opt, store_path, process_id, verbose = False):
|
48 |
+
''' ESR Degradation V1 mode (Same as the original paper)
|
49 |
+
Args:
|
50 |
+
out (tensor): BxCxHxW All input images as tensor
|
51 |
+
opt (dict): All configuration we need to process
|
52 |
+
store_path (str): Store Directory
|
53 |
+
process_id (int): The id we used to store temporary file
|
54 |
+
verbose (bool): Whether print some information for auxiliary log (default: False)
|
55 |
+
'''
|
56 |
+
|
57 |
+
batch_size, _, ori_h, ori_w = out.size()
|
58 |
+
|
59 |
+
# Shared degradation until the last step
|
60 |
+
resize_mode = random.choice(opt['resize_options'])
|
61 |
+
out = regular_common_degradation(out, opt, [self.kernel1, self.kernel2], process_id, verbose=verbose)
|
62 |
+
|
63 |
+
|
64 |
+
# Resize back
|
65 |
+
out = F.interpolate(out, size=(ori_h // opt['scale'], ori_w // opt['scale']), mode = resize_mode)
|
66 |
+
out = torch.clamp(out, 0, 1)
|
67 |
+
# TODO: 可能Tensor2Numpy会放在之前,而不是在这里,一起转换节约时间
|
68 |
+
|
69 |
+
# Tensor2np
|
70 |
+
np_frame = tensor2np(out)
|
71 |
+
|
72 |
+
# Choose an image compression codec (All degradation batch use the same codec)
|
73 |
+
compression_codec = random.choices(opt['compression_codec2'], opt['compression_codec_prob2'])[0] # All lower case
|
74 |
+
|
75 |
+
if compression_codec == "jpeg":
|
76 |
+
self.jpeg_instance.compress_and_store(np_frame, store_path, process_id)
|
77 |
+
|
78 |
+
elif compression_codec == "webp":
|
79 |
+
try:
|
80 |
+
self.webp_instance.compress_and_store(np_frame, store_path, process_id)
|
81 |
+
except Exception:
|
82 |
+
print("There appears to be exception in webp again!")
|
83 |
+
if os.path.exists(store_path):
|
84 |
+
os.remove(store_path)
|
85 |
+
self.webp_instance.compress_and_store(np_frame, store_path, process_id)
|
86 |
+
|
87 |
+
elif compression_codec == "avif":
|
88 |
+
self.avif_instance.compress_and_store(np_frame, store_path, process_id)
|
89 |
+
|
90 |
+
elif compression_codec == "h264":
|
91 |
+
self.H264_instance.compress_and_store(np_frame, store_path, process_id)
|
92 |
+
|
93 |
+
elif compression_codec == "h265":
|
94 |
+
self.H265_instance.compress_and_store(np_frame, store_path, process_id)
|
95 |
+
|
96 |
+
elif compression_codec == "mpeg2":
|
97 |
+
self.MPEG2_instance.compress_and_store(np_frame, store_path, process_id)
|
98 |
+
|
99 |
+
elif compression_codec == "mpeg4":
|
100 |
+
self.MPEG4_instance.compress_and_store(np_frame, store_path, process_id)
|
101 |
+
|
102 |
+
else:
|
103 |
+
raise NotImplementedError("This compression codec is not supported! Please check the implementation!")
|
104 |
+
|
105 |
+
|
106 |
+
|
107 |
+
|
108 |
+
|
109 |
+
|
110 |
+
|
degradation/image_compression/avif.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch, sys, os, random
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import numpy as np
|
4 |
+
import cv2
|
5 |
+
from multiprocessing import Process, Queue
|
6 |
+
from PIL import Image
|
7 |
+
import pillow_heif
|
8 |
+
|
9 |
+
root_path = os.path.abspath('.')
|
10 |
+
sys.path.append(root_path)
|
11 |
+
# Import files from the local folder
|
12 |
+
from opt import opt
|
13 |
+
from degradation.ESR.utils import tensor2np, np2tensor
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
class AVIF():
|
18 |
+
def __init__(self) -> None:
|
19 |
+
# Choose an image compression degradation
|
20 |
+
pass
|
21 |
+
|
22 |
+
def compress_and_store(self, np_frames, store_path, idx):
|
23 |
+
''' Compress and Store the whole batch as AVIF (~ AV1)
|
24 |
+
Args:
|
25 |
+
np_frames (numpy): The numpy format of the data (Shape:?)
|
26 |
+
store_path (str): The store path
|
27 |
+
Return:
|
28 |
+
None
|
29 |
+
'''
|
30 |
+
# Init call for avif
|
31 |
+
pillow_heif.register_avif_opener()
|
32 |
+
|
33 |
+
|
34 |
+
single_frame = np_frames
|
35 |
+
|
36 |
+
# Prepare
|
37 |
+
essential_name = "tmp/temp_"+str(idx)
|
38 |
+
|
39 |
+
# Choose the quality
|
40 |
+
quality = random.randint(*opt['avif_quality_range2'])
|
41 |
+
method = random.randint(*opt['avif_encode_speed2'])
|
42 |
+
|
43 |
+
# Transform to PIL and then compress
|
44 |
+
PIL_image = Image.fromarray(np.uint8(single_frame[...,::-1])).convert('RGB')
|
45 |
+
PIL_image.save(essential_name+'.avif', quality=quality, method=method)
|
46 |
+
|
47 |
+
# Read as png
|
48 |
+
avif_file = pillow_heif.open_heif(essential_name+'.avif', convert_hdr_to_8bit=False, bgr_mode=True)
|
49 |
+
np_array = np.asarray(avif_file)
|
50 |
+
cv2.imwrite(store_path, np_array)
|
51 |
+
|
52 |
+
os.remove(essential_name+'.avif')
|
53 |
+
|
54 |
+
|
55 |
+
|
56 |
+
@staticmethod
|
57 |
+
def compress_tensor(tensor_frames, idx=0):
|
58 |
+
''' Compress tensor input to AVIF and then return it
|
59 |
+
Args:
|
60 |
+
tensor_frame (tensor): Tensor inputs
|
61 |
+
Returns:
|
62 |
+
result (tensor): Tensor outputs (same shape as input)
|
63 |
+
'''
|
64 |
+
# Init call for avif
|
65 |
+
pillow_heif.register_avif_opener()
|
66 |
+
|
67 |
+
# Prepare
|
68 |
+
single_frame = tensor2np(tensor_frames)
|
69 |
+
essential_name = "tmp/temp_"+str(idx)
|
70 |
+
|
71 |
+
# Choose the quality
|
72 |
+
quality = random.randint(*opt['avif_quality_range1'])
|
73 |
+
method = random.randint(*opt['avif_encode_speed1'])
|
74 |
+
|
75 |
+
# Transform to PIL and then compress
|
76 |
+
PIL_image = Image.fromarray(np.uint8(single_frame[...,::-1])).convert('RGB')
|
77 |
+
PIL_image.save(essential_name+'.avif', quality=quality, method=method)
|
78 |
+
|
79 |
+
# Transform as png format
|
80 |
+
avif_file = pillow_heif.open_heif(essential_name+'.avif', convert_hdr_to_8bit=False, bgr_mode=True)
|
81 |
+
decimg = np.asarray(avif_file)
|
82 |
+
os.remove(essential_name+'.avif')
|
83 |
+
|
84 |
+
# Read back
|
85 |
+
result = np2tensor(decimg)
|
86 |
+
|
87 |
+
|
88 |
+
return result
|
degradation/image_compression/heif.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch, sys, os, random
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import numpy as np
|
4 |
+
import cv2
|
5 |
+
from multiprocessing import Process, Queue
|
6 |
+
from PIL import Image
|
7 |
+
from pillow_heif import register_heif_opener
|
8 |
+
import pillow_heif
|
9 |
+
|
10 |
+
root_path = os.path.abspath('.')
|
11 |
+
sys.path.append(root_path)
|
12 |
+
# Import files from the local folder
|
13 |
+
from opt import opt
|
14 |
+
from degradation.ESR.utils import tensor2np, np2tensor
|
15 |
+
|
16 |
+
|
17 |
+
|
18 |
+
|
19 |
+
class HEIF():
|
20 |
+
def __init__(self) -> None:
|
21 |
+
# Choose an image compression degradation
|
22 |
+
pass
|
23 |
+
|
24 |
+
def compress_and_store(self, np_frames, store_path):
|
25 |
+
''' Compress and Store the whole batch as HEIF (~ HEVC)
|
26 |
+
Args:
|
27 |
+
np_frames (numpy): The numpy format of the data (Shape:?)
|
28 |
+
store_path (str): The store path
|
29 |
+
Return:
|
30 |
+
None
|
31 |
+
'''
|
32 |
+
# Init call for heif
|
33 |
+
register_heif_opener()
|
34 |
+
|
35 |
+
single_frame = np_frames
|
36 |
+
|
37 |
+
# Prepare
|
38 |
+
essential_name = store_path.split('.')[0]
|
39 |
+
|
40 |
+
# Choose the quality
|
41 |
+
quality = random.randint(*opt['heif_quality_range1'])
|
42 |
+
method = random.randint(*opt['heif_encode_speed1'])
|
43 |
+
|
44 |
+
# Transform to PIL and then compress
|
45 |
+
PIL_image = Image.fromarray(np.uint8(single_frame[...,::-1])).convert('RGB')
|
46 |
+
PIL_image.save(essential_name+'.heic', quality=quality, method=method)
|
47 |
+
|
48 |
+
# Transform as png format
|
49 |
+
heif_file = pillow_heif.open_heif(essential_name+'.heic', convert_hdr_to_8bit=False, bgr_mode=True)
|
50 |
+
np_array = np.asarray(heif_file)
|
51 |
+
cv2.imwrite(store_path, np_array)
|
52 |
+
|
53 |
+
os.remove(essential_name+'.heic')
|
54 |
+
|
55 |
+
|
56 |
+
@staticmethod
|
57 |
+
def compress_tensor(tensor_frames, idx=0):
|
58 |
+
''' Compress tensor input to HEIF and then return it
|
59 |
+
Args:
|
60 |
+
tensor_frame (tensor): Tensor inputs
|
61 |
+
Returns:
|
62 |
+
result (tensor): Tensor outputs (same shape as input)
|
63 |
+
'''
|
64 |
+
|
65 |
+
# Init call for heif
|
66 |
+
register_heif_opener()
|
67 |
+
|
68 |
+
# Prepare
|
69 |
+
single_frame = tensor2np(tensor_frames)
|
70 |
+
essential_name = "tmp/temp_"+str(idx)
|
71 |
+
|
72 |
+
# Choose the quality
|
73 |
+
quality = random.randint(*opt['heif_quality_range1'])
|
74 |
+
method = random.randint(*opt['heif_encode_speed1'])
|
75 |
+
|
76 |
+
# Transform to PIL and then compress
|
77 |
+
PIL_image = Image.fromarray(np.uint8(single_frame[...,::-1])).convert('RGB')
|
78 |
+
PIL_image.save(essential_name+'.heic', quality=quality, method=method)
|
79 |
+
|
80 |
+
# Transform as png format
|
81 |
+
heif_file = pillow_heif.open_heif(essential_name+'.heic', convert_hdr_to_8bit=False, bgr_mode=True)
|
82 |
+
decimg = np.asarray(heif_file)
|
83 |
+
os.remove(essential_name+'.heic')
|
84 |
+
|
85 |
+
# Read back
|
86 |
+
result = np2tensor(decimg)
|
87 |
+
|
88 |
+
return result
|
89 |
+
|
90 |
+
|
degradation/image_compression/jpeg.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys, os, random
|
2 |
+
import cv2, torch
|
3 |
+
from multiprocessing import Process, Queue
|
4 |
+
|
5 |
+
root_path = os.path.abspath('.')
|
6 |
+
sys.path.append(root_path)
|
7 |
+
# Import files from the local folder
|
8 |
+
from opt import opt
|
9 |
+
from degradation.ESR.utils import tensor2np, np2tensor
|
10 |
+
|
11 |
+
|
12 |
+
|
13 |
+
class JPEG():
|
14 |
+
def __init__(self) -> None:
|
15 |
+
# Choose an image compression degradation
|
16 |
+
# self.jpeger = DiffJPEG(differentiable=False).cuda()
|
17 |
+
pass
|
18 |
+
|
19 |
+
def compress_and_store(self, np_frames, store_path, idx):
|
20 |
+
''' Compress and Store the whole batch as JPEG
|
21 |
+
Args:
|
22 |
+
np_frames (numpy): The numpy format of the data (Shape:?)
|
23 |
+
store_path (str): The store path
|
24 |
+
Return:
|
25 |
+
None
|
26 |
+
'''
|
27 |
+
|
28 |
+
# Preparation
|
29 |
+
single_frame = np_frames
|
30 |
+
|
31 |
+
# Compress as JPEG
|
32 |
+
jpeg_quality = random.randint(*opt['jpeg_quality_range2'])
|
33 |
+
|
34 |
+
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), jpeg_quality]
|
35 |
+
_, encimg = cv2.imencode('.jpg', single_frame, encode_param)
|
36 |
+
decimg = cv2.imdecode(encimg, 1)
|
37 |
+
|
38 |
+
# Store the image with quality
|
39 |
+
cv2.imwrite(store_path, decimg)
|
40 |
+
|
41 |
+
|
42 |
+
|
43 |
+
@staticmethod
|
44 |
+
def compress_tensor(tensor_frames):
|
45 |
+
''' Compress tensor input to JPEG and then return it
|
46 |
+
Args:
|
47 |
+
tensor_frame (tensor): Tensor inputs
|
48 |
+
Returns:
|
49 |
+
result (tensor): Tensor outputs (same shape as input)
|
50 |
+
'''
|
51 |
+
|
52 |
+
single_frame = tensor2np(tensor_frames)
|
53 |
+
|
54 |
+
# Compress as JPEG
|
55 |
+
jpeg_quality = random.randint(*opt['jpeg_quality_range1'])
|
56 |
+
|
57 |
+
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), jpeg_quality]
|
58 |
+
_, encimg = cv2.imencode('.jpg', single_frame, encode_param)
|
59 |
+
decimg = cv2.imdecode(encimg, 1)
|
60 |
+
|
61 |
+
# Store the image with quality
|
62 |
+
# cv2.imwrite(store_name, decimg)
|
63 |
+
result = np2tensor(decimg)
|
64 |
+
|
65 |
+
return result
|
66 |
+
|
67 |
+
|
68 |
+
|
degradation/image_compression/webp.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch, sys, os, random
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import numpy as np
|
4 |
+
import cv2
|
5 |
+
from multiprocessing import Process, Queue
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
root_path = os.path.abspath('.')
|
9 |
+
sys.path.append(root_path)
|
10 |
+
# Import files from the local folder
|
11 |
+
from opt import opt
|
12 |
+
from degradation.ESR.utils import tensor2np, np2tensor
|
13 |
+
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
class WEBP():
|
18 |
+
def __init__(self) -> None:
|
19 |
+
# Choose an image compression degradation
|
20 |
+
pass
|
21 |
+
|
22 |
+
def compress_and_store(self, np_frames, store_path, idx):
|
23 |
+
''' Compress and Store the whole batch as WebP (~ VP8)
|
24 |
+
Args:
|
25 |
+
np_frames (numpy): The numpy format of the data (Shape:?)
|
26 |
+
store_path (str): The store path
|
27 |
+
Return:
|
28 |
+
None
|
29 |
+
'''
|
30 |
+
single_frame = np_frames
|
31 |
+
|
32 |
+
# Choose the quality
|
33 |
+
quality = random.randint(*opt['webp_quality_range2'])
|
34 |
+
method = random.randint(*opt['webp_encode_speed2'])
|
35 |
+
|
36 |
+
# Transform to PIL and then compress
|
37 |
+
PIL_image = Image.fromarray(np.uint8(single_frame[...,::-1])).convert('RGB')
|
38 |
+
PIL_image.save(store_path, 'webp', quality=quality, method=method)
|
39 |
+
|
40 |
+
|
41 |
+
@staticmethod
|
42 |
+
def compress_tensor(tensor_frames, idx = 0):
|
43 |
+
''' Compress tensor input to WEBP and then return it
|
44 |
+
Args:
|
45 |
+
tensor_frame (tensor): Tensor inputs
|
46 |
+
Returns:
|
47 |
+
result (tensor): Tensor outputs (same shape as input)
|
48 |
+
'''
|
49 |
+
single_frame = tensor2np(tensor_frames)
|
50 |
+
|
51 |
+
# Choose the quality
|
52 |
+
quality = random.randint(*opt['webp_quality_range1'])
|
53 |
+
method = random.randint(*opt['webp_encode_speed1'])
|
54 |
+
|
55 |
+
# Transform to PIL and then compress
|
56 |
+
PIL_image = Image.fromarray(np.uint8(single_frame[...,::-1])).convert('RGB')
|
57 |
+
store_path = os.path.join("tmp", "temp_"+str(idx)+".webp")
|
58 |
+
PIL_image.save(store_path, 'webp', quality=quality, method=method)
|
59 |
+
|
60 |
+
# Read back
|
61 |
+
decimg = cv2.imread(store_path)
|
62 |
+
result = np2tensor(decimg)
|
63 |
+
os.remove(store_path)
|
64 |
+
|
65 |
+
return result
|
degradation/video_compression/h264.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch, sys, os, random
|
2 |
+
import cv2
|
3 |
+
import shutil
|
4 |
+
|
5 |
+
root_path = os.path.abspath('.')
|
6 |
+
sys.path.append(root_path)
|
7 |
+
# Import files from the local folder
|
8 |
+
from opt import opt
|
9 |
+
|
10 |
+
|
11 |
+
|
12 |
+
class H264():
|
13 |
+
def __init__(self) -> None:
|
14 |
+
# Choose an image compression degradation
|
15 |
+
pass
|
16 |
+
|
17 |
+
def compress_and_store(self, single_frame, store_path, idx):
|
18 |
+
''' Compress and Store the whole batch as H.264 (for 2nd stage)
|
19 |
+
Args:
|
20 |
+
single_frame (numpy): The numpy format of the data (Shape:?)
|
21 |
+
store_path (str): The store path
|
22 |
+
idx (int): A unique process idx
|
23 |
+
Return:
|
24 |
+
None
|
25 |
+
'''
|
26 |
+
|
27 |
+
# Prepare
|
28 |
+
temp_input_path = "tmp/input_"+str(idx)
|
29 |
+
video_store_dir = "tmp/encoded_"+str(idx)+".mp4"
|
30 |
+
temp_store_path = "tmp/output_"+str(idx)
|
31 |
+
os.makedirs(temp_input_path)
|
32 |
+
os.makedirs(temp_store_path)
|
33 |
+
|
34 |
+
# Move frame
|
35 |
+
cv2.imwrite(os.path.join(temp_input_path, "1.png"), single_frame)
|
36 |
+
|
37 |
+
|
38 |
+
# Decide the quality
|
39 |
+
crf = str(random.randint(*opt['h264_crf_range2']))
|
40 |
+
preset = random.choices(opt['h264_preset_mode2'], opt['h264_preset_prob2'])[0]
|
41 |
+
|
42 |
+
# Encode
|
43 |
+
ffmpeg_encode_cmd = "ffmpeg -i " + temp_input_path + "/%d.png -vcodec libx264 -crf " + crf + " -preset " + preset + " -pix_fmt yuv420p " + video_store_dir + " -loglevel 0"
|
44 |
+
os.system(ffmpeg_encode_cmd)
|
45 |
+
|
46 |
+
|
47 |
+
# Decode
|
48 |
+
ffmpeg_decode_cmd = "ffmpeg -i " + video_store_dir + " " + temp_store_path + "/%d.png -loglevel 0"
|
49 |
+
os.system(ffmpeg_decode_cmd)
|
50 |
+
if len(os.listdir(temp_store_path)) != 1:
|
51 |
+
print("This is strange")
|
52 |
+
assert(len(os.listdir(temp_store_path)) == 1)
|
53 |
+
|
54 |
+
# Move frame to the target places
|
55 |
+
shutil.copy(os.path.join(temp_store_path, "1.png"), store_path)
|
56 |
+
|
57 |
+
# Clean temp files
|
58 |
+
os.remove(video_store_dir)
|
59 |
+
shutil.rmtree(temp_input_path)
|
60 |
+
shutil.rmtree(temp_store_path)
|
61 |
+
|
62 |
+
|
63 |
+
|
64 |
+
@staticmethod
|
65 |
+
def compress_tensor(tensor_frames, idx=0):
|
66 |
+
''' Compress tensor input to H.264 and then return it (for 1st stage)
|
67 |
+
Args:
|
68 |
+
tensor_frame (tensor): Tensor inputs
|
69 |
+
Returns:
|
70 |
+
result (tensor): Tensor outputs (same shape as input)
|
71 |
+
'''
|
72 |
+
|
73 |
+
pass
|