5AF1 commited on
Commit
9e7a9f7
·
1 Parent(s): 81244f2

deploy torch model

Browse files
Set5/x2/img_001_SRF_2_HR.png ADDED
Set5/x2/img_001_SRF_2_LR.png ADDED
Set5/x2/img_002_SRF_2_HR.png ADDED
Set5/x2/img_002_SRF_2_LR.png ADDED
Set5/x2/img_003_SRF_2_HR.png ADDED
Set5/x2/img_003_SRF_2_LR.png ADDED
Set5/x2/img_004_SRF_2_HR.png ADDED
Set5/x2/img_004_SRF_2_LR.png ADDED
Set5/x2/img_005_SRF_2_HR.png ADDED
Set5/x2/img_005_SRF_2_LR.png ADDED
Set5/x3/img_001_SRF_3_HR.png ADDED
Set5/x3/img_001_SRF_3_LR.png ADDED
Set5/x3/img_002_SRF_3_HR.png ADDED
Set5/x3/img_002_SRF_3_LR.png ADDED
Set5/x3/img_003_SRF_3_HR.png ADDED
Set5/x3/img_003_SRF_3_LR.png ADDED
Set5/x3/img_004_SRF_3_HR.png ADDED
Set5/x3/img_004_SRF_3_LR.png ADDED
Set5/x3/img_005_SRF_3_HR.png ADDED
Set5/x3/img_005_SRF_3_LR.png ADDED
Set5/x4/img_001_SRF_4_HR.png ADDED
Set5/x4/img_001_SRF_4_LR.png ADDED
Set5/x4/img_002_SRF_4_HR.png ADDED
Set5/x4/img_002_SRF_4_LR.png ADDED
Set5/x4/img_003_SRF_4_HR.png ADDED
Set5/x4/img_003_SRF_4_LR.png ADDED
Set5/x4/img_004_SRF_4_HR.png ADDED
Set5/x4/img_004_SRF_4_LR.png ADDED
Set5/x4/img_005_SRF_4_HR.png ADDED
Set5/x4/img_005_SRF_4_LR.png ADDED
app.py ADDED
File without changes
checkpoint/HGSRCNN.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:57a33668876d2667b436b74b975faa5789dcd7ba1934e96be0dca28f1c7dc45b
3
+ size 11026482
checkpoint/LSGSRCNN.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:57a33668876d2667b436b74b975faa5789dcd7ba1934e96be0dca28f1c7dc45b
3
+ size 11026482
model/HGSRCNN.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import model.ops as ops
4
+
5
+ '''
6
+ class Block(nn.Module):
7
+ def __init__(self,
8
+ in_channels, out_channels,
9
+ group=1):
10
+ super(Block, self).__init__()
11
+
12
+ self.b1 = ops.EResidualBlock(64, 64, group=group)
13
+ self.c1 = ops.BasicBlock(64*2, 64, 1, 1, 0)
14
+ self.c2 = ops.BasicBlock(64*3, 64, 1, 1, 0)
15
+ self.c3 = ops.BasicBlock(64*4, 64, 1, 1, 0)
16
+
17
+ def forward(self, x):
18
+ c0 = o0 = x
19
+
20
+ b1 = self.b1(o0)
21
+ c1 = torch.cat([c0, b1], dim=1)
22
+ o1 = self.c1(c1)
23
+
24
+ b2 = self.b1(o1)
25
+ c2 = torch.cat([c1, b2], dim=1)
26
+ o2 = self.c2(c2)
27
+
28
+ b3 = self.b1(o2)
29
+ c3 = torch.cat([c2, b3], dim=1)
30
+ o3 = self.c3(c3)
31
+
32
+ return o3
33
+ '''
34
+
35
+ class MFCModule(nn.Module):
36
+ def __init__(self,in_channels,out_channels,gropus=1):
37
+ super(MFCModule,self).__init__()
38
+ kernel_size =3
39
+ padding = 1
40
+ features = 64
41
+ features1 = 32
42
+ distill_rate = 0.5
43
+ self.distilled_channels = int(features*distill_rate)
44
+ self.remaining_channels = int(features-self.distilled_channels)
45
+ self.conv1_1 = nn.Sequential(nn.Conv2d(in_channels=features1,out_channels=features1,kernel_size=kernel_size,padding=padding,groups=1,bias=False))
46
+ self.conv2_1 = nn.Sequential(nn.Conv2d(in_channels=features1,out_channels=features1,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
47
+ self.conv3_1 = nn.Sequential(nn.Conv2d(in_channels=features1,out_channels=features1,kernel_size=kernel_size,padding=padding,groups=1,bias=False))
48
+ self.conv1_1_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
49
+ self.conv2_1_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
50
+ self.conv3_1_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
51
+ self.conv4_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
52
+ self.conv5_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
53
+ self.conv6_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
54
+ self.conv7_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
55
+ self.conv8_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
56
+ '''
57
+ self.conv1_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False))
58
+ self.conv2_1 = nn.Sequential(nn.Conv2d(in_channels=features1,out_channels=features1,kernel_size=kernel_size,padding=padding,groups=1,bias=False))
59
+ self.conv2_2 = nn.Sequential(nn.Conv2d(in_channels=features1,out_channels=features1,kernel_size=kernel_size,padding=padding,groups=1,bias=False))
60
+ self.conv3_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False))
61
+ self.conv4_1 = nn.Sequential(nn.Conv2d(in_channels=features1,out_channels=features1,kernel_size=kernel_size,padding=padding,groups=1,bias=False))
62
+ self.conv4_2 = nn.Sequential(nn.Conv2d(in_channels=features1,out_channels=features1,kernel_size=kernel_size,padding=padding,groups=1,bias=False))
63
+ self.conv5_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False))
64
+ self.conv6_1 = nn.Sequential(nn.Conv2d(in_channels=2*features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
65
+ '''
66
+ self.ReLU = nn.ReLU(inplace=True)
67
+ def forward(self,input):
68
+ dit1,remain1 = torch.split(input,(self.distilled_channels,self.remaining_channels),dim=1)
69
+ out1_1=self.conv1_1(dit1)
70
+ out1_1_t = self.ReLU(out1_1)
71
+ out2_1=self.conv2_1(out1_1_t)
72
+ out3_1=self.conv3_1(out2_1)
73
+ out1_2=self.conv1_1(remain1)
74
+ out1_2_t = self.ReLU(out1_2)
75
+ out2_2=self.conv2_1(out1_2_t)
76
+ out3_2=self.conv3_1(out2_2)
77
+ #out3 = torch.cat([out1_1,out3_1],dim=1)
78
+ #out3_t = torch.cat([out1_2,out3_2],dim=1)
79
+ out3_t = torch.cat([out3_1,out3_2],dim=1)
80
+ out3 = self.ReLU(out3_t)
81
+ #out3 = input+out3
82
+ out1_1t = self.conv1_1_1(input)
83
+ out1_2t1 = self.conv2_1_1(out1_1t)
84
+ out1_3t1 = self.conv3_1_1(out1_2t1)
85
+ out1_3t1 = out3+out1_3t1
86
+ out4_1=self.conv4_1(out1_3t1)
87
+ out5_1=self.conv5_1(out4_1)
88
+ out6_1=self.conv6_1(out5_1)
89
+ out7_1=self.conv7_1(out6_1)
90
+ out8_1=self.conv8_1(out7_1)
91
+ out8_1=out8_1+input+out4_1
92
+ '''
93
+ out1_c = self.conv1_1(input)
94
+ dit1,remain1 = torch.split(out1_c,(self.distilled_channels,self.remaining_channels),dim=1)
95
+ out1_r = self.ReLU(remain1)
96
+ out1_d = self.ReLU(dit1)
97
+ out2_r = self.conv2_1(out1_r)
98
+ out2_d = self.conv2_2(out1_d)
99
+ out2 = torch.cat([out2_r,out2_d],dim=1)
100
+ out2_r = torch.cat([remain1,out2_r],dim=1)
101
+ out2_d = torch.cat([dit1,out2_d],dim=1)
102
+ out2_1 = out2+out2_r+out2_d
103
+ out2 = self.ReLU(out2_1)
104
+ out3 = self.conv3_1(out2)
105
+ dit3,remain3 = torch.split(out3,(self.distilled_channels,self.remaining_channels),dim=1)
106
+ out3_r = self.ReLU(remain3)
107
+ out3_d = self.ReLU(dit3)
108
+ out4_r = self.conv4_1(out3_r)
109
+ out4_d = self.conv4_2(out3_d)
110
+ out4 = torch.cat([out4_r,out4_d],dim=1)
111
+ out4_r = torch.cat([remain3,out4_r],dim=1)
112
+ out4_d = torch.cat([dit3,out4_d],dim=1)
113
+ out4_1 = out4+out4_r+out4_d
114
+ out4 = self.ReLU(out4_1)
115
+ out5 = self.conv5_1(out4)
116
+ out5_1 = torch.cat([out3,out5],dim=1)
117
+ out5_1 = self.ReLU(out5_1)
118
+ out6_1 = self.conv6_1(out5_1)
119
+ out6_r = input+out6_1
120
+ '''
121
+ return out8_1
122
+
123
+
124
+ class Net(nn.Module):
125
+ def __init__(self, **kwargs):
126
+ super(Net, self).__init__()
127
+
128
+ scale = kwargs.get("scale") #value of scale is scale.
129
+ multi_scale = kwargs.get("multi_scale") # value of multi_scale is multi_scale in args.
130
+ group = kwargs.get("group", 1) #if valule of group isn't given, group is 1.
131
+ kernel_size = 3 #tcw 201904091123
132
+ kernel_size1 = 1 #tcw 201904091123
133
+ padding1 = 0 #tcw 201904091124
134
+ padding = 1 #tcw201904091123
135
+ features = 64 #tcw201904091124
136
+ groups = 1 #tcw201904091124
137
+ channels = 3
138
+ features1 = 64
139
+ self.sub_mean = ops.MeanShift((0.4488, 0.4371, 0.4040), sub=True)
140
+ self.add_mean = ops.MeanShift((0.4488, 0.4371, 0.4040), sub=False)
141
+ '''
142
+ in_channels, out_channels, kernel_size, stride, padding,dialation, groups,
143
+ '''
144
+
145
+ self.conv1_1 = nn.Sequential(nn.Conv2d(in_channels=channels,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
146
+ self.b1 = MFCModule(features,features)
147
+ self.b2 = MFCModule(features,features)
148
+ self.b3 = MFCModule(features,features)
149
+ self.b4 = MFCModule(features,features)
150
+ self.b5 = MFCModule(features,features)
151
+ self.b6 = MFCModule(features,features)
152
+ self.ReLU=nn.ReLU(inplace=True)
153
+ #self.conv2 = nn.Sequential(nn.Conv2d(in_channels=6*features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
154
+ self.conv2 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
155
+ self.conv3 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=3,kernel_size=kernel_size,padding=padding,groups=1,bias=False))
156
+ self.upsample = ops.UpsampleBlock(64, scale=scale, multi_scale=multi_scale,group=1)
157
+ def forward(self, x, scale):
158
+ x = self.sub_mean(x)
159
+ x1 = self.conv1_1(x)
160
+ b1 = self.b1(x1)
161
+ b2 = self.b2(b1)
162
+ b3 = self.b3(b2)
163
+ b4 = self.b4(b3)
164
+ b5 = self.b5(b4)
165
+ b5 = b5+b1
166
+ b6 = self.b6(b5)
167
+ b6 = b6+x1
168
+ #b6 = torch.cat([b1,b2,b3,b4,b5,b6],dim=1)
169
+ #b6 = x1+b1+b2+b3+b4+b5+b6
170
+ #x2 = x1+b1+b2+b3+b4+b5+b6
171
+ x2 = self.conv2(b6)
172
+ temp = self.upsample(x2, scale=scale)
173
+ #temp1 = self.upsample(x1, scale=scale)
174
+ #temp = temp+temp1
175
+ #temp2 = self.ReLU(temp)
176
+ out = self.conv3(temp)
177
+ out = self.add_mean(out)
178
+ return out
model/LSGSRCNN.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import model.ops as ops
4
+
5
+ '''
6
+ class Block(nn.Module):
7
+ def __init__(self,
8
+ in_channels, out_channels,
9
+ group=1):
10
+ super(Block, self).__init__()
11
+
12
+ self.b1 = ops.EResidualBlock(64, 64, group=group)
13
+ self.c1 = ops.BasicBlock(64*2, 64, 1, 1, 0)
14
+ self.c2 = ops.BasicBlock(64*3, 64, 1, 1, 0)
15
+ self.c3 = ops.BasicBlock(64*4, 64, 1, 1, 0)
16
+
17
+ def forward(self, x):
18
+ c0 = o0 = x
19
+
20
+ b1 = self.b1(o0)
21
+ c1 = torch.cat([c0, b1], dim=1)
22
+ o1 = self.c1(c1)
23
+
24
+ b2 = self.b1(o1)
25
+ c2 = torch.cat([c1, b2], dim=1)
26
+ o2 = self.c2(c2)
27
+
28
+ b3 = self.b1(o2)
29
+ c3 = torch.cat([c2, b3], dim=1)
30
+ o3 = self.c3(c3)
31
+
32
+ return o3
33
+ '''
34
+
35
+ class MFCModule(nn.Module):
36
+ def __init__(self,in_channels,out_channels,gropus=1):
37
+ super(MFCModule,self).__init__()
38
+ kernel_size =3
39
+ padding = 1
40
+ features = 64
41
+ features1 = 32
42
+ distill_rate = 0.5
43
+ self.distilled_channels = int(features*distill_rate)
44
+ self.remaining_channels = int(features-self.distilled_channels)
45
+ self.conv1_1 = nn.Sequential(nn.Conv2d(in_channels=features1,out_channels=features1,kernel_size=kernel_size,padding=padding,groups=1,bias=False))
46
+ self.conv2_1 = nn.Sequential(nn.Conv2d(in_channels=features1,out_channels=features1,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
47
+ self.conv3_1 = nn.Sequential(nn.Conv2d(in_channels=features1,out_channels=features1,kernel_size=kernel_size,padding=padding,groups=1,bias=False))
48
+ self.conv1_1_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
49
+ self.conv2_1_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
50
+ self.conv3_1_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
51
+ self.conv4_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
52
+ self.conv5_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
53
+ self.conv6_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
54
+ self.conv7_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
55
+ self.conv8_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
56
+ '''
57
+ self.conv1_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False))
58
+ self.conv2_1 = nn.Sequential(nn.Conv2d(in_channels=features1,out_channels=features1,kernel_size=kernel_size,padding=padding,groups=1,bias=False))
59
+ self.conv2_2 = nn.Sequential(nn.Conv2d(in_channels=features1,out_channels=features1,kernel_size=kernel_size,padding=padding,groups=1,bias=False))
60
+ self.conv3_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False))
61
+ self.conv4_1 = nn.Sequential(nn.Conv2d(in_channels=features1,out_channels=features1,kernel_size=kernel_size,padding=padding,groups=1,bias=False))
62
+ self.conv4_2 = nn.Sequential(nn.Conv2d(in_channels=features1,out_channels=features1,kernel_size=kernel_size,padding=padding,groups=1,bias=False))
63
+ self.conv5_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False))
64
+ self.conv6_1 = nn.Sequential(nn.Conv2d(in_channels=2*features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
65
+ '''
66
+ self.ReLU = nn.ReLU(inplace=True)
67
+ def forward(self,input):
68
+ dit1,remain1 = torch.split(input,(self.distilled_channels,self.remaining_channels),dim=1)
69
+ out1_1=self.conv1_1(dit1)
70
+ out1_1_t = self.ReLU(out1_1)
71
+ out2_1=self.conv2_1(out1_1_t)
72
+ out3_1=self.conv3_1(out2_1)
73
+ out1_2=self.conv1_1(remain1)
74
+ out1_2_t = self.ReLU(out1_2)
75
+ out2_2=self.conv2_1(out1_2_t)
76
+ out3_2=self.conv3_1(out2_2)
77
+ #out3 = torch.cat([out1_1,out3_1],dim=1)
78
+ #out3_t = torch.cat([out1_2,out3_2],dim=1)
79
+ out3_t = torch.cat([out3_1,out3_2],dim=1)
80
+ out3 = self.ReLU(out3_t)
81
+ #out3 = input+out3
82
+ out1_1t = self.conv1_1_1(input)
83
+ out1_2t1 = self.conv2_1_1(out1_1t)
84
+ out1_3t1 = self.conv3_1_1(out1_2t1)
85
+ out1_3t1 = out3+out1_3t1
86
+ out4_1=self.conv4_1(out1_3t1)
87
+ out5_1=self.conv5_1(out4_1)
88
+ out6_1=self.conv6_1(out5_1)
89
+ out7_1=self.conv7_1(out6_1)
90
+ out8_1=self.conv8_1(out7_1)
91
+ out8_1=out8_1+input+out4_1
92
+ '''
93
+ out1_c = self.conv1_1(input)
94
+ dit1,remain1 = torch.split(out1_c,(self.distilled_channels,self.remaining_channels),dim=1)
95
+ out1_r = self.ReLU(remain1)
96
+ out1_d = self.ReLU(dit1)
97
+ out2_r = self.conv2_1(out1_r)
98
+ out2_d = self.conv2_2(out1_d)
99
+ out2 = torch.cat([out2_r,out2_d],dim=1)
100
+ out2_r = torch.cat([remain1,out2_r],dim=1)
101
+ out2_d = torch.cat([dit1,out2_d],dim=1)
102
+ out2_1 = out2+out2_r+out2_d
103
+ out2 = self.ReLU(out2_1)
104
+ out3 = self.conv3_1(out2)
105
+ dit3,remain3 = torch.split(out3,(self.distilled_channels,self.remaining_channels),dim=1)
106
+ out3_r = self.ReLU(remain3)
107
+ out3_d = self.ReLU(dit3)
108
+ out4_r = self.conv4_1(out3_r)
109
+ out4_d = self.conv4_2(out3_d)
110
+ out4 = torch.cat([out4_r,out4_d],dim=1)
111
+ out4_r = torch.cat([remain3,out4_r],dim=1)
112
+ out4_d = torch.cat([dit3,out4_d],dim=1)
113
+ out4_1 = out4+out4_r+out4_d
114
+ out4 = self.ReLU(out4_1)
115
+ out5 = self.conv5_1(out4)
116
+ out5_1 = torch.cat([out3,out5],dim=1)
117
+ out5_1 = self.ReLU(out5_1)
118
+ out6_1 = self.conv6_1(out5_1)
119
+ out6_r = input+out6_1
120
+ '''
121
+ return out8_1
122
+
123
+
124
+ class Net(nn.Module):
125
+ def __init__(self, **kwargs):
126
+ super(Net, self).__init__()
127
+
128
+ scale = kwargs.get("scale") #value of scale is scale.
129
+ multi_scale = kwargs.get("multi_scale") # value of multi_scale is multi_scale in args.
130
+ group = kwargs.get("group", 1) #if valule of group isn't given, group is 1.
131
+ kernel_size = 3 #tcw 201904091123
132
+ kernel_size1 = 1 #tcw 201904091123
133
+ padding1 = 0 #tcw 201904091124
134
+ padding = 1 #tcw201904091123
135
+ features = 64 #tcw201904091124
136
+ groups = 1 #tcw201904091124
137
+ channels = 3
138
+ features1 = 64
139
+ self.sub_mean = ops.MeanShift((0.4488, 0.4371, 0.4040), sub=True)
140
+ self.add_mean = ops.MeanShift((0.4488, 0.4371, 0.4040), sub=False)
141
+ '''
142
+ in_channels, out_channels, kernel_size, stride, padding,dialation, groups,
143
+ '''
144
+
145
+ self.conv1_1 = nn.Sequential(nn.Conv2d(in_channels=channels,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
146
+ self.b1 = MFCModule(features,features)
147
+ self.b2 = MFCModule(features,features)
148
+ self.b3 = MFCModule(features,features)
149
+ self.b4 = MFCModule(features,features)
150
+ self.b5 = MFCModule(features,features)
151
+ self.b6 = MFCModule(features,features)
152
+ self.ReLU=nn.ReLU(inplace=True)
153
+ #self.conv2 = nn.Sequential(nn.Conv2d(in_channels=6*features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
154
+ self.conv2 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
155
+ self.conv3 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=3,kernel_size=kernel_size,padding=padding,groups=1,bias=False))
156
+ self.upsample = ops.UpsampleBlock(64, scale=scale, multi_scale=multi_scale,group=1)
157
+ def forward(self, x, scale):
158
+ x = self.sub_mean(x)
159
+ x1 = self.conv1_1(x)
160
+ b1 = self.b1(x1)
161
+ b2 = self.b2(b1)
162
+ b3 = self.b3(b2)
163
+ b4 = self.b4(b3)
164
+ b5 = self.b5(b4)
165
+ b5 = b5+b1
166
+ b6 = self.b6(b5)
167
+ b6 = b6+x1
168
+ #b6 = torch.cat([b1,b2,b3,b4,b5,b6],dim=1)
169
+ #b6 = x1+b1+b2+b3+b4+b5+b6
170
+ #x2 = x1+b1+b2+b3+b4+b5+b6
171
+ x2 = self.conv2(b6)
172
+ temp = self.upsample(x2, scale=scale)
173
+ #temp1 = self.upsample(x1, scale=scale)
174
+ #temp = temp+temp1
175
+ #temp2 = self.ReLU(temp)
176
+ out = self.conv3(temp)
177
+ out = self.add_mean(out)
178
+ return out
model/LSGSRCNN_b.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import model.ops as ops
4
+
5
+ '''
6
+ class Block(nn.Module):
7
+ def __init__(self,
8
+ in_channels, out_channels,
9
+ group=1):
10
+ super(Block, self).__init__()
11
+
12
+ self.b1 = ops.EResidualBlock(64, 64, group=group)
13
+ self.c1 = ops.BasicBlock(64*2, 64, 1, 1, 0)
14
+ self.c2 = ops.BasicBlock(64*3, 64, 1, 1, 0)
15
+ self.c3 = ops.BasicBlock(64*4, 64, 1, 1, 0)
16
+
17
+ def forward(self, x):
18
+ c0 = o0 = x
19
+
20
+ b1 = self.b1(o0)
21
+ c1 = torch.cat([c0, b1], dim=1)
22
+ o1 = self.c1(c1)
23
+
24
+ b2 = self.b1(o1)
25
+ c2 = torch.cat([c1, b2], dim=1)
26
+ o2 = self.c2(c2)
27
+
28
+ b3 = self.b1(o2)
29
+ c3 = torch.cat([c2, b3], dim=1)
30
+ o3 = self.c3(c3)
31
+
32
+ return o3
33
+ '''
34
+
35
+ class MFCModule(nn.Module):
36
+ def __init__(self,in_channels,out_channels,gropus=1):
37
+ super(MFCModule,self).__init__()
38
+ kernel_size =3
39
+ padding = 1
40
+ features = 64
41
+ features1 = 32
42
+ distill_rate = 0.5
43
+ self.distilled_channels = int(features*distill_rate)
44
+ self.remaining_channels = int(features-self.distilled_channels)
45
+ self.conv1_1 = nn.Sequential(nn.Conv2d(in_channels=features1,out_channels=features1,kernel_size=kernel_size,padding=padding,groups=1,bias=False))
46
+ self.conv2_1 = nn.Sequential(nn.Conv2d(in_channels=features1,out_channels=features1,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
47
+ self.conv3_1 = nn.Sequential(nn.Conv2d(in_channels=features1,out_channels=features1,kernel_size=kernel_size,padding=padding,groups=1,bias=False))
48
+ self.conv1_1_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
49
+ self.conv2_1_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
50
+ self.conv3_1_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
51
+ self.conv4_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
52
+ self.conv5_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
53
+ self.conv6_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
54
+ self.conv7_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
55
+ self.conv8_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
56
+ '''
57
+ self.conv1_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False))
58
+ self.conv2_1 = nn.Sequential(nn.Conv2d(in_channels=features1,out_channels=features1,kernel_size=kernel_size,padding=padding,groups=1,bias=False))
59
+ self.conv2_2 = nn.Sequential(nn.Conv2d(in_channels=features1,out_channels=features1,kernel_size=kernel_size,padding=padding,groups=1,bias=False))
60
+ self.conv3_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False))
61
+ self.conv4_1 = nn.Sequential(nn.Conv2d(in_channels=features1,out_channels=features1,kernel_size=kernel_size,padding=padding,groups=1,bias=False))
62
+ self.conv4_2 = nn.Sequential(nn.Conv2d(in_channels=features1,out_channels=features1,kernel_size=kernel_size,padding=padding,groups=1,bias=False))
63
+ self.conv5_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False))
64
+ self.conv6_1 = nn.Sequential(nn.Conv2d(in_channels=2*features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
65
+ '''
66
+ self.ReLU = nn.ReLU(inplace=True)
67
+ def forward(self,input):
68
+ dit1,remain1 = torch.split(input,(self.distilled_channels,self.remaining_channels),dim=1)
69
+ out1_1=self.conv1_1(dit1)
70
+ out1_1_t = self.ReLU(out1_1)
71
+ out2_1=self.conv2_1(out1_1_t)
72
+ out3_1=self.conv3_1(out2_1)
73
+ out1_2=self.conv1_1(remain1)
74
+ out1_2_t = self.ReLU(out1_2)
75
+ out2_2=self.conv2_1(out1_2_t)
76
+ out3_2=self.conv3_1(out2_2)
77
+ #out3 = torch.cat([out1_1,out3_1],dim=1)
78
+ #out3_t = torch.cat([out1_2,out3_2],dim=1)
79
+ out3_t = torch.cat([out3_1,out3_2],dim=1)
80
+ out3 = self.ReLU(out3_t)
81
+ #out3 = input+out3
82
+ out1_1t = self.conv1_1_1(input)
83
+ out1_2t1 = self.conv2_1_1(out1_1t)
84
+ out1_3t1 = self.conv3_1_1(out1_2t1)
85
+ out1_3t1 = out3+out1_3t1
86
+ out4_1=self.conv4_1(out1_3t1)
87
+ out5_1=self.conv5_1(out4_1)
88
+ out6_1=self.conv6_1(out5_1)
89
+ out7_1=self.conv7_1(out6_1)
90
+ out8_1=self.conv8_1(out7_1)
91
+ out8_1=out8_1+input+out4_1
92
+ '''
93
+ out1_c = self.conv1_1(input)
94
+ dit1,remain1 = torch.split(out1_c,(self.distilled_channels,self.remaining_channels),dim=1)
95
+ out1_r = self.ReLU(remain1)
96
+ out1_d = self.ReLU(dit1)
97
+ out2_r = self.conv2_1(out1_r)
98
+ out2_d = self.conv2_2(out1_d)
99
+ out2 = torch.cat([out2_r,out2_d],dim=1)
100
+ out2_r = torch.cat([remain1,out2_r],dim=1)
101
+ out2_d = torch.cat([dit1,out2_d],dim=1)
102
+ out2_1 = out2+out2_r+out2_d
103
+ out2 = self.ReLU(out2_1)
104
+ out3 = self.conv3_1(out2)
105
+ dit3,remain3 = torch.split(out3,(self.distilled_channels,self.remaining_channels),dim=1)
106
+ out3_r = self.ReLU(remain3)
107
+ out3_d = self.ReLU(dit3)
108
+ out4_r = self.conv4_1(out3_r)
109
+ out4_d = self.conv4_2(out3_d)
110
+ out4 = torch.cat([out4_r,out4_d],dim=1)
111
+ out4_r = torch.cat([remain3,out4_r],dim=1)
112
+ out4_d = torch.cat([dit3,out4_d],dim=1)
113
+ out4_1 = out4+out4_r+out4_d
114
+ out4 = self.ReLU(out4_1)
115
+ out5 = self.conv5_1(out4)
116
+ out5_1 = torch.cat([out3,out5],dim=1)
117
+ out5_1 = self.ReLU(out5_1)
118
+ out6_1 = self.conv6_1(out5_1)
119
+ out6_r = input+out6_1
120
+ '''
121
+ return out8_1
122
+
123
+
124
+ class Net(nn.Module):
125
+ def __init__(self, **kwargs):
126
+ super(Net, self).__init__()
127
+
128
+ scale = kwargs.get("scale") #value of scale is scale.
129
+ multi_scale = kwargs.get("multi_scale") # value of multi_scale is multi_scale in args.
130
+ group = kwargs.get("group", 1) #if valule of group isn't given, group is 1.
131
+ kernel_size = 3 #tcw 201904091123
132
+ kernel_size1 = 1 #tcw 201904091123
133
+ padding1 = 0 #tcw 201904091124
134
+ padding = 1 #tcw201904091123
135
+ features = 64 #tcw201904091124
136
+ groups = 1 #tcw201904091124
137
+ channels = 3
138
+ features1 = 64
139
+ self.sub_mean = ops.MeanShift((0.4488, 0.4371, 0.4040), sub=True)
140
+ self.add_mean = ops.MeanShift((0.4488, 0.4371, 0.4040), sub=False)
141
+ '''
142
+ in_channels, out_channels, kernel_size, stride, padding,dialation, groups,
143
+ '''
144
+
145
+ self.conv1_1 = nn.Sequential(nn.Conv2d(in_channels=channels,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
146
+ self.b1 = MFCModule(features,features)
147
+ self.b2 = MFCModule(features,features)
148
+ self.b3 = MFCModule(features,features)
149
+ self.b4 = MFCModule(features,features)
150
+ self.b5 = MFCModule(features,features)
151
+ self.b6 = MFCModule(features,features)
152
+ self.ReLU=nn.ReLU(inplace=True)
153
+ #self.conv2 = nn.Sequential(nn.Conv2d(in_channels=6*features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
154
+ self.conv2 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
155
+ self.conv3 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=3,kernel_size=kernel_size,padding=padding,groups=1,bias=False))
156
+ self.upsample = ops.UpsampleBlock(64, scale=scale, multi_scale=multi_scale,group=1)
157
+ def forward(self, x, scale):
158
+ x = self.sub_mean(x)
159
+ x1 = self.conv1_1(x)
160
+ b1 = self.b1(x1)
161
+ b2 = self.b2(b1)
162
+ b3 = self.b3(b2)
163
+ b4 = self.b4(b3)
164
+ b5 = self.b5(b4)
165
+ b5 = b5+b1
166
+ b6 = self.b6(b5)
167
+ b6 = b6+x1
168
+ #b6 = torch.cat([b1,b2,b3,b4,b5,b6],dim=1)
169
+ #b6 = x1+b1+b2+b3+b4+b5+b6
170
+ #x2 = x1+b1+b2+b3+b4+b5+b6
171
+ x2 = self.conv2(b6)
172
+ temp = self.upsample(x2, scale=scale)
173
+ #temp1 = self.upsample(x1, scale=scale)
174
+ #temp = temp+temp1
175
+ #temp2 = self.ReLU(temp)
176
+ out = self.conv3(temp)
177
+ out = self.add_mean(out)
178
+ return out
model/__init__.py ADDED
File without changes
model/ops.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.init as init
5
+ import torch.nn.functional as F
6
+
7
+ def init_weights(modules):
8
+ pass
9
+
10
+
11
+ class MeanShift(nn.Module):
12
+ def __init__(self, mean_rgb, sub):
13
+ super(MeanShift, self).__init__()
14
+
15
+ sign = -1 if sub else 1
16
+ r = mean_rgb[0] * sign
17
+ g = mean_rgb[1] * sign
18
+ b = mean_rgb[2] * sign
19
+
20
+ self.shifter = nn.Conv2d(3, 3, 1, 1, 0) #3 is size of output, 3 is size of input, 1 is kernel 1 is padding, 0 is group
21
+ self.shifter.weight.data = torch.eye(3).view(3, 3, 1, 1) # view(3,3,1,1) convert a shape into (3,3,1,1) eye(3) is a 3x3 matrix and diagonal is 1.
22
+ self.shifter.bias.data = torch.Tensor([r, g, b])
23
+ #in_channels, out_channels,ksize=3, stride=1, pad=1
24
+ # Freeze the mean shift layer
25
+ for params in self.shifter.parameters():
26
+ params.requires_grad = False
27
+
28
+ def forward(self, x):
29
+ x = self.shifter(x)
30
+ return x
31
+
32
+
33
+ class BasicBlock(nn.Module):
34
+ def __init__(self,
35
+ in_channels, out_channels,
36
+ ksize=3, stride=1, pad=1):
37
+ super(BasicBlock, self).__init__()
38
+
39
+ self.body = nn.Sequential(
40
+ nn.Conv2d(in_channels, out_channels, ksize, stride, pad),
41
+ nn.ReLU(inplace=True)
42
+ )
43
+
44
+ init_weights(self.modules)
45
+
46
+ def forward(self, x):
47
+ out = self.body(x)
48
+ return out
49
+
50
+
51
+ class ResidualBlock(nn.Module):
52
+ def __init__(self,
53
+ in_channels, out_channels):
54
+ super(ResidualBlock, self).__init__()
55
+
56
+ self.body = nn.Sequential(
57
+ nn.Conv2d(in_channels, out_channels, 3, 1, 1),
58
+ nn.ReLU(inplace=True),
59
+ nn.Conv2d(out_channels, out_channels, 3, 1, 1),
60
+ )
61
+
62
+ init_weights(self.modules)
63
+
64
+ def forward(self, x):
65
+ out = self.body(x)
66
+ out = F.relu(out + x)
67
+ return out
68
+
69
+
70
+ class EResidualBlock(nn.Module):
71
+ def __init__(self,
72
+ in_channels, out_channels,
73
+ group=1):
74
+ super(EResidualBlock, self).__init__()
75
+
76
+ self.body = nn.Sequential(
77
+ nn.Conv2d(in_channels, out_channels, 3, 1, 1, groups=group),
78
+ nn.ReLU(inplace=True),
79
+ nn.Conv2d(out_channels, out_channels, 3, 1, 1, groups=group),
80
+ nn.ReLU(inplace=True),
81
+ nn.Conv2d(out_channels, out_channels, 1, 1, 0),
82
+ )
83
+
84
+ init_weights(self.modules)
85
+
86
+ def forward(self, x):
87
+ out = self.body(x)
88
+ out = F.relu(out + x)
89
+ return out
90
+
91
+
92
+ class UpsampleBlock(nn.Module):
93
+ def __init__(self,
94
+ n_channels, scale, multi_scale,
95
+ group=1):
96
+ super(UpsampleBlock, self).__init__()
97
+
98
+ if multi_scale:
99
+ self.up2 = _UpsampleBlock(n_channels, scale=2, group=group)
100
+ self.up3 = _UpsampleBlock(n_channels, scale=3, group=group)
101
+ self.up4 = _UpsampleBlock(n_channels, scale=4, group=group)
102
+ else:
103
+ self.up = _UpsampleBlock(n_channels, scale=scale, group=group)
104
+
105
+ self.multi_scale = multi_scale
106
+
107
+ def forward(self, x, scale):
108
+ if self.multi_scale:
109
+ if scale == 2:
110
+ return self.up2(x)
111
+ elif scale == 3:
112
+ return self.up3(x)
113
+ elif scale == 4:
114
+ return self.up4(x)
115
+ else:
116
+ return self.up(x)
117
+
118
+
119
+ class _UpsampleBlock(nn.Module):
120
+ def __init__(self,
121
+ n_channels, scale,
122
+ group=1):
123
+ super(_UpsampleBlock, self).__init__()
124
+
125
+ modules = []
126
+ if scale == 2 or scale == 4 or scale == 8:
127
+ for _ in range(int(math.log(scale, 2))):
128
+ modules += [nn.Conv2d(n_channels, 4*n_channels, 3, 1, 1, groups=group), nn.ReLU(inplace=True)]
129
+ #modules += [nn.Conv2d(n_channels, 4*n_channels, 3, 1, 1, groups=group)]
130
+ modules += [nn.PixelShuffle(2)]
131
+ elif scale == 3:
132
+ modules += [nn.Conv2d(n_channels, 9*n_channels, 3, 1, 1, groups=group), nn.ReLU(inplace=True)]
133
+ #modules += [nn.Conv2d(n_channels, 9*n_channels, 3, 1, 1, groups=group)]
134
+ modules += [nn.PixelShuffle(3)]
135
+
136
+ self.body = nn.Sequential(*modules)
137
+ init_weights(self.modules)
138
+
139
+ def forward(self, x):
140
+ out = self.body(x)
141
+ return out