Spaces:
Runtime error
Runtime error
deploy torch model
Browse files- Set5/x2/img_001_SRF_2_HR.png +0 -0
- Set5/x2/img_001_SRF_2_LR.png +0 -0
- Set5/x2/img_002_SRF_2_HR.png +0 -0
- Set5/x2/img_002_SRF_2_LR.png +0 -0
- Set5/x2/img_003_SRF_2_HR.png +0 -0
- Set5/x2/img_003_SRF_2_LR.png +0 -0
- Set5/x2/img_004_SRF_2_HR.png +0 -0
- Set5/x2/img_004_SRF_2_LR.png +0 -0
- Set5/x2/img_005_SRF_2_HR.png +0 -0
- Set5/x2/img_005_SRF_2_LR.png +0 -0
- Set5/x3/img_001_SRF_3_HR.png +0 -0
- Set5/x3/img_001_SRF_3_LR.png +0 -0
- Set5/x3/img_002_SRF_3_HR.png +0 -0
- Set5/x3/img_002_SRF_3_LR.png +0 -0
- Set5/x3/img_003_SRF_3_HR.png +0 -0
- Set5/x3/img_003_SRF_3_LR.png +0 -0
- Set5/x3/img_004_SRF_3_HR.png +0 -0
- Set5/x3/img_004_SRF_3_LR.png +0 -0
- Set5/x3/img_005_SRF_3_HR.png +0 -0
- Set5/x3/img_005_SRF_3_LR.png +0 -0
- Set5/x4/img_001_SRF_4_HR.png +0 -0
- Set5/x4/img_001_SRF_4_LR.png +0 -0
- Set5/x4/img_002_SRF_4_HR.png +0 -0
- Set5/x4/img_002_SRF_4_LR.png +0 -0
- Set5/x4/img_003_SRF_4_HR.png +0 -0
- Set5/x4/img_003_SRF_4_LR.png +0 -0
- Set5/x4/img_004_SRF_4_HR.png +0 -0
- Set5/x4/img_004_SRF_4_LR.png +0 -0
- Set5/x4/img_005_SRF_4_HR.png +0 -0
- Set5/x4/img_005_SRF_4_LR.png +0 -0
- app.py +0 -0
- checkpoint/HGSRCNN.pth +3 -0
- checkpoint/LSGSRCNN.pth +3 -0
- model/HGSRCNN.py +178 -0
- model/LSGSRCNN.py +178 -0
- model/LSGSRCNN_b.py +178 -0
- model/__init__.py +0 -0
- model/ops.py +141 -0
Set5/x2/img_001_SRF_2_HR.png
ADDED
![]() |
Set5/x2/img_001_SRF_2_LR.png
ADDED
![]() |
Set5/x2/img_002_SRF_2_HR.png
ADDED
![]() |
Set5/x2/img_002_SRF_2_LR.png
ADDED
![]() |
Set5/x2/img_003_SRF_2_HR.png
ADDED
![]() |
Set5/x2/img_003_SRF_2_LR.png
ADDED
![]() |
Set5/x2/img_004_SRF_2_HR.png
ADDED
![]() |
Set5/x2/img_004_SRF_2_LR.png
ADDED
![]() |
Set5/x2/img_005_SRF_2_HR.png
ADDED
![]() |
Set5/x2/img_005_SRF_2_LR.png
ADDED
![]() |
Set5/x3/img_001_SRF_3_HR.png
ADDED
![]() |
Set5/x3/img_001_SRF_3_LR.png
ADDED
![]() |
Set5/x3/img_002_SRF_3_HR.png
ADDED
![]() |
Set5/x3/img_002_SRF_3_LR.png
ADDED
![]() |
Set5/x3/img_003_SRF_3_HR.png
ADDED
![]() |
Set5/x3/img_003_SRF_3_LR.png
ADDED
![]() |
Set5/x3/img_004_SRF_3_HR.png
ADDED
![]() |
Set5/x3/img_004_SRF_3_LR.png
ADDED
![]() |
Set5/x3/img_005_SRF_3_HR.png
ADDED
![]() |
Set5/x3/img_005_SRF_3_LR.png
ADDED
![]() |
Set5/x4/img_001_SRF_4_HR.png
ADDED
![]() |
Set5/x4/img_001_SRF_4_LR.png
ADDED
![]() |
Set5/x4/img_002_SRF_4_HR.png
ADDED
![]() |
Set5/x4/img_002_SRF_4_LR.png
ADDED
![]() |
Set5/x4/img_003_SRF_4_HR.png
ADDED
![]() |
Set5/x4/img_003_SRF_4_LR.png
ADDED
![]() |
Set5/x4/img_004_SRF_4_HR.png
ADDED
![]() |
Set5/x4/img_004_SRF_4_LR.png
ADDED
![]() |
Set5/x4/img_005_SRF_4_HR.png
ADDED
![]() |
Set5/x4/img_005_SRF_4_LR.png
ADDED
![]() |
app.py
ADDED
File without changes
|
checkpoint/HGSRCNN.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:57a33668876d2667b436b74b975faa5789dcd7ba1934e96be0dca28f1c7dc45b
|
3 |
+
size 11026482
|
checkpoint/LSGSRCNN.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:57a33668876d2667b436b74b975faa5789dcd7ba1934e96be0dca28f1c7dc45b
|
3 |
+
size 11026482
|
model/HGSRCNN.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import model.ops as ops
|
4 |
+
|
5 |
+
'''
|
6 |
+
class Block(nn.Module):
|
7 |
+
def __init__(self,
|
8 |
+
in_channels, out_channels,
|
9 |
+
group=1):
|
10 |
+
super(Block, self).__init__()
|
11 |
+
|
12 |
+
self.b1 = ops.EResidualBlock(64, 64, group=group)
|
13 |
+
self.c1 = ops.BasicBlock(64*2, 64, 1, 1, 0)
|
14 |
+
self.c2 = ops.BasicBlock(64*3, 64, 1, 1, 0)
|
15 |
+
self.c3 = ops.BasicBlock(64*4, 64, 1, 1, 0)
|
16 |
+
|
17 |
+
def forward(self, x):
|
18 |
+
c0 = o0 = x
|
19 |
+
|
20 |
+
b1 = self.b1(o0)
|
21 |
+
c1 = torch.cat([c0, b1], dim=1)
|
22 |
+
o1 = self.c1(c1)
|
23 |
+
|
24 |
+
b2 = self.b1(o1)
|
25 |
+
c2 = torch.cat([c1, b2], dim=1)
|
26 |
+
o2 = self.c2(c2)
|
27 |
+
|
28 |
+
b3 = self.b1(o2)
|
29 |
+
c3 = torch.cat([c2, b3], dim=1)
|
30 |
+
o3 = self.c3(c3)
|
31 |
+
|
32 |
+
return o3
|
33 |
+
'''
|
34 |
+
|
35 |
+
class MFCModule(nn.Module):
|
36 |
+
def __init__(self,in_channels,out_channels,gropus=1):
|
37 |
+
super(MFCModule,self).__init__()
|
38 |
+
kernel_size =3
|
39 |
+
padding = 1
|
40 |
+
features = 64
|
41 |
+
features1 = 32
|
42 |
+
distill_rate = 0.5
|
43 |
+
self.distilled_channels = int(features*distill_rate)
|
44 |
+
self.remaining_channels = int(features-self.distilled_channels)
|
45 |
+
self.conv1_1 = nn.Sequential(nn.Conv2d(in_channels=features1,out_channels=features1,kernel_size=kernel_size,padding=padding,groups=1,bias=False))
|
46 |
+
self.conv2_1 = nn.Sequential(nn.Conv2d(in_channels=features1,out_channels=features1,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
|
47 |
+
self.conv3_1 = nn.Sequential(nn.Conv2d(in_channels=features1,out_channels=features1,kernel_size=kernel_size,padding=padding,groups=1,bias=False))
|
48 |
+
self.conv1_1_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
|
49 |
+
self.conv2_1_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
|
50 |
+
self.conv3_1_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
|
51 |
+
self.conv4_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
|
52 |
+
self.conv5_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
|
53 |
+
self.conv6_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
|
54 |
+
self.conv7_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
|
55 |
+
self.conv8_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
|
56 |
+
'''
|
57 |
+
self.conv1_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False))
|
58 |
+
self.conv2_1 = nn.Sequential(nn.Conv2d(in_channels=features1,out_channels=features1,kernel_size=kernel_size,padding=padding,groups=1,bias=False))
|
59 |
+
self.conv2_2 = nn.Sequential(nn.Conv2d(in_channels=features1,out_channels=features1,kernel_size=kernel_size,padding=padding,groups=1,bias=False))
|
60 |
+
self.conv3_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False))
|
61 |
+
self.conv4_1 = nn.Sequential(nn.Conv2d(in_channels=features1,out_channels=features1,kernel_size=kernel_size,padding=padding,groups=1,bias=False))
|
62 |
+
self.conv4_2 = nn.Sequential(nn.Conv2d(in_channels=features1,out_channels=features1,kernel_size=kernel_size,padding=padding,groups=1,bias=False))
|
63 |
+
self.conv5_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False))
|
64 |
+
self.conv6_1 = nn.Sequential(nn.Conv2d(in_channels=2*features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
|
65 |
+
'''
|
66 |
+
self.ReLU = nn.ReLU(inplace=True)
|
67 |
+
def forward(self,input):
|
68 |
+
dit1,remain1 = torch.split(input,(self.distilled_channels,self.remaining_channels),dim=1)
|
69 |
+
out1_1=self.conv1_1(dit1)
|
70 |
+
out1_1_t = self.ReLU(out1_1)
|
71 |
+
out2_1=self.conv2_1(out1_1_t)
|
72 |
+
out3_1=self.conv3_1(out2_1)
|
73 |
+
out1_2=self.conv1_1(remain1)
|
74 |
+
out1_2_t = self.ReLU(out1_2)
|
75 |
+
out2_2=self.conv2_1(out1_2_t)
|
76 |
+
out3_2=self.conv3_1(out2_2)
|
77 |
+
#out3 = torch.cat([out1_1,out3_1],dim=1)
|
78 |
+
#out3_t = torch.cat([out1_2,out3_2],dim=1)
|
79 |
+
out3_t = torch.cat([out3_1,out3_2],dim=1)
|
80 |
+
out3 = self.ReLU(out3_t)
|
81 |
+
#out3 = input+out3
|
82 |
+
out1_1t = self.conv1_1_1(input)
|
83 |
+
out1_2t1 = self.conv2_1_1(out1_1t)
|
84 |
+
out1_3t1 = self.conv3_1_1(out1_2t1)
|
85 |
+
out1_3t1 = out3+out1_3t1
|
86 |
+
out4_1=self.conv4_1(out1_3t1)
|
87 |
+
out5_1=self.conv5_1(out4_1)
|
88 |
+
out6_1=self.conv6_1(out5_1)
|
89 |
+
out7_1=self.conv7_1(out6_1)
|
90 |
+
out8_1=self.conv8_1(out7_1)
|
91 |
+
out8_1=out8_1+input+out4_1
|
92 |
+
'''
|
93 |
+
out1_c = self.conv1_1(input)
|
94 |
+
dit1,remain1 = torch.split(out1_c,(self.distilled_channels,self.remaining_channels),dim=1)
|
95 |
+
out1_r = self.ReLU(remain1)
|
96 |
+
out1_d = self.ReLU(dit1)
|
97 |
+
out2_r = self.conv2_1(out1_r)
|
98 |
+
out2_d = self.conv2_2(out1_d)
|
99 |
+
out2 = torch.cat([out2_r,out2_d],dim=1)
|
100 |
+
out2_r = torch.cat([remain1,out2_r],dim=1)
|
101 |
+
out2_d = torch.cat([dit1,out2_d],dim=1)
|
102 |
+
out2_1 = out2+out2_r+out2_d
|
103 |
+
out2 = self.ReLU(out2_1)
|
104 |
+
out3 = self.conv3_1(out2)
|
105 |
+
dit3,remain3 = torch.split(out3,(self.distilled_channels,self.remaining_channels),dim=1)
|
106 |
+
out3_r = self.ReLU(remain3)
|
107 |
+
out3_d = self.ReLU(dit3)
|
108 |
+
out4_r = self.conv4_1(out3_r)
|
109 |
+
out4_d = self.conv4_2(out3_d)
|
110 |
+
out4 = torch.cat([out4_r,out4_d],dim=1)
|
111 |
+
out4_r = torch.cat([remain3,out4_r],dim=1)
|
112 |
+
out4_d = torch.cat([dit3,out4_d],dim=1)
|
113 |
+
out4_1 = out4+out4_r+out4_d
|
114 |
+
out4 = self.ReLU(out4_1)
|
115 |
+
out5 = self.conv5_1(out4)
|
116 |
+
out5_1 = torch.cat([out3,out5],dim=1)
|
117 |
+
out5_1 = self.ReLU(out5_1)
|
118 |
+
out6_1 = self.conv6_1(out5_1)
|
119 |
+
out6_r = input+out6_1
|
120 |
+
'''
|
121 |
+
return out8_1
|
122 |
+
|
123 |
+
|
124 |
+
class Net(nn.Module):
|
125 |
+
def __init__(self, **kwargs):
|
126 |
+
super(Net, self).__init__()
|
127 |
+
|
128 |
+
scale = kwargs.get("scale") #value of scale is scale.
|
129 |
+
multi_scale = kwargs.get("multi_scale") # value of multi_scale is multi_scale in args.
|
130 |
+
group = kwargs.get("group", 1) #if valule of group isn't given, group is 1.
|
131 |
+
kernel_size = 3 #tcw 201904091123
|
132 |
+
kernel_size1 = 1 #tcw 201904091123
|
133 |
+
padding1 = 0 #tcw 201904091124
|
134 |
+
padding = 1 #tcw201904091123
|
135 |
+
features = 64 #tcw201904091124
|
136 |
+
groups = 1 #tcw201904091124
|
137 |
+
channels = 3
|
138 |
+
features1 = 64
|
139 |
+
self.sub_mean = ops.MeanShift((0.4488, 0.4371, 0.4040), sub=True)
|
140 |
+
self.add_mean = ops.MeanShift((0.4488, 0.4371, 0.4040), sub=False)
|
141 |
+
'''
|
142 |
+
in_channels, out_channels, kernel_size, stride, padding,dialation, groups,
|
143 |
+
'''
|
144 |
+
|
145 |
+
self.conv1_1 = nn.Sequential(nn.Conv2d(in_channels=channels,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
|
146 |
+
self.b1 = MFCModule(features,features)
|
147 |
+
self.b2 = MFCModule(features,features)
|
148 |
+
self.b3 = MFCModule(features,features)
|
149 |
+
self.b4 = MFCModule(features,features)
|
150 |
+
self.b5 = MFCModule(features,features)
|
151 |
+
self.b6 = MFCModule(features,features)
|
152 |
+
self.ReLU=nn.ReLU(inplace=True)
|
153 |
+
#self.conv2 = nn.Sequential(nn.Conv2d(in_channels=6*features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
|
154 |
+
self.conv2 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
|
155 |
+
self.conv3 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=3,kernel_size=kernel_size,padding=padding,groups=1,bias=False))
|
156 |
+
self.upsample = ops.UpsampleBlock(64, scale=scale, multi_scale=multi_scale,group=1)
|
157 |
+
def forward(self, x, scale):
|
158 |
+
x = self.sub_mean(x)
|
159 |
+
x1 = self.conv1_1(x)
|
160 |
+
b1 = self.b1(x1)
|
161 |
+
b2 = self.b2(b1)
|
162 |
+
b3 = self.b3(b2)
|
163 |
+
b4 = self.b4(b3)
|
164 |
+
b5 = self.b5(b4)
|
165 |
+
b5 = b5+b1
|
166 |
+
b6 = self.b6(b5)
|
167 |
+
b6 = b6+x1
|
168 |
+
#b6 = torch.cat([b1,b2,b3,b4,b5,b6],dim=1)
|
169 |
+
#b6 = x1+b1+b2+b3+b4+b5+b6
|
170 |
+
#x2 = x1+b1+b2+b3+b4+b5+b6
|
171 |
+
x2 = self.conv2(b6)
|
172 |
+
temp = self.upsample(x2, scale=scale)
|
173 |
+
#temp1 = self.upsample(x1, scale=scale)
|
174 |
+
#temp = temp+temp1
|
175 |
+
#temp2 = self.ReLU(temp)
|
176 |
+
out = self.conv3(temp)
|
177 |
+
out = self.add_mean(out)
|
178 |
+
return out
|
model/LSGSRCNN.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import model.ops as ops
|
4 |
+
|
5 |
+
'''
|
6 |
+
class Block(nn.Module):
|
7 |
+
def __init__(self,
|
8 |
+
in_channels, out_channels,
|
9 |
+
group=1):
|
10 |
+
super(Block, self).__init__()
|
11 |
+
|
12 |
+
self.b1 = ops.EResidualBlock(64, 64, group=group)
|
13 |
+
self.c1 = ops.BasicBlock(64*2, 64, 1, 1, 0)
|
14 |
+
self.c2 = ops.BasicBlock(64*3, 64, 1, 1, 0)
|
15 |
+
self.c3 = ops.BasicBlock(64*4, 64, 1, 1, 0)
|
16 |
+
|
17 |
+
def forward(self, x):
|
18 |
+
c0 = o0 = x
|
19 |
+
|
20 |
+
b1 = self.b1(o0)
|
21 |
+
c1 = torch.cat([c0, b1], dim=1)
|
22 |
+
o1 = self.c1(c1)
|
23 |
+
|
24 |
+
b2 = self.b1(o1)
|
25 |
+
c2 = torch.cat([c1, b2], dim=1)
|
26 |
+
o2 = self.c2(c2)
|
27 |
+
|
28 |
+
b3 = self.b1(o2)
|
29 |
+
c3 = torch.cat([c2, b3], dim=1)
|
30 |
+
o3 = self.c3(c3)
|
31 |
+
|
32 |
+
return o3
|
33 |
+
'''
|
34 |
+
|
35 |
+
class MFCModule(nn.Module):
|
36 |
+
def __init__(self,in_channels,out_channels,gropus=1):
|
37 |
+
super(MFCModule,self).__init__()
|
38 |
+
kernel_size =3
|
39 |
+
padding = 1
|
40 |
+
features = 64
|
41 |
+
features1 = 32
|
42 |
+
distill_rate = 0.5
|
43 |
+
self.distilled_channels = int(features*distill_rate)
|
44 |
+
self.remaining_channels = int(features-self.distilled_channels)
|
45 |
+
self.conv1_1 = nn.Sequential(nn.Conv2d(in_channels=features1,out_channels=features1,kernel_size=kernel_size,padding=padding,groups=1,bias=False))
|
46 |
+
self.conv2_1 = nn.Sequential(nn.Conv2d(in_channels=features1,out_channels=features1,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
|
47 |
+
self.conv3_1 = nn.Sequential(nn.Conv2d(in_channels=features1,out_channels=features1,kernel_size=kernel_size,padding=padding,groups=1,bias=False))
|
48 |
+
self.conv1_1_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
|
49 |
+
self.conv2_1_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
|
50 |
+
self.conv3_1_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
|
51 |
+
self.conv4_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
|
52 |
+
self.conv5_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
|
53 |
+
self.conv6_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
|
54 |
+
self.conv7_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
|
55 |
+
self.conv8_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
|
56 |
+
'''
|
57 |
+
self.conv1_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False))
|
58 |
+
self.conv2_1 = nn.Sequential(nn.Conv2d(in_channels=features1,out_channels=features1,kernel_size=kernel_size,padding=padding,groups=1,bias=False))
|
59 |
+
self.conv2_2 = nn.Sequential(nn.Conv2d(in_channels=features1,out_channels=features1,kernel_size=kernel_size,padding=padding,groups=1,bias=False))
|
60 |
+
self.conv3_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False))
|
61 |
+
self.conv4_1 = nn.Sequential(nn.Conv2d(in_channels=features1,out_channels=features1,kernel_size=kernel_size,padding=padding,groups=1,bias=False))
|
62 |
+
self.conv4_2 = nn.Sequential(nn.Conv2d(in_channels=features1,out_channels=features1,kernel_size=kernel_size,padding=padding,groups=1,bias=False))
|
63 |
+
self.conv5_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False))
|
64 |
+
self.conv6_1 = nn.Sequential(nn.Conv2d(in_channels=2*features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
|
65 |
+
'''
|
66 |
+
self.ReLU = nn.ReLU(inplace=True)
|
67 |
+
def forward(self,input):
|
68 |
+
dit1,remain1 = torch.split(input,(self.distilled_channels,self.remaining_channels),dim=1)
|
69 |
+
out1_1=self.conv1_1(dit1)
|
70 |
+
out1_1_t = self.ReLU(out1_1)
|
71 |
+
out2_1=self.conv2_1(out1_1_t)
|
72 |
+
out3_1=self.conv3_1(out2_1)
|
73 |
+
out1_2=self.conv1_1(remain1)
|
74 |
+
out1_2_t = self.ReLU(out1_2)
|
75 |
+
out2_2=self.conv2_1(out1_2_t)
|
76 |
+
out3_2=self.conv3_1(out2_2)
|
77 |
+
#out3 = torch.cat([out1_1,out3_1],dim=1)
|
78 |
+
#out3_t = torch.cat([out1_2,out3_2],dim=1)
|
79 |
+
out3_t = torch.cat([out3_1,out3_2],dim=1)
|
80 |
+
out3 = self.ReLU(out3_t)
|
81 |
+
#out3 = input+out3
|
82 |
+
out1_1t = self.conv1_1_1(input)
|
83 |
+
out1_2t1 = self.conv2_1_1(out1_1t)
|
84 |
+
out1_3t1 = self.conv3_1_1(out1_2t1)
|
85 |
+
out1_3t1 = out3+out1_3t1
|
86 |
+
out4_1=self.conv4_1(out1_3t1)
|
87 |
+
out5_1=self.conv5_1(out4_1)
|
88 |
+
out6_1=self.conv6_1(out5_1)
|
89 |
+
out7_1=self.conv7_1(out6_1)
|
90 |
+
out8_1=self.conv8_1(out7_1)
|
91 |
+
out8_1=out8_1+input+out4_1
|
92 |
+
'''
|
93 |
+
out1_c = self.conv1_1(input)
|
94 |
+
dit1,remain1 = torch.split(out1_c,(self.distilled_channels,self.remaining_channels),dim=1)
|
95 |
+
out1_r = self.ReLU(remain1)
|
96 |
+
out1_d = self.ReLU(dit1)
|
97 |
+
out2_r = self.conv2_1(out1_r)
|
98 |
+
out2_d = self.conv2_2(out1_d)
|
99 |
+
out2 = torch.cat([out2_r,out2_d],dim=1)
|
100 |
+
out2_r = torch.cat([remain1,out2_r],dim=1)
|
101 |
+
out2_d = torch.cat([dit1,out2_d],dim=1)
|
102 |
+
out2_1 = out2+out2_r+out2_d
|
103 |
+
out2 = self.ReLU(out2_1)
|
104 |
+
out3 = self.conv3_1(out2)
|
105 |
+
dit3,remain3 = torch.split(out3,(self.distilled_channels,self.remaining_channels),dim=1)
|
106 |
+
out3_r = self.ReLU(remain3)
|
107 |
+
out3_d = self.ReLU(dit3)
|
108 |
+
out4_r = self.conv4_1(out3_r)
|
109 |
+
out4_d = self.conv4_2(out3_d)
|
110 |
+
out4 = torch.cat([out4_r,out4_d],dim=1)
|
111 |
+
out4_r = torch.cat([remain3,out4_r],dim=1)
|
112 |
+
out4_d = torch.cat([dit3,out4_d],dim=1)
|
113 |
+
out4_1 = out4+out4_r+out4_d
|
114 |
+
out4 = self.ReLU(out4_1)
|
115 |
+
out5 = self.conv5_1(out4)
|
116 |
+
out5_1 = torch.cat([out3,out5],dim=1)
|
117 |
+
out5_1 = self.ReLU(out5_1)
|
118 |
+
out6_1 = self.conv6_1(out5_1)
|
119 |
+
out6_r = input+out6_1
|
120 |
+
'''
|
121 |
+
return out8_1
|
122 |
+
|
123 |
+
|
124 |
+
class Net(nn.Module):
|
125 |
+
def __init__(self, **kwargs):
|
126 |
+
super(Net, self).__init__()
|
127 |
+
|
128 |
+
scale = kwargs.get("scale") #value of scale is scale.
|
129 |
+
multi_scale = kwargs.get("multi_scale") # value of multi_scale is multi_scale in args.
|
130 |
+
group = kwargs.get("group", 1) #if valule of group isn't given, group is 1.
|
131 |
+
kernel_size = 3 #tcw 201904091123
|
132 |
+
kernel_size1 = 1 #tcw 201904091123
|
133 |
+
padding1 = 0 #tcw 201904091124
|
134 |
+
padding = 1 #tcw201904091123
|
135 |
+
features = 64 #tcw201904091124
|
136 |
+
groups = 1 #tcw201904091124
|
137 |
+
channels = 3
|
138 |
+
features1 = 64
|
139 |
+
self.sub_mean = ops.MeanShift((0.4488, 0.4371, 0.4040), sub=True)
|
140 |
+
self.add_mean = ops.MeanShift((0.4488, 0.4371, 0.4040), sub=False)
|
141 |
+
'''
|
142 |
+
in_channels, out_channels, kernel_size, stride, padding,dialation, groups,
|
143 |
+
'''
|
144 |
+
|
145 |
+
self.conv1_1 = nn.Sequential(nn.Conv2d(in_channels=channels,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
|
146 |
+
self.b1 = MFCModule(features,features)
|
147 |
+
self.b2 = MFCModule(features,features)
|
148 |
+
self.b3 = MFCModule(features,features)
|
149 |
+
self.b4 = MFCModule(features,features)
|
150 |
+
self.b5 = MFCModule(features,features)
|
151 |
+
self.b6 = MFCModule(features,features)
|
152 |
+
self.ReLU=nn.ReLU(inplace=True)
|
153 |
+
#self.conv2 = nn.Sequential(nn.Conv2d(in_channels=6*features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
|
154 |
+
self.conv2 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
|
155 |
+
self.conv3 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=3,kernel_size=kernel_size,padding=padding,groups=1,bias=False))
|
156 |
+
self.upsample = ops.UpsampleBlock(64, scale=scale, multi_scale=multi_scale,group=1)
|
157 |
+
def forward(self, x, scale):
|
158 |
+
x = self.sub_mean(x)
|
159 |
+
x1 = self.conv1_1(x)
|
160 |
+
b1 = self.b1(x1)
|
161 |
+
b2 = self.b2(b1)
|
162 |
+
b3 = self.b3(b2)
|
163 |
+
b4 = self.b4(b3)
|
164 |
+
b5 = self.b5(b4)
|
165 |
+
b5 = b5+b1
|
166 |
+
b6 = self.b6(b5)
|
167 |
+
b6 = b6+x1
|
168 |
+
#b6 = torch.cat([b1,b2,b3,b4,b5,b6],dim=1)
|
169 |
+
#b6 = x1+b1+b2+b3+b4+b5+b6
|
170 |
+
#x2 = x1+b1+b2+b3+b4+b5+b6
|
171 |
+
x2 = self.conv2(b6)
|
172 |
+
temp = self.upsample(x2, scale=scale)
|
173 |
+
#temp1 = self.upsample(x1, scale=scale)
|
174 |
+
#temp = temp+temp1
|
175 |
+
#temp2 = self.ReLU(temp)
|
176 |
+
out = self.conv3(temp)
|
177 |
+
out = self.add_mean(out)
|
178 |
+
return out
|
model/LSGSRCNN_b.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import model.ops as ops
|
4 |
+
|
5 |
+
'''
|
6 |
+
class Block(nn.Module):
|
7 |
+
def __init__(self,
|
8 |
+
in_channels, out_channels,
|
9 |
+
group=1):
|
10 |
+
super(Block, self).__init__()
|
11 |
+
|
12 |
+
self.b1 = ops.EResidualBlock(64, 64, group=group)
|
13 |
+
self.c1 = ops.BasicBlock(64*2, 64, 1, 1, 0)
|
14 |
+
self.c2 = ops.BasicBlock(64*3, 64, 1, 1, 0)
|
15 |
+
self.c3 = ops.BasicBlock(64*4, 64, 1, 1, 0)
|
16 |
+
|
17 |
+
def forward(self, x):
|
18 |
+
c0 = o0 = x
|
19 |
+
|
20 |
+
b1 = self.b1(o0)
|
21 |
+
c1 = torch.cat([c0, b1], dim=1)
|
22 |
+
o1 = self.c1(c1)
|
23 |
+
|
24 |
+
b2 = self.b1(o1)
|
25 |
+
c2 = torch.cat([c1, b2], dim=1)
|
26 |
+
o2 = self.c2(c2)
|
27 |
+
|
28 |
+
b3 = self.b1(o2)
|
29 |
+
c3 = torch.cat([c2, b3], dim=1)
|
30 |
+
o3 = self.c3(c3)
|
31 |
+
|
32 |
+
return o3
|
33 |
+
'''
|
34 |
+
|
35 |
+
class MFCModule(nn.Module):
|
36 |
+
def __init__(self,in_channels,out_channels,gropus=1):
|
37 |
+
super(MFCModule,self).__init__()
|
38 |
+
kernel_size =3
|
39 |
+
padding = 1
|
40 |
+
features = 64
|
41 |
+
features1 = 32
|
42 |
+
distill_rate = 0.5
|
43 |
+
self.distilled_channels = int(features*distill_rate)
|
44 |
+
self.remaining_channels = int(features-self.distilled_channels)
|
45 |
+
self.conv1_1 = nn.Sequential(nn.Conv2d(in_channels=features1,out_channels=features1,kernel_size=kernel_size,padding=padding,groups=1,bias=False))
|
46 |
+
self.conv2_1 = nn.Sequential(nn.Conv2d(in_channels=features1,out_channels=features1,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
|
47 |
+
self.conv3_1 = nn.Sequential(nn.Conv2d(in_channels=features1,out_channels=features1,kernel_size=kernel_size,padding=padding,groups=1,bias=False))
|
48 |
+
self.conv1_1_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
|
49 |
+
self.conv2_1_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
|
50 |
+
self.conv3_1_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
|
51 |
+
self.conv4_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
|
52 |
+
self.conv5_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
|
53 |
+
self.conv6_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
|
54 |
+
self.conv7_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
|
55 |
+
self.conv8_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
|
56 |
+
'''
|
57 |
+
self.conv1_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False))
|
58 |
+
self.conv2_1 = nn.Sequential(nn.Conv2d(in_channels=features1,out_channels=features1,kernel_size=kernel_size,padding=padding,groups=1,bias=False))
|
59 |
+
self.conv2_2 = nn.Sequential(nn.Conv2d(in_channels=features1,out_channels=features1,kernel_size=kernel_size,padding=padding,groups=1,bias=False))
|
60 |
+
self.conv3_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False))
|
61 |
+
self.conv4_1 = nn.Sequential(nn.Conv2d(in_channels=features1,out_channels=features1,kernel_size=kernel_size,padding=padding,groups=1,bias=False))
|
62 |
+
self.conv4_2 = nn.Sequential(nn.Conv2d(in_channels=features1,out_channels=features1,kernel_size=kernel_size,padding=padding,groups=1,bias=False))
|
63 |
+
self.conv5_1 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False))
|
64 |
+
self.conv6_1 = nn.Sequential(nn.Conv2d(in_channels=2*features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
|
65 |
+
'''
|
66 |
+
self.ReLU = nn.ReLU(inplace=True)
|
67 |
+
def forward(self,input):
|
68 |
+
dit1,remain1 = torch.split(input,(self.distilled_channels,self.remaining_channels),dim=1)
|
69 |
+
out1_1=self.conv1_1(dit1)
|
70 |
+
out1_1_t = self.ReLU(out1_1)
|
71 |
+
out2_1=self.conv2_1(out1_1_t)
|
72 |
+
out3_1=self.conv3_1(out2_1)
|
73 |
+
out1_2=self.conv1_1(remain1)
|
74 |
+
out1_2_t = self.ReLU(out1_2)
|
75 |
+
out2_2=self.conv2_1(out1_2_t)
|
76 |
+
out3_2=self.conv3_1(out2_2)
|
77 |
+
#out3 = torch.cat([out1_1,out3_1],dim=1)
|
78 |
+
#out3_t = torch.cat([out1_2,out3_2],dim=1)
|
79 |
+
out3_t = torch.cat([out3_1,out3_2],dim=1)
|
80 |
+
out3 = self.ReLU(out3_t)
|
81 |
+
#out3 = input+out3
|
82 |
+
out1_1t = self.conv1_1_1(input)
|
83 |
+
out1_2t1 = self.conv2_1_1(out1_1t)
|
84 |
+
out1_3t1 = self.conv3_1_1(out1_2t1)
|
85 |
+
out1_3t1 = out3+out1_3t1
|
86 |
+
out4_1=self.conv4_1(out1_3t1)
|
87 |
+
out5_1=self.conv5_1(out4_1)
|
88 |
+
out6_1=self.conv6_1(out5_1)
|
89 |
+
out7_1=self.conv7_1(out6_1)
|
90 |
+
out8_1=self.conv8_1(out7_1)
|
91 |
+
out8_1=out8_1+input+out4_1
|
92 |
+
'''
|
93 |
+
out1_c = self.conv1_1(input)
|
94 |
+
dit1,remain1 = torch.split(out1_c,(self.distilled_channels,self.remaining_channels),dim=1)
|
95 |
+
out1_r = self.ReLU(remain1)
|
96 |
+
out1_d = self.ReLU(dit1)
|
97 |
+
out2_r = self.conv2_1(out1_r)
|
98 |
+
out2_d = self.conv2_2(out1_d)
|
99 |
+
out2 = torch.cat([out2_r,out2_d],dim=1)
|
100 |
+
out2_r = torch.cat([remain1,out2_r],dim=1)
|
101 |
+
out2_d = torch.cat([dit1,out2_d],dim=1)
|
102 |
+
out2_1 = out2+out2_r+out2_d
|
103 |
+
out2 = self.ReLU(out2_1)
|
104 |
+
out3 = self.conv3_1(out2)
|
105 |
+
dit3,remain3 = torch.split(out3,(self.distilled_channels,self.remaining_channels),dim=1)
|
106 |
+
out3_r = self.ReLU(remain3)
|
107 |
+
out3_d = self.ReLU(dit3)
|
108 |
+
out4_r = self.conv4_1(out3_r)
|
109 |
+
out4_d = self.conv4_2(out3_d)
|
110 |
+
out4 = torch.cat([out4_r,out4_d],dim=1)
|
111 |
+
out4_r = torch.cat([remain3,out4_r],dim=1)
|
112 |
+
out4_d = torch.cat([dit3,out4_d],dim=1)
|
113 |
+
out4_1 = out4+out4_r+out4_d
|
114 |
+
out4 = self.ReLU(out4_1)
|
115 |
+
out5 = self.conv5_1(out4)
|
116 |
+
out5_1 = torch.cat([out3,out5],dim=1)
|
117 |
+
out5_1 = self.ReLU(out5_1)
|
118 |
+
out6_1 = self.conv6_1(out5_1)
|
119 |
+
out6_r = input+out6_1
|
120 |
+
'''
|
121 |
+
return out8_1
|
122 |
+
|
123 |
+
|
124 |
+
class Net(nn.Module):
|
125 |
+
def __init__(self, **kwargs):
|
126 |
+
super(Net, self).__init__()
|
127 |
+
|
128 |
+
scale = kwargs.get("scale") #value of scale is scale.
|
129 |
+
multi_scale = kwargs.get("multi_scale") # value of multi_scale is multi_scale in args.
|
130 |
+
group = kwargs.get("group", 1) #if valule of group isn't given, group is 1.
|
131 |
+
kernel_size = 3 #tcw 201904091123
|
132 |
+
kernel_size1 = 1 #tcw 201904091123
|
133 |
+
padding1 = 0 #tcw 201904091124
|
134 |
+
padding = 1 #tcw201904091123
|
135 |
+
features = 64 #tcw201904091124
|
136 |
+
groups = 1 #tcw201904091124
|
137 |
+
channels = 3
|
138 |
+
features1 = 64
|
139 |
+
self.sub_mean = ops.MeanShift((0.4488, 0.4371, 0.4040), sub=True)
|
140 |
+
self.add_mean = ops.MeanShift((0.4488, 0.4371, 0.4040), sub=False)
|
141 |
+
'''
|
142 |
+
in_channels, out_channels, kernel_size, stride, padding,dialation, groups,
|
143 |
+
'''
|
144 |
+
|
145 |
+
self.conv1_1 = nn.Sequential(nn.Conv2d(in_channels=channels,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
|
146 |
+
self.b1 = MFCModule(features,features)
|
147 |
+
self.b2 = MFCModule(features,features)
|
148 |
+
self.b3 = MFCModule(features,features)
|
149 |
+
self.b4 = MFCModule(features,features)
|
150 |
+
self.b5 = MFCModule(features,features)
|
151 |
+
self.b6 = MFCModule(features,features)
|
152 |
+
self.ReLU=nn.ReLU(inplace=True)
|
153 |
+
#self.conv2 = nn.Sequential(nn.Conv2d(in_channels=6*features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
|
154 |
+
self.conv2 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=1,bias=False),nn.ReLU(inplace=True))
|
155 |
+
self.conv3 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=3,kernel_size=kernel_size,padding=padding,groups=1,bias=False))
|
156 |
+
self.upsample = ops.UpsampleBlock(64, scale=scale, multi_scale=multi_scale,group=1)
|
157 |
+
def forward(self, x, scale):
|
158 |
+
x = self.sub_mean(x)
|
159 |
+
x1 = self.conv1_1(x)
|
160 |
+
b1 = self.b1(x1)
|
161 |
+
b2 = self.b2(b1)
|
162 |
+
b3 = self.b3(b2)
|
163 |
+
b4 = self.b4(b3)
|
164 |
+
b5 = self.b5(b4)
|
165 |
+
b5 = b5+b1
|
166 |
+
b6 = self.b6(b5)
|
167 |
+
b6 = b6+x1
|
168 |
+
#b6 = torch.cat([b1,b2,b3,b4,b5,b6],dim=1)
|
169 |
+
#b6 = x1+b1+b2+b3+b4+b5+b6
|
170 |
+
#x2 = x1+b1+b2+b3+b4+b5+b6
|
171 |
+
x2 = self.conv2(b6)
|
172 |
+
temp = self.upsample(x2, scale=scale)
|
173 |
+
#temp1 = self.upsample(x1, scale=scale)
|
174 |
+
#temp = temp+temp1
|
175 |
+
#temp2 = self.ReLU(temp)
|
176 |
+
out = self.conv3(temp)
|
177 |
+
out = self.add_mean(out)
|
178 |
+
return out
|
model/__init__.py
ADDED
File without changes
|
model/ops.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.init as init
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
def init_weights(modules):
|
8 |
+
pass
|
9 |
+
|
10 |
+
|
11 |
+
class MeanShift(nn.Module):
|
12 |
+
def __init__(self, mean_rgb, sub):
|
13 |
+
super(MeanShift, self).__init__()
|
14 |
+
|
15 |
+
sign = -1 if sub else 1
|
16 |
+
r = mean_rgb[0] * sign
|
17 |
+
g = mean_rgb[1] * sign
|
18 |
+
b = mean_rgb[2] * sign
|
19 |
+
|
20 |
+
self.shifter = nn.Conv2d(3, 3, 1, 1, 0) #3 is size of output, 3 is size of input, 1 is kernel 1 is padding, 0 is group
|
21 |
+
self.shifter.weight.data = torch.eye(3).view(3, 3, 1, 1) # view(3,3,1,1) convert a shape into (3,3,1,1) eye(3) is a 3x3 matrix and diagonal is 1.
|
22 |
+
self.shifter.bias.data = torch.Tensor([r, g, b])
|
23 |
+
#in_channels, out_channels,ksize=3, stride=1, pad=1
|
24 |
+
# Freeze the mean shift layer
|
25 |
+
for params in self.shifter.parameters():
|
26 |
+
params.requires_grad = False
|
27 |
+
|
28 |
+
def forward(self, x):
|
29 |
+
x = self.shifter(x)
|
30 |
+
return x
|
31 |
+
|
32 |
+
|
33 |
+
class BasicBlock(nn.Module):
|
34 |
+
def __init__(self,
|
35 |
+
in_channels, out_channels,
|
36 |
+
ksize=3, stride=1, pad=1):
|
37 |
+
super(BasicBlock, self).__init__()
|
38 |
+
|
39 |
+
self.body = nn.Sequential(
|
40 |
+
nn.Conv2d(in_channels, out_channels, ksize, stride, pad),
|
41 |
+
nn.ReLU(inplace=True)
|
42 |
+
)
|
43 |
+
|
44 |
+
init_weights(self.modules)
|
45 |
+
|
46 |
+
def forward(self, x):
|
47 |
+
out = self.body(x)
|
48 |
+
return out
|
49 |
+
|
50 |
+
|
51 |
+
class ResidualBlock(nn.Module):
|
52 |
+
def __init__(self,
|
53 |
+
in_channels, out_channels):
|
54 |
+
super(ResidualBlock, self).__init__()
|
55 |
+
|
56 |
+
self.body = nn.Sequential(
|
57 |
+
nn.Conv2d(in_channels, out_channels, 3, 1, 1),
|
58 |
+
nn.ReLU(inplace=True),
|
59 |
+
nn.Conv2d(out_channels, out_channels, 3, 1, 1),
|
60 |
+
)
|
61 |
+
|
62 |
+
init_weights(self.modules)
|
63 |
+
|
64 |
+
def forward(self, x):
|
65 |
+
out = self.body(x)
|
66 |
+
out = F.relu(out + x)
|
67 |
+
return out
|
68 |
+
|
69 |
+
|
70 |
+
class EResidualBlock(nn.Module):
|
71 |
+
def __init__(self,
|
72 |
+
in_channels, out_channels,
|
73 |
+
group=1):
|
74 |
+
super(EResidualBlock, self).__init__()
|
75 |
+
|
76 |
+
self.body = nn.Sequential(
|
77 |
+
nn.Conv2d(in_channels, out_channels, 3, 1, 1, groups=group),
|
78 |
+
nn.ReLU(inplace=True),
|
79 |
+
nn.Conv2d(out_channels, out_channels, 3, 1, 1, groups=group),
|
80 |
+
nn.ReLU(inplace=True),
|
81 |
+
nn.Conv2d(out_channels, out_channels, 1, 1, 0),
|
82 |
+
)
|
83 |
+
|
84 |
+
init_weights(self.modules)
|
85 |
+
|
86 |
+
def forward(self, x):
|
87 |
+
out = self.body(x)
|
88 |
+
out = F.relu(out + x)
|
89 |
+
return out
|
90 |
+
|
91 |
+
|
92 |
+
class UpsampleBlock(nn.Module):
|
93 |
+
def __init__(self,
|
94 |
+
n_channels, scale, multi_scale,
|
95 |
+
group=1):
|
96 |
+
super(UpsampleBlock, self).__init__()
|
97 |
+
|
98 |
+
if multi_scale:
|
99 |
+
self.up2 = _UpsampleBlock(n_channels, scale=2, group=group)
|
100 |
+
self.up3 = _UpsampleBlock(n_channels, scale=3, group=group)
|
101 |
+
self.up4 = _UpsampleBlock(n_channels, scale=4, group=group)
|
102 |
+
else:
|
103 |
+
self.up = _UpsampleBlock(n_channels, scale=scale, group=group)
|
104 |
+
|
105 |
+
self.multi_scale = multi_scale
|
106 |
+
|
107 |
+
def forward(self, x, scale):
|
108 |
+
if self.multi_scale:
|
109 |
+
if scale == 2:
|
110 |
+
return self.up2(x)
|
111 |
+
elif scale == 3:
|
112 |
+
return self.up3(x)
|
113 |
+
elif scale == 4:
|
114 |
+
return self.up4(x)
|
115 |
+
else:
|
116 |
+
return self.up(x)
|
117 |
+
|
118 |
+
|
119 |
+
class _UpsampleBlock(nn.Module):
|
120 |
+
def __init__(self,
|
121 |
+
n_channels, scale,
|
122 |
+
group=1):
|
123 |
+
super(_UpsampleBlock, self).__init__()
|
124 |
+
|
125 |
+
modules = []
|
126 |
+
if scale == 2 or scale == 4 or scale == 8:
|
127 |
+
for _ in range(int(math.log(scale, 2))):
|
128 |
+
modules += [nn.Conv2d(n_channels, 4*n_channels, 3, 1, 1, groups=group), nn.ReLU(inplace=True)]
|
129 |
+
#modules += [nn.Conv2d(n_channels, 4*n_channels, 3, 1, 1, groups=group)]
|
130 |
+
modules += [nn.PixelShuffle(2)]
|
131 |
+
elif scale == 3:
|
132 |
+
modules += [nn.Conv2d(n_channels, 9*n_channels, 3, 1, 1, groups=group), nn.ReLU(inplace=True)]
|
133 |
+
#modules += [nn.Conv2d(n_channels, 9*n_channels, 3, 1, 1, groups=group)]
|
134 |
+
modules += [nn.PixelShuffle(3)]
|
135 |
+
|
136 |
+
self.body = nn.Sequential(*modules)
|
137 |
+
init_weights(self.modules)
|
138 |
+
|
139 |
+
def forward(self, x):
|
140 |
+
out = self.body(x)
|
141 |
+
return out
|