codermert commited on
Commit
287e679
·
verified ·
1 Parent(s): bc1d7b8

Create rrdbnet_arch.py

Browse files
Files changed (1) hide show
  1. RealESRGAN/rrdbnet_arch.py +114 -0
RealESRGAN/rrdbnet_arch.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn as nn
3
+ from torch.nn import functional as F
4
+
5
+ from .arch_utils import default_init_weights, make_layer, pixel_unshuffle
6
+
7
+
8
+ class ResidualDenseBlock(nn.Module):
9
+ """Residual Dense Block.
10
+ Used in RRDB block in ESRGAN.
11
+ Args:
12
+ num_feat (int): Channel number of intermediate features.
13
+ num_grow_ch (int): Channels for each growth.
14
+ """
15
+
16
+ def __init__(self, num_feat=64, num_grow_ch=32):
17
+ super(ResidualDenseBlock, self).__init__()
18
+ self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
19
+ self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
20
+ self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
21
+ self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
22
+ self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
23
+
24
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
25
+
26
+ # initialization
27
+ default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
28
+
29
+ def forward(self, x):
30
+ x1 = self.lrelu(self.conv1(x))
31
+ x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
32
+ x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
33
+ x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
34
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
35
+ # Emperically, we use 0.2 to scale the residual for better performance
36
+ return x5 * 0.2 + x
37
+
38
+
39
+ class RRDB(nn.Module):
40
+ """Residual in Residual Dense Block.
41
+ Used in RRDB-Net in ESRGAN.
42
+ Args:
43
+ num_feat (int): Channel number of intermediate features.
44
+ num_grow_ch (int): Channels for each growth.
45
+ """
46
+
47
+ def __init__(self, num_feat, num_grow_ch=32):
48
+ super(RRDB, self).__init__()
49
+ self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
50
+ self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
51
+ self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
52
+
53
+ def forward(self, x):
54
+ out = self.rdb1(x)
55
+ out = self.rdb2(out)
56
+ out = self.rdb3(out)
57
+ # Emperically, we use 0.2 to scale the residual for better performance
58
+ return out * 0.2 + x
59
+
60
+
61
+ class RRDBNet(nn.Module):
62
+ """Networks consisting of Residual in Residual Dense Block, which is used
63
+ in ESRGAN.
64
+ ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.
65
+ We extend ESRGAN for scale x2 and scale x1.
66
+ Note: This is one option for scale 1, scale 2 in RRDBNet.
67
+ We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size
68
+ and enlarge the channel size before feeding inputs into the main ESRGAN architecture.
69
+ Args:
70
+ num_in_ch (int): Channel number of inputs.
71
+ num_out_ch (int): Channel number of outputs.
72
+ num_feat (int): Channel number of intermediate features.
73
+ Default: 64
74
+ num_block (int): Block number in the trunk network. Defaults: 23
75
+ num_grow_ch (int): Channels for each growth. Default: 32.
76
+ """
77
+
78
+ def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32):
79
+ super(RRDBNet, self).__init__()
80
+ self.scale = scale
81
+ if scale == 2:
82
+ num_in_ch = num_in_ch * 4
83
+ elif scale == 1:
84
+ num_in_ch = num_in_ch * 16
85
+ self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
86
+ self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch)
87
+ self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
88
+ # upsample
89
+ self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
90
+ self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
91
+ if scale == 8:
92
+ self.conv_up3 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
93
+ self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
94
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
95
+
96
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
97
+
98
+ def forward(self, x):
99
+ if self.scale == 2:
100
+ feat = pixel_unshuffle(x, scale=2)
101
+ elif self.scale == 1:
102
+ feat = pixel_unshuffle(x, scale=4)
103
+ else:
104
+ feat = x
105
+ feat = self.conv_first(feat)
106
+ body_feat = self.conv_body(self.body(feat))
107
+ feat = feat + body_feat
108
+ # upsample
109
+ feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
110
+ feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
111
+ if self.scale == 8:
112
+ feat = self.lrelu(self.conv_up3(F.interpolate(feat, scale_factor=2, mode='nearest')))
113
+ out = self.conv_last(self.lrelu(self.conv_hr(feat)))
114
+ return out