Spaces:
Runtime error
Runtime error
Upscaling works now
Browse files- README.md +174 -10
- RRDBNet_arch.py +0 -78
- architecture.py +38 -0
- block.py +261 -0
- models/README.md +0 -9
- test.py +21 -25
- transer_RRDB_models.py +0 -55
README.md
CHANGED
@@ -1,10 +1,174 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ESRGAN (Enhanced SRGAN) [[Paper]](https://arxiv.org/abs/1809.00219) [[BasicSR]](https://github.com/xinntao/BasicSR)
|
2 |
+
## :smiley: Training codes are in [BasicSR](https://github.com/xinntao/BasicSR) repo.
|
3 |
+
### Enhanced Super-Resolution Generative Adversarial Networks
|
4 |
+
By Xintao Wang, [Ke Yu](https://yuke93.github.io/), Shixiang Wu, [Jinjin Gu](http://www.jasongt.com/), Yihao Liu, [Chao Dong](https://scholar.google.com.hk/citations?user=OSDCB0UAAAAJ&hl=en), [Yu Qiao](http://mmlab.siat.ac.cn/yuqiao/), [Chen Change Loy](http://personal.ie.cuhk.edu.hk/~ccloy/)
|
5 |
+
|
6 |
+
This repo only provides simple testing codes, pretrained models and the network strategy demo.
|
7 |
+
|
8 |
+
### **For full training and testing codes, please refer to [BasicSR](https://github.com/xinntao/BasicSR).**
|
9 |
+
|
10 |
+
We won the first place in [PIRM2018-SR competition](https://www.pirm2018.org/PIRM-SR.html) (region 3) and got the best perceptual index.
|
11 |
+
The paper is accepted to [ECCV2018 PIRM Workshop](https://pirm2018.org/).
|
12 |
+
|
13 |
+
:triangular_flag_on_post: Add [Frequently Asked Questions](https://github.com/xinntao/ESRGAN/blob/master/QA.md).
|
14 |
+
|
15 |
+
> For instance,
|
16 |
+
> 1. How to reproduce your results in the PIRM18-SR Challenge (with low perceptual index)?
|
17 |
+
> 2. How do you get the perceptual index in your ESRGAN paper?
|
18 |
+
|
19 |
+
#### BibTeX
|
20 |
+
<!--
|
21 |
+
@article{wang2018esrgan,
|
22 |
+
author={Wang, Xintao and Yu, Ke and Wu, Shixiang and Gu, Jinjin and Liu, Yihao and Dong, Chao and Loy, Chen Change and Qiao, Yu and Tang, Xiaoou},
|
23 |
+
title={ESRGAN: Enhanced super-resolution generative adversarial networks},
|
24 |
+
journal={arXiv preprint arXiv:1809.00219},
|
25 |
+
year={2018}
|
26 |
+
}
|
27 |
+
-->
|
28 |
+
@InProceedings{wang2018esrgan,
|
29 |
+
author = {Wang, Xintao and Yu, Ke and Wu, Shixiang and Gu, Jinjin and Liu, Yihao and Dong, Chao and Qiao, Yu and Loy, Chen Change},
|
30 |
+
title = {ESRGAN: Enhanced super-resolution generative adversarial networks},
|
31 |
+
booktitle = {The European Conference on Computer Vision Workshops (ECCVW)},
|
32 |
+
month = {September},
|
33 |
+
year = {2018}
|
34 |
+
}
|
35 |
+
|
36 |
+
<p align="center">
|
37 |
+
<img src="figures/baboon.jpg">
|
38 |
+
</p>
|
39 |
+
|
40 |
+
The **RRDB_PSNR** PSNR_oriented model trained with DF2K dataset (a merged dataset with [DIV2K](https://data.vision.ee.ethz.ch/cvl/DIV2K/) and [Flickr2K](http://cv.snu.ac.kr/research/EDSR/Flickr2K.tar) (proposed in [EDSR](https://github.com/LimBee/NTIRE2017))) is also able to achive high PSNR performance.
|
41 |
+
|
42 |
+
| <sub>Method</sub> | <sub>Training dataset</sub> | <sub>Set5</sub> | <sub>Set14</sub> | <sub>BSD100</sub> | <sub>Urban100</sub> | <sub>Manga109</sub> |
|
43 |
+
|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
|
44 |
+
| <sub>[SRCNN](http://mmlab.ie.cuhk.edu.hk/projects/SRCNN.html)</sub>| <sub>291</sub>| <sub>30.48/0.8628</sub> |<sub>27.50/0.7513</sub>|<sub>26.90/0.7101</sub>|<sub>24.52/0.7221</sub>|<sub>27.58/0.8555</sub>|
|
45 |
+
| <sub>[EDSR](https://github.com/thstkdgus35/EDSR-PyTorch)</sub> | <sub>DIV2K</sub> | <sub>32.46/0.8968</sub> | <sub>28.80/0.7876</sub> | <sub>27.71/0.7420</sub> | <sub>26.64/0.8033</sub> | <sub>31.02/0.9148</sub> |
|
46 |
+
| <sub>[RCAN](https://github.com/yulunzhang/RCAN)</sub> | <sub>DIV2K</sub> | <sub>32.63/0.9002</sub> | <sub>28.87/0.7889</sub> | <sub>27.77/0.7436</sub> | <sub>26.82/ 0.8087</sub>| <sub>31.22/ 0.9173</sub>|
|
47 |
+
|<sub>RRDB(ours)</sub>| <sub>DF2K</sub>| <sub>**32.73/0.9011**</sub> |<sub>**28.99/0.7917**</sub> |<sub>**27.85/0.7455**</sub> |<sub>**27.03/0.8153**</sub> |<sub>**31.66/0.9196**</sub>|
|
48 |
+
|
49 |
+
## Quick Test
|
50 |
+
#### Dependencies
|
51 |
+
- Python 3
|
52 |
+
- [PyTorch >= 0.4](https://pytorch.org/) (CUDA version >= 7.5 if installing with CUDA. [More details](https://pytorch.org/get-started/previous-versions/))
|
53 |
+
- Python packages: `pip install numpy opencv-python`
|
54 |
+
|
55 |
+
### Test models
|
56 |
+
1. Clone this github repo.
|
57 |
+
```
|
58 |
+
git clone https://github.com/xinntao/ESRGAN
|
59 |
+
cd ESRGAN
|
60 |
+
```
|
61 |
+
2. Place your own **low-resolution images** in `./LR` folder. (There are two sample images - baboon and comic).
|
62 |
+
3. Download pretrained models from [Google Drive](https://drive.google.com/drive/u/0/folders/17VYV_SoZZesU6mbxz2dMAIccSSlqLecY) or [Baidu Drive](https://pan.baidu.com/s/1-Lh6ma-wXzfH8NqeBtPaFQ). Place the models in `./models`. We provide two models with high perceptual quality and high PSNR performance (see [model list](https://github.com/xinntao/ESRGAN/tree/master/models)).
|
63 |
+
4. Run test. We provide ESRGAN model and RRDB_PSNR model.
|
64 |
+
```
|
65 |
+
python test.py models/RRDB_ESRGAN_x4.pth
|
66 |
+
python test.py models/RRDB_PSNR_x4.pth
|
67 |
+
```
|
68 |
+
5. The results are in `./results` folder.
|
69 |
+
### Network interpolation demo
|
70 |
+
You can interpolate the RRDB_ESRGAN and RRDB_PSNR models with alpha in [0, 1].
|
71 |
+
|
72 |
+
1. Run `python net_interp.py 0.8`, where *0.8* is the interpolation parameter and you can change it to any value in [0,1].
|
73 |
+
2. Run `python test.py models/interp_08.pth`, where *models/interp_08.pth* is the model path.
|
74 |
+
|
75 |
+
<p align="center">
|
76 |
+
<img height="400" src="figures/43074.gif">
|
77 |
+
</p>
|
78 |
+
|
79 |
+
## Perceptual-driven SR Results
|
80 |
+
|
81 |
+
You can download all the resutls from [Google Drive](https://drive.google.com/drive/folders/1iaM-c6EgT1FNoJAOKmDrK7YhEhtlKcLx?usp=sharing). (:heavy_check_mark: included; :heavy_minus_sign: not included; :o: TODO)
|
82 |
+
|
83 |
+
HR images can be downloaed from [BasicSR-Datasets](https://github.com/xinntao/BasicSR#datasets).
|
84 |
+
|
85 |
+
| Datasets |LR | [*ESRGAN*](https://arxiv.org/abs/1809.00219) | [SRGAN](https://arxiv.org/abs/1609.04802) | [EnhanceNet](http://openaccess.thecvf.com/content_ICCV_2017/papers/Sajjadi_EnhanceNet_Single_Image_ICCV_2017_paper.pdf) | [CX](https://arxiv.org/abs/1803.04626) |
|
86 |
+
|:---:|:---:|:---:|:---:|:---:|:---:|
|
87 |
+
| Set5 |:heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark:| :o: |
|
88 |
+
| Set14 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark:| :o: |
|
89 |
+
| BSDS100 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark:| :o: |
|
90 |
+
| [PIRM](https://pirm.github.io/) <br><sup>(val, test)</sup> | :heavy_check_mark: | :heavy_check_mark: | :heavy_minus_sign: | :heavy_check_mark:| :heavy_check_mark: |
|
91 |
+
| [OST300](https://arxiv.org/pdf/1804.02815.pdf) |:heavy_check_mark: | :heavy_check_mark: | :heavy_minus_sign: | :heavy_check_mark:| :o: |
|
92 |
+
| urban100 | :heavy_check_mark: | :heavy_check_mark: | :heavy_minus_sign: | :heavy_check_mark:| :o: |
|
93 |
+
| [DIV2K](https://data.vision.ee.ethz.ch/cvl/DIV2K/) <br><sup>(val, test)</sup> | :heavy_check_mark: | :heavy_check_mark: | :heavy_minus_sign: | :heavy_check_mark:| :o: |
|
94 |
+
|
95 |
+
## ESRGAN
|
96 |
+
We improve the [SRGAN](https://arxiv.org/abs/1609.04802) from three aspects:
|
97 |
+
1. adopt a deeper model using Residual-in-Residual Dense Block (RRDB) without batch normalization layers.
|
98 |
+
2. employ [Relativistic average GAN](https://ajolicoeur.wordpress.com/relativisticgan/) instead of the vanilla GAN.
|
99 |
+
3. improve the perceptual loss by using the features before activation.
|
100 |
+
|
101 |
+
In contrast to SRGAN, which claimed that **deeper models are increasingly difficult to train**, our deeper ESRGAN model shows its superior performance with easy training.
|
102 |
+
|
103 |
+
<p align="center">
|
104 |
+
<img height="120" src="figures/architecture.jpg">
|
105 |
+
</p>
|
106 |
+
<p align="center">
|
107 |
+
<img height="180" src="figures/RRDB.png">
|
108 |
+
</p>
|
109 |
+
|
110 |
+
## Network Interpolation
|
111 |
+
We propose the **network interpolation strategy** to balance the visual quality and PSNR.
|
112 |
+
|
113 |
+
<p align="center">
|
114 |
+
<img height="500" src="figures/net_interp.jpg">
|
115 |
+
</p>
|
116 |
+
|
117 |
+
We show the smooth animation with the interpolation parameters changing from 0 to 1.
|
118 |
+
Interestingly, it is observed that the network interpolation strategy provides a smooth control of the RRDB_PSNR model and the fine-tuned ESRGAN model.
|
119 |
+
|
120 |
+
<p align="center">
|
121 |
+
<img height="480" src="figures/81.gif">
|
122 |
+
   
|
123 |
+
<img height="480" src="figures/102061.gif">
|
124 |
+
</p>
|
125 |
+
|
126 |
+
## Qualitative Results
|
127 |
+
PSNR (evaluated on the Y channel) and the perceptual index used in the PIRM-SR challenge are also provided for reference.
|
128 |
+
|
129 |
+
<p align="center">
|
130 |
+
<img src="figures/qualitative_cmp_01.jpg">
|
131 |
+
</p>
|
132 |
+
<p align="center">
|
133 |
+
<img src="figures/qualitative_cmp_02.jpg">
|
134 |
+
</p>
|
135 |
+
<p align="center">
|
136 |
+
<img src="figures/qualitative_cmp_03.jpg">
|
137 |
+
</p>
|
138 |
+
<p align="center">
|
139 |
+
<img src="figures/qualitative_cmp_04.jpg">
|
140 |
+
</p>
|
141 |
+
|
142 |
+
## Ablation Study
|
143 |
+
Overall visual comparisons for showing the effects of each component in
|
144 |
+
ESRGAN. Each column represents a model with its configurations in the top.
|
145 |
+
The red sign indicates the main improvement compared with the previous model.
|
146 |
+
<p align="center">
|
147 |
+
<img src="figures/abalation_study.png">
|
148 |
+
</p>
|
149 |
+
|
150 |
+
## BN artifacts
|
151 |
+
We empirically observe that BN layers tend to bring artifacts. These artifacts,
|
152 |
+
namely BN artifacts, occasionally appear among iterations and different settings,
|
153 |
+
violating the needs for a stable performance over training. We find that
|
154 |
+
the network depth, BN position, training dataset and training loss
|
155 |
+
have impact on the occurrence of BN artifacts.
|
156 |
+
<p align="center">
|
157 |
+
<img src="figures/BN_artifacts.jpg">
|
158 |
+
</p>
|
159 |
+
|
160 |
+
## Useful techniques to train a very deep network
|
161 |
+
We find that residual scaling and smaller initialization can help to train a very deep network. More details are in the Supplementary File attached in our [paper](https://arxiv.org/abs/1809.00219).
|
162 |
+
|
163 |
+
<p align="center">
|
164 |
+
<img height="250" src="figures/train_deeper_neta.png">
|
165 |
+
<img height="250" src="figures/train_deeper_netb.png">
|
166 |
+
</p>
|
167 |
+
|
168 |
+
## The influence of training patch size
|
169 |
+
We observe that training a deeper network benefits from a larger patch size. Moreover, the deeper model achieves more improvement (∼0.12dB) than the shallower one (∼0.04dB) since larger model capacity is capable of taking full advantage of
|
170 |
+
larger training patch size. (Evaluated on Set5 dataset with RGB channels.)
|
171 |
+
<p align="center">
|
172 |
+
<img height="250" src="figures/patch_a.png">
|
173 |
+
<img height="250" src="figures/patch_b.png">
|
174 |
+
</p>
|
RRDBNet_arch.py
DELETED
@@ -1,78 +0,0 @@
|
|
1 |
-
import functools
|
2 |
-
import torch
|
3 |
-
import torch.nn as nn
|
4 |
-
import torch.nn.functional as F
|
5 |
-
|
6 |
-
|
7 |
-
def make_layer(block, n_layers):
|
8 |
-
layers = []
|
9 |
-
for _ in range(n_layers):
|
10 |
-
layers.append(block())
|
11 |
-
return nn.Sequential(*layers)
|
12 |
-
|
13 |
-
|
14 |
-
class ResidualDenseBlock_5C(nn.Module):
|
15 |
-
def __init__(self, nf=64, gc=32, bias=True):
|
16 |
-
super(ResidualDenseBlock_5C, self).__init__()
|
17 |
-
# gc: growth channel, i.e. intermediate channels
|
18 |
-
self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
|
19 |
-
self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
|
20 |
-
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
|
21 |
-
self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
|
22 |
-
self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
|
23 |
-
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
24 |
-
|
25 |
-
# initialization
|
26 |
-
# mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
|
27 |
-
|
28 |
-
def forward(self, x):
|
29 |
-
x1 = self.lrelu(self.conv1(x))
|
30 |
-
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
31 |
-
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
32 |
-
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
33 |
-
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
34 |
-
return x5 * 0.2 + x
|
35 |
-
|
36 |
-
|
37 |
-
class RRDB(nn.Module):
|
38 |
-
'''Residual in Residual Dense Block'''
|
39 |
-
|
40 |
-
def __init__(self, nf, gc=32):
|
41 |
-
super(RRDB, self).__init__()
|
42 |
-
self.RDB1 = ResidualDenseBlock_5C(nf, gc)
|
43 |
-
self.RDB2 = ResidualDenseBlock_5C(nf, gc)
|
44 |
-
self.RDB3 = ResidualDenseBlock_5C(nf, gc)
|
45 |
-
|
46 |
-
def forward(self, x):
|
47 |
-
out = self.RDB1(x)
|
48 |
-
out = self.RDB2(out)
|
49 |
-
out = self.RDB3(out)
|
50 |
-
return out * 0.2 + x
|
51 |
-
|
52 |
-
|
53 |
-
class RRDBNet(nn.Module):
|
54 |
-
def __init__(self, in_nc, out_nc, nf, nb, gc=32):
|
55 |
-
super(RRDBNet, self).__init__()
|
56 |
-
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
|
57 |
-
|
58 |
-
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
|
59 |
-
self.RRDB_trunk = make_layer(RRDB_block_f, nb)
|
60 |
-
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
61 |
-
#### upsampling
|
62 |
-
self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
63 |
-
self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
64 |
-
self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
65 |
-
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
|
66 |
-
|
67 |
-
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
68 |
-
|
69 |
-
def forward(self, x):
|
70 |
-
fea = self.conv_first(x)
|
71 |
-
trunk = self.trunk_conv(self.RRDB_trunk(fea))
|
72 |
-
fea = fea + trunk
|
73 |
-
|
74 |
-
fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
|
75 |
-
fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
|
76 |
-
out = self.conv_last(self.lrelu(self.HRconv(fea)))
|
77 |
-
|
78 |
-
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
architecture.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import block as B
|
5 |
+
|
6 |
+
|
7 |
+
class RRDB_Net(nn.Module):
|
8 |
+
def __init__(self, in_nc, out_nc, nf, nb, gc=32, upscale=4, norm_type=None, act_type='leakyrelu', \
|
9 |
+
mode='CNA', res_scale=1, upsample_mode='upconv'):
|
10 |
+
super(RRDB_Net, self).__init__()
|
11 |
+
n_upscale = int(math.log(upscale, 2))
|
12 |
+
if upscale == 3:
|
13 |
+
n_upscale = 1
|
14 |
+
|
15 |
+
fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None)
|
16 |
+
rb_blocks = [B.RRDB(nf, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
|
17 |
+
norm_type=norm_type, act_type=act_type, mode='CNA') for _ in range(nb)]
|
18 |
+
LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode)
|
19 |
+
|
20 |
+
if upsample_mode == 'upconv':
|
21 |
+
upsample_block = B.upconv_blcok
|
22 |
+
elif upsample_mode == 'pixelshuffle':
|
23 |
+
upsample_block = B.pixelshuffle_block
|
24 |
+
else:
|
25 |
+
raise NotImplementedError('upsample mode [%s] is not found' % upsample_mode)
|
26 |
+
if upscale == 3:
|
27 |
+
upsampler = upsample_block(nf, nf, 3, act_type=act_type)
|
28 |
+
else:
|
29 |
+
upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)]
|
30 |
+
HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)
|
31 |
+
HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None)
|
32 |
+
|
33 |
+
self.model = B.sequential(fea_conv, B.ShortcutBlock(B.sequential(*rb_blocks, LR_conv)),\
|
34 |
+
*upsampler, HR_conv0, HR_conv1)
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
x = self.model(x)
|
38 |
+
return x
|
block.py
ADDED
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import OrderedDict
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
####################
|
6 |
+
# Basic blocks
|
7 |
+
####################
|
8 |
+
|
9 |
+
|
10 |
+
def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1):
|
11 |
+
# helper selecting activation
|
12 |
+
# neg_slope: for leakyrelu and init of prelu
|
13 |
+
# n_prelu: for p_relu num_parameters
|
14 |
+
act_type = act_type.lower()
|
15 |
+
if act_type == 'relu':
|
16 |
+
layer = nn.ReLU(inplace)
|
17 |
+
elif act_type == 'leakyrelu':
|
18 |
+
layer = nn.LeakyReLU(neg_slope, inplace)
|
19 |
+
elif act_type == 'prelu':
|
20 |
+
layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
|
21 |
+
else:
|
22 |
+
raise NotImplementedError('activation layer [%s] is not found' % act_type)
|
23 |
+
return layer
|
24 |
+
|
25 |
+
|
26 |
+
def norm(norm_type, nc):
|
27 |
+
# helper selecting normalization layer
|
28 |
+
norm_type = norm_type.lower()
|
29 |
+
if norm_type == 'batch':
|
30 |
+
layer = nn.BatchNorm2d(nc, affine=True)
|
31 |
+
elif norm_type == 'instance':
|
32 |
+
layer = nn.InstanceNorm2d(nc, affine=False)
|
33 |
+
else:
|
34 |
+
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
|
35 |
+
return layer
|
36 |
+
|
37 |
+
|
38 |
+
def pad(pad_type, padding):
|
39 |
+
# helper selecting padding layer
|
40 |
+
# if padding is 'zero', do by conv layers
|
41 |
+
pad_type = pad_type.lower()
|
42 |
+
if padding == 0:
|
43 |
+
return None
|
44 |
+
if pad_type == 'reflect':
|
45 |
+
layer = nn.ReflectionPad2d(padding)
|
46 |
+
elif pad_type == 'replicate':
|
47 |
+
layer = nn.ReplicationPad2d(padding)
|
48 |
+
else:
|
49 |
+
raise NotImplementedError('padding layer [%s] is not implemented' % pad_type)
|
50 |
+
return layer
|
51 |
+
|
52 |
+
|
53 |
+
def get_valid_padding(kernel_size, dilation):
|
54 |
+
kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
|
55 |
+
padding = (kernel_size - 1) // 2
|
56 |
+
return padding
|
57 |
+
|
58 |
+
|
59 |
+
class ConcatBlock(nn.Module):
|
60 |
+
# Concat the output of a submodule to its input
|
61 |
+
def __init__(self, submodule):
|
62 |
+
super(ConcatBlock, self).__init__()
|
63 |
+
self.sub = submodule
|
64 |
+
|
65 |
+
def forward(self, x):
|
66 |
+
output = torch.cat((x, self.sub(x)), dim=1)
|
67 |
+
return output
|
68 |
+
|
69 |
+
def __repr__(self):
|
70 |
+
tmpstr = 'Identity .. \n|'
|
71 |
+
modstr = self.sub.__repr__().replace('\n', '\n|')
|
72 |
+
tmpstr = tmpstr + modstr
|
73 |
+
return tmpstr
|
74 |
+
|
75 |
+
|
76 |
+
class ShortcutBlock(nn.Module):
|
77 |
+
#Elementwise sum the output of a submodule to its input
|
78 |
+
def __init__(self, submodule):
|
79 |
+
super(ShortcutBlock, self).__init__()
|
80 |
+
self.sub = submodule
|
81 |
+
|
82 |
+
def forward(self, x):
|
83 |
+
output = x + self.sub(x)
|
84 |
+
return output
|
85 |
+
|
86 |
+
def __repr__(self):
|
87 |
+
tmpstr = 'Identity + \n|'
|
88 |
+
modstr = self.sub.__repr__().replace('\n', '\n|')
|
89 |
+
tmpstr = tmpstr + modstr
|
90 |
+
return tmpstr
|
91 |
+
|
92 |
+
|
93 |
+
def sequential(*args):
|
94 |
+
# Flatten Sequential. It unwraps nn.Sequential.
|
95 |
+
if len(args) == 1:
|
96 |
+
if isinstance(args[0], OrderedDict):
|
97 |
+
raise NotImplementedError('sequential does not support OrderedDict input.')
|
98 |
+
return args[0] # No sequential is needed.
|
99 |
+
modules = []
|
100 |
+
for module in args:
|
101 |
+
if isinstance(module, nn.Sequential):
|
102 |
+
for submodule in module.children():
|
103 |
+
modules.append(submodule)
|
104 |
+
elif isinstance(module, nn.Module):
|
105 |
+
modules.append(module)
|
106 |
+
return nn.Sequential(*modules)
|
107 |
+
|
108 |
+
|
109 |
+
def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True,
|
110 |
+
pad_type='zero', norm_type=None, act_type='relu', mode='CNA'):
|
111 |
+
"""
|
112 |
+
Conv layer with padding, normalization, activation
|
113 |
+
mode: CNA --> Conv -> Norm -> Act
|
114 |
+
NAC --> Norm -> Act --> Conv (Identity Mappings in Deep Residual Networks, ECCV16)
|
115 |
+
"""
|
116 |
+
assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [%s]' % mode
|
117 |
+
padding = get_valid_padding(kernel_size, dilation)
|
118 |
+
p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
|
119 |
+
padding = padding if pad_type == 'zero' else 0
|
120 |
+
|
121 |
+
c = nn.Conv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, \
|
122 |
+
dilation=dilation, bias=bias, groups=groups)
|
123 |
+
a = act(act_type) if act_type else None
|
124 |
+
if 'CNA' in mode:
|
125 |
+
n = norm(norm_type, out_nc) if norm_type else None
|
126 |
+
return sequential(p, c, n, a)
|
127 |
+
elif mode == 'NAC':
|
128 |
+
if norm_type is None and act_type is not None:
|
129 |
+
a = act(act_type, inplace=False)
|
130 |
+
# Important!
|
131 |
+
# input----ReLU(inplace)----Conv--+----output
|
132 |
+
# |________________________|
|
133 |
+
# inplace ReLU will modify the input, therefore wrong output
|
134 |
+
n = norm(norm_type, in_nc) if norm_type else None
|
135 |
+
return sequential(n, a, p, c)
|
136 |
+
|
137 |
+
|
138 |
+
####################
|
139 |
+
# Useful blocks
|
140 |
+
####################
|
141 |
+
|
142 |
+
|
143 |
+
class ResNetBlock(nn.Module):
|
144 |
+
"""
|
145 |
+
ResNet Block, 3-3 style
|
146 |
+
with extra residual scaling used in EDSR
|
147 |
+
(Enhanced Deep Residual Networks for Single Image Super-Resolution, CVPRW 17)
|
148 |
+
"""
|
149 |
+
|
150 |
+
def __init__(self, in_nc, mid_nc, out_nc, kernel_size=3, stride=1, dilation=1, groups=1, \
|
151 |
+
bias=True, pad_type='zero', norm_type=None, act_type='relu', mode='CNA', res_scale=1):
|
152 |
+
super(ResNetBlock, self).__init__()
|
153 |
+
conv0 = conv_block(in_nc, mid_nc, kernel_size, stride, dilation, groups, bias, pad_type, \
|
154 |
+
norm_type, act_type, mode)
|
155 |
+
if mode == 'CNA':
|
156 |
+
act_type = None
|
157 |
+
if mode == 'CNAC': # Residual path: |-CNAC-|
|
158 |
+
act_type = None
|
159 |
+
norm_type = None
|
160 |
+
conv1 = conv_block(mid_nc, out_nc, kernel_size, stride, dilation, groups, bias, pad_type, \
|
161 |
+
norm_type, act_type, mode)
|
162 |
+
# if in_nc != out_nc:
|
163 |
+
# self.project = conv_block(in_nc, out_nc, 1, stride, dilation, 1, bias, pad_type, \
|
164 |
+
# None, None)
|
165 |
+
# print('Need a projecter in ResNetBlock.')
|
166 |
+
# else:
|
167 |
+
# self.project = lambda x:x
|
168 |
+
self.res = sequential(conv0, conv1)
|
169 |
+
self.res_scale = res_scale
|
170 |
+
|
171 |
+
def forward(self, x):
|
172 |
+
res = self.res(x).mul(self.res_scale)
|
173 |
+
return x + res
|
174 |
+
|
175 |
+
|
176 |
+
class ResidualDenseBlock_5C(nn.Module):
|
177 |
+
"""
|
178 |
+
Residual Dense Block
|
179 |
+
style: 5 convs
|
180 |
+
The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
|
181 |
+
"""
|
182 |
+
|
183 |
+
def __init__(self, nc, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
|
184 |
+
norm_type=None, act_type='leakyrelu', mode='CNA'):
|
185 |
+
super(ResidualDenseBlock_5C, self).__init__()
|
186 |
+
# gc: growth channel, i.e. intermediate channels
|
187 |
+
self.conv1 = conv_block(nc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
|
188 |
+
norm_type=norm_type, act_type=act_type, mode=mode)
|
189 |
+
self.conv2 = conv_block(nc+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
|
190 |
+
norm_type=norm_type, act_type=act_type, mode=mode)
|
191 |
+
self.conv3 = conv_block(nc+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
|
192 |
+
norm_type=norm_type, act_type=act_type, mode=mode)
|
193 |
+
self.conv4 = conv_block(nc+3*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
|
194 |
+
norm_type=norm_type, act_type=act_type, mode=mode)
|
195 |
+
if mode == 'CNA':
|
196 |
+
last_act = None
|
197 |
+
else:
|
198 |
+
last_act = act_type
|
199 |
+
self.conv5 = conv_block(nc+4*gc, nc, 3, stride, bias=bias, pad_type=pad_type, \
|
200 |
+
norm_type=norm_type, act_type=last_act, mode=mode)
|
201 |
+
|
202 |
+
def forward(self, x):
|
203 |
+
x1 = self.conv1(x)
|
204 |
+
x2 = self.conv2(torch.cat((x, x1), 1))
|
205 |
+
x3 = self.conv3(torch.cat((x, x1, x2), 1))
|
206 |
+
x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
|
207 |
+
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
208 |
+
return x5.mul(0.2) + x
|
209 |
+
|
210 |
+
|
211 |
+
class RRDB(nn.Module):
|
212 |
+
"""
|
213 |
+
Residual in Residual Dense Block
|
214 |
+
"""
|
215 |
+
|
216 |
+
def __init__(self, nc, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
|
217 |
+
norm_type=None, act_type='leakyrelu', mode='CNA'):
|
218 |
+
super(RRDB, self).__init__()
|
219 |
+
self.RDB1 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, \
|
220 |
+
norm_type, act_type, mode)
|
221 |
+
self.RDB2 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, \
|
222 |
+
norm_type, act_type, mode)
|
223 |
+
self.RDB3 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, \
|
224 |
+
norm_type, act_type, mode)
|
225 |
+
|
226 |
+
def forward(self, x):
|
227 |
+
out = self.RDB1(x)
|
228 |
+
out = self.RDB2(out)
|
229 |
+
out = self.RDB3(out)
|
230 |
+
return out.mul(0.2) + x
|
231 |
+
|
232 |
+
|
233 |
+
####################
|
234 |
+
# Upsampler
|
235 |
+
####################
|
236 |
+
|
237 |
+
|
238 |
+
def pixelshuffle_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
|
239 |
+
pad_type='zero', norm_type=None, act_type='relu'):
|
240 |
+
"""
|
241 |
+
Pixel shuffle layer
|
242 |
+
(Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
|
243 |
+
Neural Network, CVPR17)
|
244 |
+
"""
|
245 |
+
conv = conv_block(in_nc, out_nc * (upscale_factor ** 2), kernel_size, stride, bias=bias,
|
246 |
+
pad_type=pad_type, norm_type=None, act_type=None)
|
247 |
+
pixel_shuffle = nn.PixelShuffle(upscale_factor)
|
248 |
+
|
249 |
+
n = norm(norm_type, out_nc) if norm_type else None
|
250 |
+
a = act(act_type) if act_type else None
|
251 |
+
return sequential(conv, pixel_shuffle, n, a)
|
252 |
+
|
253 |
+
|
254 |
+
def upconv_blcok(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
|
255 |
+
pad_type='zero', norm_type=None, act_type='relu', mode='nearest'):
|
256 |
+
# Up conv
|
257 |
+
# described in https://distill.pub/2016/deconv-checkerboard/
|
258 |
+
upsample = nn.Upsample(scale_factor=upscale_factor, mode=mode)
|
259 |
+
conv = conv_block(in_nc, out_nc, kernel_size, stride, bias=bias,
|
260 |
+
pad_type=pad_type, norm_type=norm_type, act_type=act_type)
|
261 |
+
return sequential(upsample, conv)
|
models/README.md
DELETED
@@ -1,9 +0,0 @@
|
|
1 |
-
## Place pretrained models here.
|
2 |
-
|
3 |
-
We provide two pretrained models:
|
4 |
-
|
5 |
-
1. `RRDB_ESRGAN_x4.pth`: the final ESRGAN model we used in our [paper](https://arxiv.org/abs/1809.00219).
|
6 |
-
2. `RRDB_PSNR_x4.pth`: the PSNR-oriented model with **high PSNR performance**.
|
7 |
-
|
8 |
-
*Note that* the pretrained models are trained under the `MATLAB bicubic` kernel.
|
9 |
-
If the downsampled kernel is different from that, the results may have artifacts.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test.py
CHANGED
@@ -1,37 +1,33 @@
|
|
1 |
-
import
|
|
|
2 |
import glob
|
3 |
import cv2
|
4 |
import numpy as np
|
5 |
import torch
|
6 |
-
import
|
7 |
|
8 |
-
model_path = '
|
9 |
-
|
10 |
-
|
11 |
|
12 |
-
|
13 |
-
|
14 |
-
model = arch.RRDBNet(3, 3, 64, 23, gc=32)
|
15 |
model.load_state_dict(torch.load(model_path), strict=True)
|
16 |
model.eval()
|
|
|
|
|
|
|
17 |
model = model.to(device)
|
18 |
|
19 |
-
|
20 |
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
img = cv2.imread(path, cv2.IMREAD_COLOR)
|
28 |
-
img = img * 1.0 / 255
|
29 |
-
img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
|
30 |
-
img_LR = img.unsqueeze(0)
|
31 |
-
img_LR = img_LR.to(device)
|
32 |
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
cv2.imwrite('results/{:s}_rlt.png'.format(base), output)
|
|
|
1 |
+
import sys
|
2 |
+
import os.path
|
3 |
import glob
|
4 |
import cv2
|
5 |
import numpy as np
|
6 |
import torch
|
7 |
+
import architecture as arch
|
8 |
|
9 |
+
model_path = '4x_eula_digimanga_bw_v2_nc1_307k.pth'
|
10 |
+
img_path = sys.argv[1]
|
11 |
+
device = torch.device('cpu')
|
12 |
|
13 |
+
model = arch.RRDB_Net(1, 1, 64, 23, gc=32, upscale=4, norm_type=None, act_type='leakyrelu', mode='CNA', res_scale=1, upsample_mode='upconv')
|
|
|
|
|
14 |
model.load_state_dict(torch.load(model_path), strict=True)
|
15 |
model.eval()
|
16 |
+
|
17 |
+
for k, v in model.named_parameters():
|
18 |
+
v.requires_grad = False
|
19 |
model = model.to(device)
|
20 |
|
21 |
+
base = os.path.splitext(os.path.basename(img_path))[0]
|
22 |
|
23 |
+
# read image
|
24 |
+
img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
|
25 |
+
img = img * 1.0 / 255
|
26 |
+
img = torch.from_numpy(img[np.newaxis, :, :]).float()
|
27 |
+
img_LR = img.unsqueeze(0)
|
28 |
+
img_LR = img_LR.to(device)
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
+
output = model(img_LR).squeeze(dim=0).float().cpu().clamp_(0, 1).numpy()
|
31 |
+
output = np.transpose(output, (1, 2, 0))
|
32 |
+
output = (output * 255.0).round()
|
33 |
+
cv2.imwrite('results/{:s}_rlt.jpg'.format(base), output, [int(cv2.IMWRITE_JPEG_QUALITY), 90])
|
|
transer_RRDB_models.py
DELETED
@@ -1,55 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import torch
|
3 |
-
import RRDBNet_arch as arch
|
4 |
-
|
5 |
-
pretrained_net = torch.load('./models/RRDB_ESRGAN_x4.pth')
|
6 |
-
save_path = './models/RRDB_ESRGAN_x4.pth'
|
7 |
-
|
8 |
-
crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
|
9 |
-
crt_net = crt_model.state_dict()
|
10 |
-
|
11 |
-
load_net_clean = {}
|
12 |
-
for k, v in pretrained_net.items():
|
13 |
-
if k.startswith('module.'):
|
14 |
-
load_net_clean[k[7:]] = v
|
15 |
-
else:
|
16 |
-
load_net_clean[k] = v
|
17 |
-
pretrained_net = load_net_clean
|
18 |
-
|
19 |
-
print('###################################\n')
|
20 |
-
tbd = []
|
21 |
-
for k, v in crt_net.items():
|
22 |
-
tbd.append(k)
|
23 |
-
|
24 |
-
# directly copy
|
25 |
-
for k, v in crt_net.items():
|
26 |
-
if k in pretrained_net and pretrained_net[k].size() == v.size():
|
27 |
-
crt_net[k] = pretrained_net[k]
|
28 |
-
tbd.remove(k)
|
29 |
-
|
30 |
-
crt_net['conv_first.weight'] = pretrained_net['model.0.weight']
|
31 |
-
crt_net['conv_first.bias'] = pretrained_net['model.0.bias']
|
32 |
-
|
33 |
-
for k in tbd.copy():
|
34 |
-
if 'RDB' in k:
|
35 |
-
ori_k = k.replace('RRDB_trunk.', 'model.1.sub.')
|
36 |
-
if '.weight' in k:
|
37 |
-
ori_k = ori_k.replace('.weight', '.0.weight')
|
38 |
-
elif '.bias' in k:
|
39 |
-
ori_k = ori_k.replace('.bias', '.0.bias')
|
40 |
-
crt_net[k] = pretrained_net[ori_k]
|
41 |
-
tbd.remove(k)
|
42 |
-
|
43 |
-
crt_net['trunk_conv.weight'] = pretrained_net['model.1.sub.23.weight']
|
44 |
-
crt_net['trunk_conv.bias'] = pretrained_net['model.1.sub.23.bias']
|
45 |
-
crt_net['upconv1.weight'] = pretrained_net['model.3.weight']
|
46 |
-
crt_net['upconv1.bias'] = pretrained_net['model.3.bias']
|
47 |
-
crt_net['upconv2.weight'] = pretrained_net['model.6.weight']
|
48 |
-
crt_net['upconv2.bias'] = pretrained_net['model.6.bias']
|
49 |
-
crt_net['HRconv.weight'] = pretrained_net['model.8.weight']
|
50 |
-
crt_net['HRconv.bias'] = pretrained_net['model.8.bias']
|
51 |
-
crt_net['conv_last.weight'] = pretrained_net['model.10.weight']
|
52 |
-
crt_net['conv_last.bias'] = pretrained_net['model.10.bias']
|
53 |
-
|
54 |
-
torch.save(crt_net, save_path)
|
55 |
-
print('Saving to ', save_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|