0x90e commited on
Commit
0d5c9d2
·
1 Parent(s): 0747f10

Upscaling works now

Browse files
Files changed (7) hide show
  1. README.md +174 -10
  2. RRDBNet_arch.py +0 -78
  3. architecture.py +38 -0
  4. block.py +261 -0
  5. models/README.md +0 -9
  6. test.py +21 -25
  7. transer_RRDB_models.py +0 -55
README.md CHANGED
@@ -1,10 +1,174 @@
1
- ---
2
- title: ESRGAN MANGA
3
- emoji: 🏃
4
- colorFrom: red
5
- colorTo: indigo
6
- sdk: gradio
7
- sdk_version: 3.12.0
8
- app_file: app.py
9
- pinned: false
10
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ESRGAN (Enhanced SRGAN) [[Paper]](https://arxiv.org/abs/1809.00219) [[BasicSR]](https://github.com/xinntao/BasicSR)
2
+ ## :smiley: Training codes are in [BasicSR](https://github.com/xinntao/BasicSR) repo.
3
+ ### Enhanced Super-Resolution Generative Adversarial Networks
4
+ By Xintao Wang, [Ke Yu](https://yuke93.github.io/), Shixiang Wu, [Jinjin Gu](http://www.jasongt.com/), Yihao Liu, [Chao Dong](https://scholar.google.com.hk/citations?user=OSDCB0UAAAAJ&hl=en), [Yu Qiao](http://mmlab.siat.ac.cn/yuqiao/), [Chen Change Loy](http://personal.ie.cuhk.edu.hk/~ccloy/)
5
+
6
+ This repo only provides simple testing codes, pretrained models and the network strategy demo.
7
+
8
+ ### **For full training and testing codes, please refer to [BasicSR](https://github.com/xinntao/BasicSR).**
9
+
10
+ We won the first place in [PIRM2018-SR competition](https://www.pirm2018.org/PIRM-SR.html) (region 3) and got the best perceptual index.
11
+ The paper is accepted to [ECCV2018 PIRM Workshop](https://pirm2018.org/).
12
+
13
+ :triangular_flag_on_post: Add [Frequently Asked Questions](https://github.com/xinntao/ESRGAN/blob/master/QA.md).
14
+
15
+ > For instance,
16
+ > 1. How to reproduce your results in the PIRM18-SR Challenge (with low perceptual index)?
17
+ > 2. How do you get the perceptual index in your ESRGAN paper?
18
+
19
+ #### BibTeX
20
+ <!--
21
+ @article{wang2018esrgan,
22
+ author={Wang, Xintao and Yu, Ke and Wu, Shixiang and Gu, Jinjin and Liu, Yihao and Dong, Chao and Loy, Chen Change and Qiao, Yu and Tang, Xiaoou},
23
+ title={ESRGAN: Enhanced super-resolution generative adversarial networks},
24
+ journal={arXiv preprint arXiv:1809.00219},
25
+ year={2018}
26
+ }
27
+ -->
28
+ @InProceedings{wang2018esrgan,
29
+ author = {Wang, Xintao and Yu, Ke and Wu, Shixiang and Gu, Jinjin and Liu, Yihao and Dong, Chao and Qiao, Yu and Loy, Chen Change},
30
+ title = {ESRGAN: Enhanced super-resolution generative adversarial networks},
31
+ booktitle = {The European Conference on Computer Vision Workshops (ECCVW)},
32
+ month = {September},
33
+ year = {2018}
34
+ }
35
+
36
+ <p align="center">
37
+ <img src="figures/baboon.jpg">
38
+ </p>
39
+
40
+ The **RRDB_PSNR** PSNR_oriented model trained with DF2K dataset (a merged dataset with [DIV2K](https://data.vision.ee.ethz.ch/cvl/DIV2K/) and [Flickr2K](http://cv.snu.ac.kr/research/EDSR/Flickr2K.tar) (proposed in [EDSR](https://github.com/LimBee/NTIRE2017))) is also able to achive high PSNR performance.
41
+
42
+ | <sub>Method</sub> | <sub>Training dataset</sub> | <sub>Set5</sub> | <sub>Set14</sub> | <sub>BSD100</sub> | <sub>Urban100</sub> | <sub>Manga109</sub> |
43
+ |:---:|:---:|:---:|:---:|:---:|:---:|:---:|
44
+ | <sub>[SRCNN](http://mmlab.ie.cuhk.edu.hk/projects/SRCNN.html)</sub>| <sub>291</sub>| <sub>30.48/0.8628</sub> |<sub>27.50/0.7513</sub>|<sub>26.90/0.7101</sub>|<sub>24.52/0.7221</sub>|<sub>27.58/0.8555</sub>|
45
+ | <sub>[EDSR](https://github.com/thstkdgus35/EDSR-PyTorch)</sub> | <sub>DIV2K</sub> | <sub>32.46/0.8968</sub> | <sub>28.80/0.7876</sub> | <sub>27.71/0.7420</sub> | <sub>26.64/0.8033</sub> | <sub>31.02/0.9148</sub> |
46
+ | <sub>[RCAN](https://github.com/yulunzhang/RCAN)</sub> | <sub>DIV2K</sub> | <sub>32.63/0.9002</sub> | <sub>28.87/0.7889</sub> | <sub>27.77/0.7436</sub> | <sub>26.82/ 0.8087</sub>| <sub>31.22/ 0.9173</sub>|
47
+ |<sub>RRDB(ours)</sub>| <sub>DF2K</sub>| <sub>**32.73/0.9011**</sub> |<sub>**28.99/0.7917**</sub> |<sub>**27.85/0.7455**</sub> |<sub>**27.03/0.8153**</sub> |<sub>**31.66/0.9196**</sub>|
48
+
49
+ ## Quick Test
50
+ #### Dependencies
51
+ - Python 3
52
+ - [PyTorch >= 0.4](https://pytorch.org/) (CUDA version >= 7.5 if installing with CUDA. [More details](https://pytorch.org/get-started/previous-versions/))
53
+ - Python packages: `pip install numpy opencv-python`
54
+
55
+ ### Test models
56
+ 1. Clone this github repo.
57
+ ```
58
+ git clone https://github.com/xinntao/ESRGAN
59
+ cd ESRGAN
60
+ ```
61
+ 2. Place your own **low-resolution images** in `./LR` folder. (There are two sample images - baboon and comic).
62
+ 3. Download pretrained models from [Google Drive](https://drive.google.com/drive/u/0/folders/17VYV_SoZZesU6mbxz2dMAIccSSlqLecY) or [Baidu Drive](https://pan.baidu.com/s/1-Lh6ma-wXzfH8NqeBtPaFQ). Place the models in `./models`. We provide two models with high perceptual quality and high PSNR performance (see [model list](https://github.com/xinntao/ESRGAN/tree/master/models)).
63
+ 4. Run test. We provide ESRGAN model and RRDB_PSNR model.
64
+ ```
65
+ python test.py models/RRDB_ESRGAN_x4.pth
66
+ python test.py models/RRDB_PSNR_x4.pth
67
+ ```
68
+ 5. The results are in `./results` folder.
69
+ ### Network interpolation demo
70
+ You can interpolate the RRDB_ESRGAN and RRDB_PSNR models with alpha in [0, 1].
71
+
72
+ 1. Run `python net_interp.py 0.8`, where *0.8* is the interpolation parameter and you can change it to any value in [0,1].
73
+ 2. Run `python test.py models/interp_08.pth`, where *models/interp_08.pth* is the model path.
74
+
75
+ <p align="center">
76
+ <img height="400" src="figures/43074.gif">
77
+ </p>
78
+
79
+ ## Perceptual-driven SR Results
80
+
81
+ You can download all the resutls from [Google Drive](https://drive.google.com/drive/folders/1iaM-c6EgT1FNoJAOKmDrK7YhEhtlKcLx?usp=sharing). (:heavy_check_mark: included; :heavy_minus_sign: not included; :o: TODO)
82
+
83
+ HR images can be downloaed from [BasicSR-Datasets](https://github.com/xinntao/BasicSR#datasets).
84
+
85
+ | Datasets |LR | [*ESRGAN*](https://arxiv.org/abs/1809.00219) | [SRGAN](https://arxiv.org/abs/1609.04802) | [EnhanceNet](http://openaccess.thecvf.com/content_ICCV_2017/papers/Sajjadi_EnhanceNet_Single_Image_ICCV_2017_paper.pdf) | [CX](https://arxiv.org/abs/1803.04626) |
86
+ |:---:|:---:|:---:|:---:|:---:|:---:|
87
+ | Set5 |:heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark:| :o: |
88
+ | Set14 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark:| :o: |
89
+ | BSDS100 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark:| :o: |
90
+ | [PIRM](https://pirm.github.io/) <br><sup>(val, test)</sup> | :heavy_check_mark: | :heavy_check_mark: | :heavy_minus_sign: | :heavy_check_mark:| :heavy_check_mark: |
91
+ | [OST300](https://arxiv.org/pdf/1804.02815.pdf) |:heavy_check_mark: | :heavy_check_mark: | :heavy_minus_sign: | :heavy_check_mark:| :o: |
92
+ | urban100 | :heavy_check_mark: | :heavy_check_mark: | :heavy_minus_sign: | :heavy_check_mark:| :o: |
93
+ | [DIV2K](https://data.vision.ee.ethz.ch/cvl/DIV2K/) <br><sup>(val, test)</sup> | :heavy_check_mark: | :heavy_check_mark: | :heavy_minus_sign: | :heavy_check_mark:| :o: |
94
+
95
+ ## ESRGAN
96
+ We improve the [SRGAN](https://arxiv.org/abs/1609.04802) from three aspects:
97
+ 1. adopt a deeper model using Residual-in-Residual Dense Block (RRDB) without batch normalization layers.
98
+ 2. employ [Relativistic average GAN](https://ajolicoeur.wordpress.com/relativisticgan/) instead of the vanilla GAN.
99
+ 3. improve the perceptual loss by using the features before activation.
100
+
101
+ In contrast to SRGAN, which claimed that **deeper models are increasingly difficult to train**, our deeper ESRGAN model shows its superior performance with easy training.
102
+
103
+ <p align="center">
104
+ <img height="120" src="figures/architecture.jpg">
105
+ </p>
106
+ <p align="center">
107
+ <img height="180" src="figures/RRDB.png">
108
+ </p>
109
+
110
+ ## Network Interpolation
111
+ We propose the **network interpolation strategy** to balance the visual quality and PSNR.
112
+
113
+ <p align="center">
114
+ <img height="500" src="figures/net_interp.jpg">
115
+ </p>
116
+
117
+ We show the smooth animation with the interpolation parameters changing from 0 to 1.
118
+ Interestingly, it is observed that the network interpolation strategy provides a smooth control of the RRDB_PSNR model and the fine-tuned ESRGAN model.
119
+
120
+ <p align="center">
121
+ <img height="480" src="figures/81.gif">
122
+ &nbsp &nbsp
123
+ <img height="480" src="figures/102061.gif">
124
+ </p>
125
+
126
+ ## Qualitative Results
127
+ PSNR (evaluated on the Y channel) and the perceptual index used in the PIRM-SR challenge are also provided for reference.
128
+
129
+ <p align="center">
130
+ <img src="figures/qualitative_cmp_01.jpg">
131
+ </p>
132
+ <p align="center">
133
+ <img src="figures/qualitative_cmp_02.jpg">
134
+ </p>
135
+ <p align="center">
136
+ <img src="figures/qualitative_cmp_03.jpg">
137
+ </p>
138
+ <p align="center">
139
+ <img src="figures/qualitative_cmp_04.jpg">
140
+ </p>
141
+
142
+ ## Ablation Study
143
+ Overall visual comparisons for showing the effects of each component in
144
+ ESRGAN. Each column represents a model with its configurations in the top.
145
+ The red sign indicates the main improvement compared with the previous model.
146
+ <p align="center">
147
+ <img src="figures/abalation_study.png">
148
+ </p>
149
+
150
+ ## BN artifacts
151
+ We empirically observe that BN layers tend to bring artifacts. These artifacts,
152
+ namely BN artifacts, occasionally appear among iterations and different settings,
153
+ violating the needs for a stable performance over training. We find that
154
+ the network depth, BN position, training dataset and training loss
155
+ have impact on the occurrence of BN artifacts.
156
+ <p align="center">
157
+ <img src="figures/BN_artifacts.jpg">
158
+ </p>
159
+
160
+ ## Useful techniques to train a very deep network
161
+ We find that residual scaling and smaller initialization can help to train a very deep network. More details are in the Supplementary File attached in our [paper](https://arxiv.org/abs/1809.00219).
162
+
163
+ <p align="center">
164
+ <img height="250" src="figures/train_deeper_neta.png">
165
+ <img height="250" src="figures/train_deeper_netb.png">
166
+ </p>
167
+
168
+ ## The influence of training patch size
169
+ We observe that training a deeper network benefits from a larger patch size. Moreover, the deeper model achieves more improvement (∼0.12dB) than the shallower one (∼0.04dB) since larger model capacity is capable of taking full advantage of
170
+ larger training patch size. (Evaluated on Set5 dataset with RGB channels.)
171
+ <p align="center">
172
+ <img height="250" src="figures/patch_a.png">
173
+ <img height="250" src="figures/patch_b.png">
174
+ </p>
RRDBNet_arch.py DELETED
@@ -1,78 +0,0 @@
1
- import functools
2
- import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
-
6
-
7
- def make_layer(block, n_layers):
8
- layers = []
9
- for _ in range(n_layers):
10
- layers.append(block())
11
- return nn.Sequential(*layers)
12
-
13
-
14
- class ResidualDenseBlock_5C(nn.Module):
15
- def __init__(self, nf=64, gc=32, bias=True):
16
- super(ResidualDenseBlock_5C, self).__init__()
17
- # gc: growth channel, i.e. intermediate channels
18
- self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
19
- self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
20
- self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
21
- self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
22
- self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
23
- self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
24
-
25
- # initialization
26
- # mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
27
-
28
- def forward(self, x):
29
- x1 = self.lrelu(self.conv1(x))
30
- x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
31
- x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
32
- x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
33
- x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
34
- return x5 * 0.2 + x
35
-
36
-
37
- class RRDB(nn.Module):
38
- '''Residual in Residual Dense Block'''
39
-
40
- def __init__(self, nf, gc=32):
41
- super(RRDB, self).__init__()
42
- self.RDB1 = ResidualDenseBlock_5C(nf, gc)
43
- self.RDB2 = ResidualDenseBlock_5C(nf, gc)
44
- self.RDB3 = ResidualDenseBlock_5C(nf, gc)
45
-
46
- def forward(self, x):
47
- out = self.RDB1(x)
48
- out = self.RDB2(out)
49
- out = self.RDB3(out)
50
- return out * 0.2 + x
51
-
52
-
53
- class RRDBNet(nn.Module):
54
- def __init__(self, in_nc, out_nc, nf, nb, gc=32):
55
- super(RRDBNet, self).__init__()
56
- RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
57
-
58
- self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
59
- self.RRDB_trunk = make_layer(RRDB_block_f, nb)
60
- self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
61
- #### upsampling
62
- self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
63
- self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
64
- self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
65
- self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
66
-
67
- self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
68
-
69
- def forward(self, x):
70
- fea = self.conv_first(x)
71
- trunk = self.trunk_conv(self.RRDB_trunk(fea))
72
- fea = fea + trunk
73
-
74
- fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
75
- fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
76
- out = self.conv_last(self.lrelu(self.HRconv(fea)))
77
-
78
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
architecture.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import block as B
5
+
6
+
7
+ class RRDB_Net(nn.Module):
8
+ def __init__(self, in_nc, out_nc, nf, nb, gc=32, upscale=4, norm_type=None, act_type='leakyrelu', \
9
+ mode='CNA', res_scale=1, upsample_mode='upconv'):
10
+ super(RRDB_Net, self).__init__()
11
+ n_upscale = int(math.log(upscale, 2))
12
+ if upscale == 3:
13
+ n_upscale = 1
14
+
15
+ fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None)
16
+ rb_blocks = [B.RRDB(nf, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
17
+ norm_type=norm_type, act_type=act_type, mode='CNA') for _ in range(nb)]
18
+ LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode)
19
+
20
+ if upsample_mode == 'upconv':
21
+ upsample_block = B.upconv_blcok
22
+ elif upsample_mode == 'pixelshuffle':
23
+ upsample_block = B.pixelshuffle_block
24
+ else:
25
+ raise NotImplementedError('upsample mode [%s] is not found' % upsample_mode)
26
+ if upscale == 3:
27
+ upsampler = upsample_block(nf, nf, 3, act_type=act_type)
28
+ else:
29
+ upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)]
30
+ HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)
31
+ HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None)
32
+
33
+ self.model = B.sequential(fea_conv, B.ShortcutBlock(B.sequential(*rb_blocks, LR_conv)),\
34
+ *upsampler, HR_conv0, HR_conv1)
35
+
36
+ def forward(self, x):
37
+ x = self.model(x)
38
+ return x
block.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ ####################
6
+ # Basic blocks
7
+ ####################
8
+
9
+
10
+ def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1):
11
+ # helper selecting activation
12
+ # neg_slope: for leakyrelu and init of prelu
13
+ # n_prelu: for p_relu num_parameters
14
+ act_type = act_type.lower()
15
+ if act_type == 'relu':
16
+ layer = nn.ReLU(inplace)
17
+ elif act_type == 'leakyrelu':
18
+ layer = nn.LeakyReLU(neg_slope, inplace)
19
+ elif act_type == 'prelu':
20
+ layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
21
+ else:
22
+ raise NotImplementedError('activation layer [%s] is not found' % act_type)
23
+ return layer
24
+
25
+
26
+ def norm(norm_type, nc):
27
+ # helper selecting normalization layer
28
+ norm_type = norm_type.lower()
29
+ if norm_type == 'batch':
30
+ layer = nn.BatchNorm2d(nc, affine=True)
31
+ elif norm_type == 'instance':
32
+ layer = nn.InstanceNorm2d(nc, affine=False)
33
+ else:
34
+ raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
35
+ return layer
36
+
37
+
38
+ def pad(pad_type, padding):
39
+ # helper selecting padding layer
40
+ # if padding is 'zero', do by conv layers
41
+ pad_type = pad_type.lower()
42
+ if padding == 0:
43
+ return None
44
+ if pad_type == 'reflect':
45
+ layer = nn.ReflectionPad2d(padding)
46
+ elif pad_type == 'replicate':
47
+ layer = nn.ReplicationPad2d(padding)
48
+ else:
49
+ raise NotImplementedError('padding layer [%s] is not implemented' % pad_type)
50
+ return layer
51
+
52
+
53
+ def get_valid_padding(kernel_size, dilation):
54
+ kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
55
+ padding = (kernel_size - 1) // 2
56
+ return padding
57
+
58
+
59
+ class ConcatBlock(nn.Module):
60
+ # Concat the output of a submodule to its input
61
+ def __init__(self, submodule):
62
+ super(ConcatBlock, self).__init__()
63
+ self.sub = submodule
64
+
65
+ def forward(self, x):
66
+ output = torch.cat((x, self.sub(x)), dim=1)
67
+ return output
68
+
69
+ def __repr__(self):
70
+ tmpstr = 'Identity .. \n|'
71
+ modstr = self.sub.__repr__().replace('\n', '\n|')
72
+ tmpstr = tmpstr + modstr
73
+ return tmpstr
74
+
75
+
76
+ class ShortcutBlock(nn.Module):
77
+ #Elementwise sum the output of a submodule to its input
78
+ def __init__(self, submodule):
79
+ super(ShortcutBlock, self).__init__()
80
+ self.sub = submodule
81
+
82
+ def forward(self, x):
83
+ output = x + self.sub(x)
84
+ return output
85
+
86
+ def __repr__(self):
87
+ tmpstr = 'Identity + \n|'
88
+ modstr = self.sub.__repr__().replace('\n', '\n|')
89
+ tmpstr = tmpstr + modstr
90
+ return tmpstr
91
+
92
+
93
+ def sequential(*args):
94
+ # Flatten Sequential. It unwraps nn.Sequential.
95
+ if len(args) == 1:
96
+ if isinstance(args[0], OrderedDict):
97
+ raise NotImplementedError('sequential does not support OrderedDict input.')
98
+ return args[0] # No sequential is needed.
99
+ modules = []
100
+ for module in args:
101
+ if isinstance(module, nn.Sequential):
102
+ for submodule in module.children():
103
+ modules.append(submodule)
104
+ elif isinstance(module, nn.Module):
105
+ modules.append(module)
106
+ return nn.Sequential(*modules)
107
+
108
+
109
+ def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True,
110
+ pad_type='zero', norm_type=None, act_type='relu', mode='CNA'):
111
+ """
112
+ Conv layer with padding, normalization, activation
113
+ mode: CNA --> Conv -> Norm -> Act
114
+ NAC --> Norm -> Act --> Conv (Identity Mappings in Deep Residual Networks, ECCV16)
115
+ """
116
+ assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [%s]' % mode
117
+ padding = get_valid_padding(kernel_size, dilation)
118
+ p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
119
+ padding = padding if pad_type == 'zero' else 0
120
+
121
+ c = nn.Conv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, \
122
+ dilation=dilation, bias=bias, groups=groups)
123
+ a = act(act_type) if act_type else None
124
+ if 'CNA' in mode:
125
+ n = norm(norm_type, out_nc) if norm_type else None
126
+ return sequential(p, c, n, a)
127
+ elif mode == 'NAC':
128
+ if norm_type is None and act_type is not None:
129
+ a = act(act_type, inplace=False)
130
+ # Important!
131
+ # input----ReLU(inplace)----Conv--+----output
132
+ # |________________________|
133
+ # inplace ReLU will modify the input, therefore wrong output
134
+ n = norm(norm_type, in_nc) if norm_type else None
135
+ return sequential(n, a, p, c)
136
+
137
+
138
+ ####################
139
+ # Useful blocks
140
+ ####################
141
+
142
+
143
+ class ResNetBlock(nn.Module):
144
+ """
145
+ ResNet Block, 3-3 style
146
+ with extra residual scaling used in EDSR
147
+ (Enhanced Deep Residual Networks for Single Image Super-Resolution, CVPRW 17)
148
+ """
149
+
150
+ def __init__(self, in_nc, mid_nc, out_nc, kernel_size=3, stride=1, dilation=1, groups=1, \
151
+ bias=True, pad_type='zero', norm_type=None, act_type='relu', mode='CNA', res_scale=1):
152
+ super(ResNetBlock, self).__init__()
153
+ conv0 = conv_block(in_nc, mid_nc, kernel_size, stride, dilation, groups, bias, pad_type, \
154
+ norm_type, act_type, mode)
155
+ if mode == 'CNA':
156
+ act_type = None
157
+ if mode == 'CNAC': # Residual path: |-CNAC-|
158
+ act_type = None
159
+ norm_type = None
160
+ conv1 = conv_block(mid_nc, out_nc, kernel_size, stride, dilation, groups, bias, pad_type, \
161
+ norm_type, act_type, mode)
162
+ # if in_nc != out_nc:
163
+ # self.project = conv_block(in_nc, out_nc, 1, stride, dilation, 1, bias, pad_type, \
164
+ # None, None)
165
+ # print('Need a projecter in ResNetBlock.')
166
+ # else:
167
+ # self.project = lambda x:x
168
+ self.res = sequential(conv0, conv1)
169
+ self.res_scale = res_scale
170
+
171
+ def forward(self, x):
172
+ res = self.res(x).mul(self.res_scale)
173
+ return x + res
174
+
175
+
176
+ class ResidualDenseBlock_5C(nn.Module):
177
+ """
178
+ Residual Dense Block
179
+ style: 5 convs
180
+ The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
181
+ """
182
+
183
+ def __init__(self, nc, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
184
+ norm_type=None, act_type='leakyrelu', mode='CNA'):
185
+ super(ResidualDenseBlock_5C, self).__init__()
186
+ # gc: growth channel, i.e. intermediate channels
187
+ self.conv1 = conv_block(nc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
188
+ norm_type=norm_type, act_type=act_type, mode=mode)
189
+ self.conv2 = conv_block(nc+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
190
+ norm_type=norm_type, act_type=act_type, mode=mode)
191
+ self.conv3 = conv_block(nc+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
192
+ norm_type=norm_type, act_type=act_type, mode=mode)
193
+ self.conv4 = conv_block(nc+3*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
194
+ norm_type=norm_type, act_type=act_type, mode=mode)
195
+ if mode == 'CNA':
196
+ last_act = None
197
+ else:
198
+ last_act = act_type
199
+ self.conv5 = conv_block(nc+4*gc, nc, 3, stride, bias=bias, pad_type=pad_type, \
200
+ norm_type=norm_type, act_type=last_act, mode=mode)
201
+
202
+ def forward(self, x):
203
+ x1 = self.conv1(x)
204
+ x2 = self.conv2(torch.cat((x, x1), 1))
205
+ x3 = self.conv3(torch.cat((x, x1, x2), 1))
206
+ x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
207
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
208
+ return x5.mul(0.2) + x
209
+
210
+
211
+ class RRDB(nn.Module):
212
+ """
213
+ Residual in Residual Dense Block
214
+ """
215
+
216
+ def __init__(self, nc, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
217
+ norm_type=None, act_type='leakyrelu', mode='CNA'):
218
+ super(RRDB, self).__init__()
219
+ self.RDB1 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, \
220
+ norm_type, act_type, mode)
221
+ self.RDB2 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, \
222
+ norm_type, act_type, mode)
223
+ self.RDB3 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, \
224
+ norm_type, act_type, mode)
225
+
226
+ def forward(self, x):
227
+ out = self.RDB1(x)
228
+ out = self.RDB2(out)
229
+ out = self.RDB3(out)
230
+ return out.mul(0.2) + x
231
+
232
+
233
+ ####################
234
+ # Upsampler
235
+ ####################
236
+
237
+
238
+ def pixelshuffle_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
239
+ pad_type='zero', norm_type=None, act_type='relu'):
240
+ """
241
+ Pixel shuffle layer
242
+ (Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
243
+ Neural Network, CVPR17)
244
+ """
245
+ conv = conv_block(in_nc, out_nc * (upscale_factor ** 2), kernel_size, stride, bias=bias,
246
+ pad_type=pad_type, norm_type=None, act_type=None)
247
+ pixel_shuffle = nn.PixelShuffle(upscale_factor)
248
+
249
+ n = norm(norm_type, out_nc) if norm_type else None
250
+ a = act(act_type) if act_type else None
251
+ return sequential(conv, pixel_shuffle, n, a)
252
+
253
+
254
+ def upconv_blcok(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
255
+ pad_type='zero', norm_type=None, act_type='relu', mode='nearest'):
256
+ # Up conv
257
+ # described in https://distill.pub/2016/deconv-checkerboard/
258
+ upsample = nn.Upsample(scale_factor=upscale_factor, mode=mode)
259
+ conv = conv_block(in_nc, out_nc, kernel_size, stride, bias=bias,
260
+ pad_type=pad_type, norm_type=norm_type, act_type=act_type)
261
+ return sequential(upsample, conv)
models/README.md DELETED
@@ -1,9 +0,0 @@
1
- ## Place pretrained models here.
2
-
3
- We provide two pretrained models:
4
-
5
- 1. `RRDB_ESRGAN_x4.pth`: the final ESRGAN model we used in our [paper](https://arxiv.org/abs/1809.00219).
6
- 2. `RRDB_PSNR_x4.pth`: the PSNR-oriented model with **high PSNR performance**.
7
-
8
- *Note that* the pretrained models are trained under the `MATLAB bicubic` kernel.
9
- If the downsampled kernel is different from that, the results may have artifacts.
 
 
 
 
 
 
 
 
 
 
test.py CHANGED
@@ -1,37 +1,33 @@
1
- import os.path as osp
 
2
  import glob
3
  import cv2
4
  import numpy as np
5
  import torch
6
- import RRDBNet_arch as arch
7
 
8
- model_path = 'models/RRDB_ESRGAN_x4.pth' # models/RRDB_ESRGAN_x4.pth OR models/RRDB_PSNR_x4.pth
9
- device = torch.device('cuda') # if you want to run on CPU, change 'cuda' -> cpu
10
- # device = torch.device('cpu')
11
 
12
- test_img_folder = 'LR/*'
13
-
14
- model = arch.RRDBNet(3, 3, 64, 23, gc=32)
15
  model.load_state_dict(torch.load(model_path), strict=True)
16
  model.eval()
 
 
 
17
  model = model.to(device)
18
 
19
- print('Model path {:s}. \nTesting...'.format(model_path))
20
 
21
- idx = 0
22
- for path in glob.glob(test_img_folder):
23
- idx += 1
24
- base = osp.splitext(osp.basename(path))[0]
25
- print(idx, base)
26
- # read images
27
- img = cv2.imread(path, cv2.IMREAD_COLOR)
28
- img = img * 1.0 / 255
29
- img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
30
- img_LR = img.unsqueeze(0)
31
- img_LR = img_LR.to(device)
32
 
33
- with torch.no_grad():
34
- output = model(img_LR).data.squeeze().float().cpu().clamp_(0, 1).numpy()
35
- output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))
36
- output = (output * 255.0).round()
37
- cv2.imwrite('results/{:s}_rlt.png'.format(base), output)
 
1
+ import sys
2
+ import os.path
3
  import glob
4
  import cv2
5
  import numpy as np
6
  import torch
7
+ import architecture as arch
8
 
9
+ model_path = '4x_eula_digimanga_bw_v2_nc1_307k.pth'
10
+ img_path = sys.argv[1]
11
+ device = torch.device('cpu')
12
 
13
+ model = arch.RRDB_Net(1, 1, 64, 23, gc=32, upscale=4, norm_type=None, act_type='leakyrelu', mode='CNA', res_scale=1, upsample_mode='upconv')
 
 
14
  model.load_state_dict(torch.load(model_path), strict=True)
15
  model.eval()
16
+
17
+ for k, v in model.named_parameters():
18
+ v.requires_grad = False
19
  model = model.to(device)
20
 
21
+ base = os.path.splitext(os.path.basename(img_path))[0]
22
 
23
+ # read image
24
+ img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
25
+ img = img * 1.0 / 255
26
+ img = torch.from_numpy(img[np.newaxis, :, :]).float()
27
+ img_LR = img.unsqueeze(0)
28
+ img_LR = img_LR.to(device)
 
 
 
 
 
29
 
30
+ output = model(img_LR).squeeze(dim=0).float().cpu().clamp_(0, 1).numpy()
31
+ output = np.transpose(output, (1, 2, 0))
32
+ output = (output * 255.0).round()
33
+ cv2.imwrite('results/{:s}_rlt.jpg'.format(base), output, [int(cv2.IMWRITE_JPEG_QUALITY), 90])
 
transer_RRDB_models.py DELETED
@@ -1,55 +0,0 @@
1
- import os
2
- import torch
3
- import RRDBNet_arch as arch
4
-
5
- pretrained_net = torch.load('./models/RRDB_ESRGAN_x4.pth')
6
- save_path = './models/RRDB_ESRGAN_x4.pth'
7
-
8
- crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
9
- crt_net = crt_model.state_dict()
10
-
11
- load_net_clean = {}
12
- for k, v in pretrained_net.items():
13
- if k.startswith('module.'):
14
- load_net_clean[k[7:]] = v
15
- else:
16
- load_net_clean[k] = v
17
- pretrained_net = load_net_clean
18
-
19
- print('###################################\n')
20
- tbd = []
21
- for k, v in crt_net.items():
22
- tbd.append(k)
23
-
24
- # directly copy
25
- for k, v in crt_net.items():
26
- if k in pretrained_net and pretrained_net[k].size() == v.size():
27
- crt_net[k] = pretrained_net[k]
28
- tbd.remove(k)
29
-
30
- crt_net['conv_first.weight'] = pretrained_net['model.0.weight']
31
- crt_net['conv_first.bias'] = pretrained_net['model.0.bias']
32
-
33
- for k in tbd.copy():
34
- if 'RDB' in k:
35
- ori_k = k.replace('RRDB_trunk.', 'model.1.sub.')
36
- if '.weight' in k:
37
- ori_k = ori_k.replace('.weight', '.0.weight')
38
- elif '.bias' in k:
39
- ori_k = ori_k.replace('.bias', '.0.bias')
40
- crt_net[k] = pretrained_net[ori_k]
41
- tbd.remove(k)
42
-
43
- crt_net['trunk_conv.weight'] = pretrained_net['model.1.sub.23.weight']
44
- crt_net['trunk_conv.bias'] = pretrained_net['model.1.sub.23.bias']
45
- crt_net['upconv1.weight'] = pretrained_net['model.3.weight']
46
- crt_net['upconv1.bias'] = pretrained_net['model.3.bias']
47
- crt_net['upconv2.weight'] = pretrained_net['model.6.weight']
48
- crt_net['upconv2.bias'] = pretrained_net['model.6.bias']
49
- crt_net['HRconv.weight'] = pretrained_net['model.8.weight']
50
- crt_net['HRconv.bias'] = pretrained_net['model.8.bias']
51
- crt_net['conv_last.weight'] = pretrained_net['model.10.weight']
52
- crt_net['conv_last.bias'] = pretrained_net['model.10.bias']
53
-
54
- torch.save(crt_net, save_path)
55
- print('Saving to ', save_path)