artelabsuper commited on
Commit
eba1c6b
·
1 Parent(s): 1513566

track and test model

Browse files
.gitattributes CHANGED
@@ -1,27 +1,5 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ftz filter=lfs diff=lfs merge=lfs -text
6
- *.gz filter=lfs diff=lfs merge=lfs -text
7
- *.h5 filter=lfs diff=lfs merge=lfs -text
8
- *.joblib filter=lfs diff=lfs merge=lfs -text
9
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
10
- *.model filter=lfs diff=lfs merge=lfs -text
11
- *.msgpack filter=lfs diff=lfs merge=lfs -text
12
- *.onnx filter=lfs diff=lfs merge=lfs -text
13
- *.ot filter=lfs diff=lfs merge=lfs -text
14
- *.parquet filter=lfs diff=lfs merge=lfs -text
15
- *.pb filter=lfs diff=lfs merge=lfs -text
16
- *.pt filter=lfs diff=lfs merge=lfs -text
17
- *.pth filter=lfs diff=lfs merge=lfs -text
18
- *.rar filter=lfs diff=lfs merge=lfs -text
19
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
20
- *.tar.* filter=lfs diff=lfs merge=lfs -text
21
- *.tflite filter=lfs diff=lfs merge=lfs -text
22
- *.tgz filter=lfs diff=lfs merge=lfs -text
23
- *.wasm filter=lfs diff=lfs merge=lfs -text
24
- *.xz filter=lfs diff=lfs merge=lfs -text
25
- *.zip filter=lfs diff=lfs merge=lfs -text
26
- *.zstandard filter=lfs diff=lfs merge=lfs -text
27
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ DTM_exp_train10%_model_a/d-best.pth filter=lfs diff=lfs merge=lfs -text
2
+ DTM_exp_train10%_model_a/g-best.pth filter=lfs diff=lfs merge=lfs -text
3
+ DTM_exp_train10%_model_c/d-best.pth filter=lfs diff=lfs merge=lfs -text
4
+ DTM_exp_train10%_model_c/g-best.pth filter=lfs diff=lfs merge=lfs -text
5
+ DTM_exp_train10%_model_b/g-best.pth filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.gitignore CHANGED
@@ -1,3 +1,5 @@
1
  venv
2
  *.pyc
3
  __pycache__
 
 
 
1
  venv
2
  *.pyc
3
  __pycache__
4
+ sr.png
5
+ test.png
DTM_exp_train10%_model_a/d-best.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8ecac9857f6f35d88afec8cea345c81ab1c7a8761ddda8225de820c3d3b48f4c
3
+ size 27401785
DTM_exp_train10%_model_a/g-best.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c0dab4fd03de2f0189e87db212a8fce5f1c7771aea0a23b804a146d9c00097df
3
+ size 61648584
DTM_exp_train10%_model_b/g-best.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:939adfbd0619db38fc5720ad1c2cb558a605e8e8aefac4703cfa98bbb651ec93
3
+ size 61648714
DTM_exp_train10%_model_c/d-best.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:65f96e37fdd9b1bc8c40b7e257a92c9b706f585b791586a50e2ec2463991e059
3
+ size 27401786
DTM_exp_train10%_model_c/g-best.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4f9652227743f248f3e8aa514c236b8738e11e062c2f270fceec385daedad7c3
3
+ size 49787998
models/modelNetA.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 Dakewe Biotech Corporation. All Rights Reserved.
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+
15
+ # ==============================================================================
16
+ # File description: Realize the model definition function.
17
+ # ==============================================================================
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ import torchvision.models as models
22
+ from torch import Tensor
23
+
24
+ __all__ = [
25
+ "ResidualDenseBlock", "ResidualResidualDenseBlock",
26
+ "Discriminator", "Generator",
27
+ "DownSamplingNetwork"
28
+ ]
29
+
30
+
31
+ class ResidualDenseBlock(nn.Module):
32
+ """Achieves densely connected convolutional layers.
33
+ `Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993v5.pdf>` paper.
34
+
35
+ Args:
36
+ channels (int): The number of channels in the input image.
37
+ growths (int): The number of channels that increase in each layer of convolution.
38
+ """
39
+
40
+ def __init__(self, channels: int, growths: int) -> None:
41
+ super(ResidualDenseBlock, self).__init__()
42
+ self.conv1 = nn.Conv2d(channels + growths * 0, growths, (3, 3), (1, 1), (1, 1))
43
+ self.conv2 = nn.Conv2d(channels + growths * 1, growths, (3, 3), (1, 1), (1, 1))
44
+ self.conv3 = nn.Conv2d(channels + growths * 2, growths, (3, 3), (1, 1), (1, 1))
45
+ self.conv4 = nn.Conv2d(channels + growths * 3, growths, (3, 3), (1, 1), (1, 1))
46
+ self.conv5 = nn.Conv2d(channels + growths * 4, channels, (3, 3), (1, 1), (1, 1))
47
+
48
+ self.leaky_relu = nn.LeakyReLU(0.2, True)
49
+ self.identity = nn.Identity()
50
+
51
+ def forward(self, x: Tensor) -> Tensor:
52
+ identity = x
53
+
54
+ out1 = self.leaky_relu(self.conv1(x))
55
+ out2 = self.leaky_relu(self.conv2(torch.cat([x, out1], 1)))
56
+ out3 = self.leaky_relu(self.conv3(torch.cat([x, out1, out2], 1)))
57
+ out4 = self.leaky_relu(self.conv4(torch.cat([x, out1, out2, out3], 1)))
58
+ out5 = self.identity(self.conv5(torch.cat([x, out1, out2, out3, out4], 1)))
59
+ out = out5 * 0.2 + identity
60
+
61
+ return out
62
+
63
+
64
+
65
+ class ResidualDenseBlock(nn.Module):
66
+ """Achieves densely connected convolutional layers.
67
+ `Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993v5.pdf>` paper.
68
+
69
+ Args:
70
+ channels (int): The number of channels in the input image.
71
+ growths (int): The number of channels that increase in each layer of convolution.
72
+ """
73
+
74
+ def __init__(self, channels: int, growths: int) -> None:
75
+ super(ResidualDenseBlock, self).__init__()
76
+ self.conv1 = nn.Conv2d(channels + growths * 0, growths, (3, 3), (1, 1), (1, 1))
77
+ self.conv2 = nn.Conv2d(channels + growths * 1, growths, (3, 3), (1, 1), (1, 1))
78
+ self.conv3 = nn.Conv2d(channels + growths * 2, growths, (3, 3), (1, 1), (1, 1))
79
+ self.conv4 = nn.Conv2d(channels + growths * 3, growths, (3, 3), (1, 1), (1, 1))
80
+ self.conv5 = nn.Conv2d(channels + growths * 4, channels, (3, 3), (1, 1), (1, 1))
81
+
82
+ self.leaky_relu = nn.LeakyReLU(0.2, True)
83
+ self.identity = nn.Identity()
84
+
85
+ def forward(self, x: Tensor) -> Tensor:
86
+ identity = x
87
+
88
+ out1 = self.leaky_relu(self.conv1(x))
89
+ out2 = self.leaky_relu(self.conv2(torch.cat([x, out1], 1)))
90
+ out3 = self.leaky_relu(self.conv3(torch.cat([x, out1, out2], 1)))
91
+ out4 = self.leaky_relu(self.conv4(torch.cat([x, out1, out2, out3], 1)))
92
+ out5 = self.identity(self.conv5(torch.cat([x, out1, out2, out3, out4], 1)))
93
+ out = out5 * 0.2 + identity
94
+
95
+ return out
96
+
97
+
98
+
99
+ class MiniResidualDenseBlock(nn.Module):
100
+ """Achieves densely connected convolutional layers.
101
+ `Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993v5.pdf>` paper.
102
+
103
+ Args:
104
+ channels (int): The number of channels in the input image.
105
+ growths (int): The number of channels that increase in each layer of convolution.
106
+ """
107
+
108
+ def __init__(self, channels: int, growths: int) -> None:
109
+ super(MiniResidualDenseBlock, self).__init__()
110
+ self.conv1 = nn.Conv2d(channels + growths * 0, growths, (3, 3), (1, 1), (1, 1))
111
+ self.conv2 = nn.Conv2d(channels + growths * 1, growths, (3, 3), (1, 1), (1, 1))
112
+ self.conv3 = nn.Conv2d(channels + growths * 2, growths, (3, 3), (1, 1), (1, 1))
113
+ self.conv4 = nn.Conv2d(channels + growths * 3, growths, (3, 3), (1, 1), (1, 1))
114
+ self.conv5 = nn.Conv2d(channels + growths * 4, channels, (3, 3), (1, 1), (1, 1))
115
+
116
+ self.leaky_relu = nn.LeakyReLU(0.2, True)
117
+
118
+ def forward(self, x: Tensor) -> Tensor:
119
+ identity = x
120
+
121
+ out1 = self.leaky_relu(self.conv1(x))
122
+ out2 = self.leaky_relu(self.conv2(torch.cat([x, out1], 1)))
123
+ out3 = self.leaky_relu(self.conv3(torch.cat([x, out1, out2], 1)))
124
+ out4 = self.leaky_relu(self.conv4(torch.cat([x, out1, out2, out3], 1)))
125
+ out5 = self.leaky_relu(self.conv5(torch.cat([x, out1, out2, out3, out4], 1)))
126
+ out = out5 * 0.2 + identity
127
+
128
+ return out
129
+
130
+
131
+
132
+ class ResidualResidualDenseBlock(nn.Module):
133
+ """Multi-layer residual dense convolution block.
134
+
135
+ Args:
136
+ channels (int): The number of channels in the input image.
137
+ growths (int): The number of channels that increase in each layer of convolution.
138
+ """
139
+
140
+ def __init__(self, channels: int, growths: int) -> None:
141
+ super(ResidualResidualDenseBlock, self).__init__()
142
+ self.rdb1 = ResidualDenseBlock(channels, growths)
143
+ self.rdb2 = ResidualDenseBlock(channels, growths)
144
+ self.rdb3 = ResidualDenseBlock(channels, growths)
145
+
146
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
147
+ identity = x
148
+
149
+ out = self.rdb1(x)
150
+ out = self.rdb2(out)
151
+ out = self.rdb3(out)
152
+ out = out * 0.2 + identity
153
+
154
+ return out
155
+
156
+
157
+ class MiniResidualResidualDenseBlock(nn.Module):
158
+ """Multi-layer residual dense convolution block.
159
+
160
+ Args:
161
+ channels (int): The number of channels in the input image.
162
+ growths (int): The number of channels that increase in each layer of convolution.
163
+ """
164
+
165
+ def __init__(self, channels: int, growths: int) -> None:
166
+ super(MiniResidualResidualDenseBlock, self).__init__()
167
+ self.M_rdb1 = MiniResidualDenseBlock(channels, growths)
168
+ self.M_rdb2 = MiniResidualDenseBlock(channels, growths)
169
+ self.M_rdb3 = MiniResidualDenseBlock(channels, growths)
170
+
171
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
172
+ identity = x
173
+ out = self.M_rdb1(x)
174
+ out = self.M_rdb2(out)
175
+ out = self.M_rdb3(out)
176
+ out = out * 0.2 + identity
177
+ return out
178
+
179
+
180
+
181
+ class Discriminator(nn.Module):
182
+ def __init__(self) -> None:
183
+ super(Discriminator, self).__init__()
184
+ self.features = nn.Sequential(
185
+ # input size. (3) x 512 x 512
186
+ nn.Conv2d(2, 32, (3, 3), (1, 1), (1, 1), bias=True),
187
+ nn.LeakyReLU(0.2, True),
188
+ nn.Conv2d(32, 64, (4, 4), (2, 2), (1, 1), bias=False),
189
+ nn.BatchNorm2d(64),
190
+ nn.LeakyReLU(0.2, True),
191
+ nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1), bias=False),
192
+ nn.BatchNorm2d(64),
193
+ nn.LeakyReLU(0.2, True),
194
+ # state size. (128) x 256 x 256
195
+ nn.Conv2d(64, 128, (4, 4), (2, 2), (1, 1), bias=False),
196
+ nn.BatchNorm2d(128),
197
+ nn.LeakyReLU(0.2, True),
198
+ nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1), bias=False),
199
+ nn.BatchNorm2d(128),
200
+ nn.LeakyReLU(0.2, True),
201
+ # state size. (256) x 64 x 64
202
+ nn.Conv2d(128, 256, (4, 4), (2, 2), (1, 1), bias=False),
203
+ nn.BatchNorm2d(256),
204
+ nn.LeakyReLU(0.2, True),
205
+ nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1), bias=False),
206
+ nn.BatchNorm2d(256),
207
+ nn.LeakyReLU(0.2, True),
208
+ nn.Conv2d(256, 256, (4, 4), (2, 2), (1, 1), bias=False),
209
+ nn.BatchNorm2d(256),
210
+ nn.LeakyReLU(0.2, True),
211
+ nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1), bias=False),
212
+ nn.BatchNorm2d(256),
213
+ nn.LeakyReLU(0.2, True),
214
+ # state size. (512) x 16 x 16
215
+ nn.Conv2d(256, 256, (4, 4), (2, 2), (1, 1), bias=False),
216
+ nn.BatchNorm2d(256),
217
+ nn.LeakyReLU(0.2, True),
218
+
219
+ nn.Conv2d(256, 256, (4, 4), (2, 2), (1, 1), bias=False),
220
+ nn.BatchNorm2d(256),
221
+ nn.LeakyReLU(0.2, True),
222
+ # state size. (512) x 8 x 8
223
+ )
224
+
225
+ self.classifier = nn.Sequential(
226
+ nn.Linear(256 * 8 * 8, 100),
227
+ nn.LeakyReLU(0.2, True),
228
+ nn.Linear(100, 1),
229
+ )
230
+
231
+ def forward(self, x: Tensor) -> Tensor:
232
+ out = self.features(x)
233
+ out = torch.flatten(out, 1)
234
+ out = self.classifier(out)
235
+ return out
236
+
237
+ class Generator(nn.Module):
238
+ def __init__(self) -> None:
239
+ super(Generator, self).__init__()
240
+ #RLNet
241
+ self.RLNetconv_block1 = nn.Conv2d(1, 64, (3, 3), (1, 1), (1, 1))
242
+ RLNettrunk = []
243
+ for _ in range(4):
244
+ RLNettrunk += [ResidualResidualDenseBlock(64, 32)]
245
+ self.RLNettrunk = nn.Sequential(*RLNettrunk)
246
+ self.RLNetconv_block2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
247
+ self.RLNetconv_block3 = nn.Sequential(
248
+ nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)),
249
+ nn.LeakyReLU(0.2, True)
250
+ )
251
+ self.RLNetconv_block4 = nn.Sequential(
252
+ nn.Conv2d(64, 1, (3, 3), (1, 1), (1, 1)),
253
+ nn.Tanh()
254
+ )
255
+
256
+ #############################################################################
257
+ #Generator
258
+ self.conv_block1 = nn.Conv2d(1, 64, (3, 3), (1, 1), (1, 1))
259
+
260
+ trunk = []
261
+ for _ in range(16):
262
+ trunk += [ResidualResidualDenseBlock(64, 32)]
263
+ self.trunk = nn.Sequential(*trunk)
264
+
265
+ # After the feature extraction network, reconnect a layer of convolutional blocks.
266
+ self.conv_block2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
267
+
268
+
269
+ # Upsampling convolutional layer.
270
+ self.upsampling = nn.Sequential(
271
+ nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)),
272
+ nn.LeakyReLU(0.2, True)
273
+ )
274
+
275
+ # Reconnect a layer of convolution block after upsampling.
276
+ self.conv_block3 = nn.Sequential(
277
+ nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)),
278
+ nn.LeakyReLU(0.2, True)
279
+ )
280
+
281
+ self.conv_block4 = nn.Sequential(
282
+ nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)),
283
+ #nn.Sigmoid()
284
+ )
285
+
286
+ self.conv_block0_branch0 = nn.Sequential(
287
+ nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)),
288
+ nn.LeakyReLU(0.2, True),
289
+ nn.Conv2d(64, 128, (3, 3), (1, 1), (1, 1)),
290
+ nn.LeakyReLU(0.2, True),
291
+ nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)),
292
+ nn.LeakyReLU(0.2, True),
293
+ nn.Conv2d(128, 64, (3, 3), (1, 1), (1, 1)),
294
+ nn.Tanh()
295
+ )
296
+
297
+ self.conv_block0_branch1 = nn.Sequential(
298
+ nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)),
299
+ nn.LeakyReLU(0.2, True),
300
+ nn.Conv2d(64, 128, (3, 3), (1, 1), (1, 1)),
301
+ nn.LeakyReLU(0.2, True),
302
+ nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)),
303
+ nn.LeakyReLU(0.2, True),
304
+ nn.Conv2d(128, 64, (3, 3), (1, 1), (1, 1)),
305
+ nn.Tanh()
306
+ )
307
+
308
+ self.conv_block1_branch0 = nn.Sequential(
309
+ nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)),
310
+ nn.LeakyReLU(0.2, True),
311
+ nn.Conv2d(64, 1, (3, 3), (1, 1), (1, 1)),
312
+ #nn.LeakyReLU(0.2, True),
313
+ #nn.Conv2d(32, 1, (3, 3), (1, 1), (1, 1)),
314
+ nn.Sigmoid()
315
+ )
316
+
317
+
318
+
319
+ self.conv_block1_branch1 = nn.Sequential(
320
+ nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)),
321
+ nn.LeakyReLU(0.2, True),
322
+ nn.Conv2d(64, 1, (3, 3), (1, 1), (1, 1)),
323
+ nn.Sigmoid())
324
+
325
+
326
+
327
+
328
+ def _forward_impl(self, x: Tensor) -> Tensor:
329
+ #RLNet
330
+ out1 = self.RLNetconv_block1(x)
331
+ out = self.RLNettrunk(out1)
332
+ out2 = self.RLNetconv_block2(out)
333
+ out = out1 + out2
334
+ out = self.RLNetconv_block3(out)
335
+ out = self.RLNetconv_block4(out)
336
+ rlNet_out = out + x
337
+
338
+ #Generator
339
+ out1 = self.conv_block1(rlNet_out)
340
+ out = self.trunk(out1)
341
+ out2 = self.conv_block2(out)
342
+ out = out1 + out2
343
+ out = self.upsampling(F.interpolate(out, scale_factor=2, mode="bicubic"))
344
+ out = self.upsampling(F.interpolate(out, scale_factor=2, mode="bicubic"))
345
+ out = self.conv_block3(out)
346
+ #
347
+ out = self.conv_block4(out)
348
+
349
+ #demResidual = out[:, 1:2, :, :]
350
+ #grayResidual = out[:, 0:1, :, :]
351
+
352
+ # out = self.trunkRGB(out_4)
353
+ #
354
+ # out_dem = out[:, 3:4, :, :] * 0.2 + demResidual # DEM images extracted
355
+ # out_rgb = out[:, 0:3, :, :] * 0.2 + rgbResidual # RGB images extracted
356
+
357
+ #ra0
358
+ #out_rgb= rgbResidual + self.conv_block0_branch0(rgbResidual)
359
+
360
+ out_dem = out + self.conv_block0_branch1(out) #out+ tanh()
361
+ out_gray = out + self.conv_block0_branch0(out) #out+ tanh()
362
+
363
+ out_gray = self.conv_block1_branch0(out_gray) #sigmoid()
364
+ out_dem = self.conv_block1_branch1(out_dem) #sigmoid()
365
+
366
+ return out_gray, out_dem, rlNet_out
367
+
368
+
369
+ def forward(self, x: Tensor) -> Tensor:
370
+ return self._forward_impl(x)
371
+
372
+ def _initialize_weights(self) -> None:
373
+ for m in self.modules():
374
+ if isinstance(m, nn.Conv2d):
375
+ nn.init.kaiming_normal_(m.weight)
376
+ if m.bias is not None:
377
+ nn.init.constant_(m.bias, 0)
378
+ m.weight.data *= 0.1
379
+ elif isinstance(m, nn.BatchNorm2d):
380
+ nn.init.constant_(m.weight, 1)
381
+ m.weight.data *= 0.1
models/modelNetB.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch import Tensor
5
+
6
+ __all__ = [
7
+ "ResidualDenseBlock", "ResidualResidualDenseBlock", "Generator",
8
+ "DownSamplingNetwork"
9
+ ]
10
+
11
+
12
+ class ResidualDenseBlock(nn.Module):
13
+ """Achieves densely connected convolutional layers.
14
+ `Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993v5.pdf>` paper.
15
+
16
+ Args:
17
+ channels (int): The number of channels in the input image.
18
+ growths (int): The number of channels that increase in each layer of convolution.
19
+ """
20
+
21
+ def __init__(self, channels: int, growths: int) -> None:
22
+ super(ResidualDenseBlock, self).__init__()
23
+ self.conv1 = nn.Conv2d(channels + growths * 0, growths, (3, 3), (1, 1), (1, 1))
24
+ self.conv2 = nn.Conv2d(channels + growths * 1, growths, (3, 3), (1, 1), (1, 1))
25
+ self.conv3 = nn.Conv2d(channels + growths * 2, growths, (3, 3), (1, 1), (1, 1))
26
+ self.conv4 = nn.Conv2d(channels + growths * 3, growths, (3, 3), (1, 1), (1, 1))
27
+ self.conv5 = nn.Conv2d(channels + growths * 4, channels, (3, 3), (1, 1), (1, 1))
28
+
29
+ self.leaky_relu = nn.LeakyReLU(0.2, True)
30
+ self.identity = nn.Identity()
31
+
32
+ def forward(self, x: Tensor) -> Tensor:
33
+ identity = x
34
+
35
+ out1 = self.leaky_relu(self.conv1(x))
36
+ out2 = self.leaky_relu(self.conv2(torch.cat([x, out1], 1)))
37
+ out3 = self.leaky_relu(self.conv3(torch.cat([x, out1, out2], 1)))
38
+ out4 = self.leaky_relu(self.conv4(torch.cat([x, out1, out2, out3], 1)))
39
+ out5 = self.identity(self.conv5(torch.cat([x, out1, out2, out3, out4], 1)))
40
+ out = out5 * 0.2 + identity
41
+
42
+ return out
43
+
44
+
45
+
46
+ class ResidualDenseBlock(nn.Module):
47
+ """Achieves densely connected convolutional layers.
48
+ `Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993v5.pdf>` paper.
49
+
50
+ Args:
51
+ channels (int): The number of channels in the input image.
52
+ growths (int): The number of channels that increase in each layer of convolution.
53
+ """
54
+
55
+ def __init__(self, channels: int, growths: int) -> None:
56
+ super(ResidualDenseBlock, self).__init__()
57
+ self.conv1 = nn.Conv2d(channels + growths * 0, growths, (3, 3), (1, 1), (1, 1))
58
+ self.conv2 = nn.Conv2d(channels + growths * 1, growths, (3, 3), (1, 1), (1, 1))
59
+ self.conv3 = nn.Conv2d(channels + growths * 2, growths, (3, 3), (1, 1), (1, 1))
60
+ self.conv4 = nn.Conv2d(channels + growths * 3, growths, (3, 3), (1, 1), (1, 1))
61
+ self.conv5 = nn.Conv2d(channels + growths * 4, channels, (3, 3), (1, 1), (1, 1))
62
+
63
+ self.leaky_relu = nn.LeakyReLU(0.2, True)
64
+ self.identity = nn.Identity()
65
+
66
+ def forward(self, x: Tensor) -> Tensor:
67
+ identity = x
68
+
69
+ out1 = self.leaky_relu(self.conv1(x))
70
+ out2 = self.leaky_relu(self.conv2(torch.cat([x, out1], 1)))
71
+ out3 = self.leaky_relu(self.conv3(torch.cat([x, out1, out2], 1)))
72
+ out4 = self.leaky_relu(self.conv4(torch.cat([x, out1, out2, out3], 1)))
73
+ out5 = self.identity(self.conv5(torch.cat([x, out1, out2, out3, out4], 1)))
74
+ out = out5 * 0.2 + identity
75
+
76
+ return out
77
+
78
+
79
+
80
+ class MiniResidualDenseBlock(nn.Module):
81
+ """Achieves densely connected convolutional layers.
82
+ `Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993v5.pdf>` paper.
83
+
84
+ Args:
85
+ channels (int): The number of channels in the input image.
86
+ growths (int): The number of channels that increase in each layer of convolution.
87
+ """
88
+
89
+ def __init__(self, channels: int, growths: int) -> None:
90
+ super(MiniResidualDenseBlock, self).__init__()
91
+ self.conv1 = nn.Conv2d(channels + growths * 0, growths, (3, 3), (1, 1), (1, 1))
92
+ self.conv2 = nn.Conv2d(channels + growths * 1, growths, (3, 3), (1, 1), (1, 1))
93
+ self.conv3 = nn.Conv2d(channels + growths * 2, growths, (3, 3), (1, 1), (1, 1))
94
+ self.conv4 = nn.Conv2d(channels + growths * 3, growths, (3, 3), (1, 1), (1, 1))
95
+ self.conv5 = nn.Conv2d(channels + growths * 4, channels, (3, 3), (1, 1), (1, 1))
96
+
97
+ self.leaky_relu = nn.LeakyReLU(0.2, True)
98
+
99
+ def forward(self, x: Tensor) -> Tensor:
100
+ identity = x
101
+
102
+ out1 = self.leaky_relu(self.conv1(x))
103
+ out2 = self.leaky_relu(self.conv2(torch.cat([x, out1], 1)))
104
+ out3 = self.leaky_relu(self.conv3(torch.cat([x, out1, out2], 1)))
105
+ out4 = self.leaky_relu(self.conv4(torch.cat([x, out1, out2, out3], 1)))
106
+ out5 = self.leaky_relu(self.conv5(torch.cat([x, out1, out2, out3, out4], 1)))
107
+ out = out5 * 0.2 + identity
108
+
109
+ return out
110
+
111
+
112
+
113
+ class ResidualResidualDenseBlock(nn.Module):
114
+ """Multi-layer residual dense convolution block.
115
+
116
+ Args:
117
+ channels (int): The number of channels in the input image.
118
+ growths (int): The number of channels that increase in each layer of convolution.
119
+ """
120
+
121
+ def __init__(self, channels: int, growths: int) -> None:
122
+ super(ResidualResidualDenseBlock, self).__init__()
123
+ self.rdb1 = ResidualDenseBlock(channels, growths)
124
+ self.rdb2 = ResidualDenseBlock(channels, growths)
125
+ self.rdb3 = ResidualDenseBlock(channels, growths)
126
+
127
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
128
+ identity = x
129
+
130
+ out = self.rdb1(x)
131
+ out = self.rdb2(out)
132
+ out = self.rdb3(out)
133
+ out = out * 0.2 + identity
134
+
135
+ return out
136
+
137
+
138
+ class MiniResidualResidualDenseBlock(nn.Module):
139
+ """Multi-layer residual dense convolution block.
140
+
141
+ Args:
142
+ channels (int): The number of channels in the input image.
143
+ growths (int): The number of channels that increase in each layer of convolution.
144
+ """
145
+
146
+ def __init__(self, channels: int, growths: int) -> None:
147
+ super(MiniResidualResidualDenseBlock, self).__init__()
148
+ self.M_rdb1 = MiniResidualDenseBlock(channels, growths)
149
+ self.M_rdb2 = MiniResidualDenseBlock(channels, growths)
150
+ self.M_rdb3 = MiniResidualDenseBlock(channels, growths)
151
+
152
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
153
+ identity = x
154
+ out = self.M_rdb1(x)
155
+ out = self.M_rdb2(out)
156
+ out = self.M_rdb3(out)
157
+ out = out * 0.2 + identity
158
+ return out
159
+
160
+
161
+ class Generator(nn.Module):
162
+ def __init__(self) -> None:
163
+ super(Generator, self).__init__()
164
+
165
+ #RLNet
166
+ self.RLNetconv_block1 = nn.Conv2d(1, 64, (3, 3), (1, 1), (1, 1))
167
+ RLNettrunk = []
168
+ for _ in range(4):
169
+ RLNettrunk += [ResidualResidualDenseBlock(64, 32)]
170
+ self.RLNettrunk = nn.Sequential(*RLNettrunk)
171
+ self.RLNetconv_block2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
172
+ self.RLNetconv_block3 = nn.Sequential(
173
+ nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)),
174
+ nn.LeakyReLU(0.2, True)
175
+ )
176
+ self.RLNetconv_block4 = nn.Sequential(
177
+ nn.Conv2d(64, 1, (3, 3), (1, 1), (1, 1)),
178
+ nn.Tanh()
179
+ )
180
+
181
+ #############################################################################
182
+ # Generator
183
+ self.conv_block1 = nn.Conv2d(1, 64, (3, 3), (1, 1), (1, 1))
184
+ trunk = []
185
+ for _ in range(16):
186
+ trunk += [ResidualResidualDenseBlock(64, 32)]
187
+ self.trunk = nn.Sequential(*trunk)
188
+
189
+ # After the feature extraction network, reconnect a layer of convolutional blocks.
190
+ self.conv_block2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
191
+
192
+
193
+ # Upsampling convolutional layer.
194
+ self.upsampling = nn.Sequential(
195
+ nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)),
196
+ nn.LeakyReLU(0.2, True)
197
+ )
198
+
199
+ # Reconnect a layer of convolution block after upsampling.
200
+ self.conv_block3 = nn.Sequential(
201
+ nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)),
202
+ nn.LeakyReLU(0.2, True)
203
+ )
204
+
205
+ self.conv_block4 = nn.Sequential(
206
+ nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)),
207
+ #nn.Sigmoid()
208
+ )
209
+
210
+ self.conv_block0_branch0 = nn.Sequential(
211
+ nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)),
212
+ nn.LeakyReLU(0.2, True),
213
+ nn.Conv2d(64, 128, (3, 3), (1, 1), (1, 1)),
214
+ nn.LeakyReLU(0.2, True),
215
+ nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)),
216
+ nn.LeakyReLU(0.2, True),
217
+ nn.Conv2d(128, 64, (3, 3), (1, 1), (1, 1)),
218
+ nn.Tanh()
219
+ )
220
+
221
+ self.conv_block0_branch1 = nn.Sequential(
222
+ nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)),
223
+ nn.LeakyReLU(0.2, True),
224
+ nn.Conv2d(64, 128, (3, 3), (1, 1), (1, 1)),
225
+ nn.LeakyReLU(0.2, True),
226
+ nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)),
227
+ nn.LeakyReLU(0.2, True),
228
+ nn.Conv2d(128, 64, (3, 3), (1, 1), (1, 1)),
229
+ nn.Tanh()
230
+ )
231
+
232
+ self.conv_block1_branch0 = nn.Sequential(
233
+ nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)),
234
+ nn.LeakyReLU(0.2, True),
235
+ nn.Conv2d(64, 1, (3, 3), (1, 1), (1, 1)),
236
+ #nn.LeakyReLU(0.2, True),
237
+ #nn.Conv2d(32, 1, (3, 3), (1, 1), (1, 1)),
238
+ nn.Sigmoid()
239
+ )
240
+
241
+
242
+
243
+ self.conv_block1_branch1 = nn.Sequential(
244
+ nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)),
245
+ nn.LeakyReLU(0.2, True),
246
+ nn.Conv2d(64, 1, (3, 3), (1, 1), (1, 1)),
247
+ nn.Sigmoid())
248
+
249
+
250
+
251
+
252
+ def _forward_impl(self, x: Tensor) -> Tensor:
253
+ #RLNet
254
+ out1 = self.RLNetconv_block1(x)
255
+ out = self.RLNettrunk(out1)
256
+ out2 = self.RLNetconv_block2(out)
257
+ out = out1 + out2
258
+ out = self.RLNetconv_block3(out)
259
+ out = self.RLNetconv_block4(out)
260
+ rlNet_out = out + x
261
+
262
+ #Generator
263
+ out1 = self.conv_block1(rlNet_out)
264
+ out = self.trunk(out1)
265
+ out2 = self.conv_block2(out)
266
+ out = out1 + out2
267
+ out = self.upsampling(F.interpolate(out, scale_factor=2, mode="bicubic"))
268
+ out = self.upsampling(F.interpolate(out, scale_factor=2, mode="bicubic"))
269
+ out = self.conv_block3(out)
270
+ #
271
+ out = self.conv_block4(out)
272
+
273
+ #demResidual = out[:, 1:2, :, :]
274
+ #grayResidual = out[:, 0:1, :, :]
275
+
276
+ # out = self.trunkRGB(out_4)
277
+ #
278
+ # out_dem = out[:, 3:4, :, :] * 0.2 + demResidual # DEM images extracted
279
+ # out_rgb = out[:, 0:3, :, :] * 0.2 + rgbResidual # RGB images extracted
280
+
281
+ #ra0
282
+ #out_rgb= rgbResidual + self.conv_block0_branch0(rgbResidual)
283
+
284
+ out_dem = out + self.conv_block0_branch1(out) #out+ tanh()
285
+ out_gray = out + self.conv_block0_branch0(out) #out+ tanh()
286
+
287
+ out_gray = self.conv_block1_branch0(out_gray) #sigmoid()
288
+ out_dem = self.conv_block1_branch1(out_dem) #sigmoid()
289
+
290
+ return out_gray, out_dem, rlNet_out
291
+
292
+
293
+ def forward(self, x: Tensor) -> Tensor:
294
+ return self._forward_impl(x)
295
+
296
+ def _initialize_weights(self) -> None:
297
+ for m in self.modules():
298
+ if isinstance(m, nn.Conv2d):
299
+ nn.init.kaiming_normal_(m.weight)
300
+ if m.bias is not None:
301
+ nn.init.constant_(m.bias, 0)
302
+ m.weight.data *= 0.1
303
+ elif isinstance(m, nn.BatchNorm2d):
304
+ nn.init.constant_(m.weight, 1)
305
+ m.weight.data *= 0.1
306
+
307
+
models/modelNetC.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch import Tensor
5
+
6
+ __all__ = [
7
+ "ResidualDenseBlock", "ResidualResidualDenseBlock", "Generator",
8
+ "DownSamplingNetwork"
9
+ ]
10
+
11
+
12
+ class ResidualDenseBlock(nn.Module):
13
+ """Achieves densely connected convolutional layers.
14
+ `Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993v5.pdf>` paper.
15
+
16
+ Args:
17
+ channels (int): The number of channels in the input image.
18
+ growths (int): The number of channels that increase in each layer of convolution.
19
+ """
20
+
21
+ def __init__(self, channels: int, growths: int) -> None:
22
+ super(ResidualDenseBlock, self).__init__()
23
+ self.conv1 = nn.Conv2d(channels + growths * 0, growths, (3, 3), (1, 1), (1, 1))
24
+ self.conv2 = nn.Conv2d(channels + growths * 1, growths, (3, 3), (1, 1), (1, 1))
25
+ self.conv3 = nn.Conv2d(channels + growths * 2, growths, (3, 3), (1, 1), (1, 1))
26
+ self.conv4 = nn.Conv2d(channels + growths * 3, growths, (3, 3), (1, 1), (1, 1))
27
+ self.conv5 = nn.Conv2d(channels + growths * 4, channels, (3, 3), (1, 1), (1, 1))
28
+
29
+ self.leaky_relu = nn.LeakyReLU(0.2, True)
30
+ self.identity = nn.Identity()
31
+
32
+ def forward(self, x: Tensor) -> Tensor:
33
+ identity = x
34
+
35
+ out1 = self.leaky_relu(self.conv1(x))
36
+ out2 = self.leaky_relu(self.conv2(torch.cat([x, out1], 1)))
37
+ out3 = self.leaky_relu(self.conv3(torch.cat([x, out1, out2], 1)))
38
+ out4 = self.leaky_relu(self.conv4(torch.cat([x, out1, out2, out3], 1)))
39
+ out5 = self.identity(self.conv5(torch.cat([x, out1, out2, out3, out4], 1)))
40
+ out = out5 * 0.2 + identity
41
+
42
+ return out
43
+
44
+
45
+
46
+ class ResidualDenseBlock(nn.Module):
47
+ """Achieves densely connected convolutional layers.
48
+ `Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993v5.pdf>` paper.
49
+
50
+ Args:
51
+ channels (int): The number of channels in the input image.
52
+ growths (int): The number of channels that increase in each layer of convolution.
53
+ """
54
+
55
+ def __init__(self, channels: int, growths: int) -> None:
56
+ super(ResidualDenseBlock, self).__init__()
57
+ self.conv1 = nn.Conv2d(channels + growths * 0, growths, (3, 3), (1, 1), (1, 1))
58
+ self.conv2 = nn.Conv2d(channels + growths * 1, growths, (3, 3), (1, 1), (1, 1))
59
+ self.conv3 = nn.Conv2d(channels + growths * 2, growths, (3, 3), (1, 1), (1, 1))
60
+ self.conv4 = nn.Conv2d(channels + growths * 3, growths, (3, 3), (1, 1), (1, 1))
61
+ self.conv5 = nn.Conv2d(channels + growths * 4, channels, (3, 3), (1, 1), (1, 1))
62
+
63
+ self.leaky_relu = nn.LeakyReLU(0.2, True)
64
+ self.identity = nn.Identity()
65
+
66
+ def forward(self, x: Tensor) -> Tensor:
67
+ identity = x
68
+
69
+ out1 = self.leaky_relu(self.conv1(x))
70
+ out2 = self.leaky_relu(self.conv2(torch.cat([x, out1], 1)))
71
+ out3 = self.leaky_relu(self.conv3(torch.cat([x, out1, out2], 1)))
72
+ out4 = self.leaky_relu(self.conv4(torch.cat([x, out1, out2, out3], 1)))
73
+ out5 = self.identity(self.conv5(torch.cat([x, out1, out2, out3, out4], 1)))
74
+ out = out5 * 0.2 + identity
75
+
76
+ return out
77
+
78
+
79
+
80
+ class MiniResidualDenseBlock(nn.Module):
81
+ """Achieves densely connected convolutional layers.
82
+ `Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993v5.pdf>` paper.
83
+
84
+ Args:
85
+ channels (int): The number of channels in the input image.
86
+ growths (int): The number of channels that increase in each layer of convolution.
87
+ """
88
+
89
+ def __init__(self, channels: int, growths: int) -> None:
90
+ super(MiniResidualDenseBlock, self).__init__()
91
+ self.conv1 = nn.Conv2d(channels + growths * 0, growths, (3, 3), (1, 1), (1, 1))
92
+ self.conv2 = nn.Conv2d(channels + growths * 1, growths, (3, 3), (1, 1), (1, 1))
93
+ self.conv3 = nn.Conv2d(channels + growths * 2, growths, (3, 3), (1, 1), (1, 1))
94
+ self.conv4 = nn.Conv2d(channels + growths * 3, growths, (3, 3), (1, 1), (1, 1))
95
+ self.conv5 = nn.Conv2d(channels + growths * 4, channels, (3, 3), (1, 1), (1, 1))
96
+
97
+ self.leaky_relu = nn.LeakyReLU(0.2, True)
98
+
99
+ def forward(self, x: Tensor) -> Tensor:
100
+ identity = x
101
+
102
+ out1 = self.leaky_relu(self.conv1(x))
103
+ out2 = self.leaky_relu(self.conv2(torch.cat([x, out1], 1)))
104
+ out3 = self.leaky_relu(self.conv3(torch.cat([x, out1, out2], 1)))
105
+ out4 = self.leaky_relu(self.conv4(torch.cat([x, out1, out2, out3], 1)))
106
+ out5 = self.leaky_relu(self.conv5(torch.cat([x, out1, out2, out3, out4], 1)))
107
+ out = out5 * 0.2 + identity
108
+
109
+ return out
110
+
111
+
112
+
113
+ class ResidualResidualDenseBlock(nn.Module):
114
+ """Multi-layer residual dense convolution block.
115
+
116
+ Args:
117
+ channels (int): The number of channels in the input image.
118
+ growths (int): The number of channels that increase in each layer of convolution.
119
+ """
120
+
121
+ def __init__(self, channels: int, growths: int) -> None:
122
+ super(ResidualResidualDenseBlock, self).__init__()
123
+ self.rdb1 = ResidualDenseBlock(channels, growths)
124
+ self.rdb2 = ResidualDenseBlock(channels, growths)
125
+ self.rdb3 = ResidualDenseBlock(channels, growths)
126
+
127
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
128
+ identity = x
129
+
130
+ out = self.rdb1(x)
131
+ out = self.rdb2(out)
132
+ out = self.rdb3(out)
133
+ out = out * 0.2 + identity
134
+
135
+ return out
136
+
137
+
138
+ class MiniResidualResidualDenseBlock(nn.Module):
139
+ """Multi-layer residual dense convolution block.
140
+
141
+ Args:
142
+ channels (int): The number of channels in the input image.
143
+ growths (int): The number of channels that increase in each layer of convolution.
144
+ """
145
+
146
+ def __init__(self, channels: int, growths: int) -> None:
147
+ super(MiniResidualResidualDenseBlock, self).__init__()
148
+ self.M_rdb1 = MiniResidualDenseBlock(channels, growths)
149
+ self.M_rdb2 = MiniResidualDenseBlock(channels, growths)
150
+ self.M_rdb3 = MiniResidualDenseBlock(channels, growths)
151
+
152
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
153
+ identity = x
154
+ out = self.M_rdb1(x)
155
+ out = self.M_rdb2(out)
156
+ out = self.M_rdb3(out)
157
+ out = out * 0.2 + identity
158
+ return out
159
+
160
+
161
+ class Generator(nn.Module):
162
+ def __init__(self) -> None:
163
+ super(Generator, self).__init__()
164
+ # Generator
165
+ self.conv_block1 = nn.Conv2d(1, 64, (3, 3), (1, 1), (1, 1))
166
+ trunk = []
167
+ for _ in range(16):
168
+ trunk += [ResidualResidualDenseBlock(64, 32)]
169
+ self.trunk = nn.Sequential(*trunk)
170
+
171
+ # After the feature extraction network, reconnect a layer of convolutional blocks.
172
+ self.conv_block2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
173
+
174
+
175
+ # Upsampling convolutional layer.
176
+ self.upsampling = nn.Sequential(
177
+ nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)),
178
+ nn.LeakyReLU(0.2, True)
179
+ )
180
+
181
+ # Reconnect a layer of convolution block after upsampling.
182
+ self.conv_block3 = nn.Sequential(
183
+ nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)),
184
+ nn.LeakyReLU(0.2, True)
185
+ )
186
+
187
+ self.conv_block4 = nn.Sequential(
188
+ nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)),
189
+ #nn.Sigmoid()
190
+ )
191
+
192
+ self.conv_block0_branch0 = nn.Sequential(
193
+ nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)),
194
+ nn.LeakyReLU(0.2, True),
195
+ nn.Conv2d(64, 128, (3, 3), (1, 1), (1, 1)),
196
+ nn.LeakyReLU(0.2, True),
197
+ nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)),
198
+ nn.LeakyReLU(0.2, True),
199
+ nn.Conv2d(128, 64, (3, 3), (1, 1), (1, 1)),
200
+ nn.Tanh()
201
+ )
202
+
203
+ self.conv_block0_branch1 = nn.Sequential(
204
+ nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)),
205
+ nn.LeakyReLU(0.2, True),
206
+ nn.Conv2d(64, 128, (3, 3), (1, 1), (1, 1)),
207
+ nn.LeakyReLU(0.2, True),
208
+ nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)),
209
+ nn.LeakyReLU(0.2, True),
210
+ nn.Conv2d(128, 64, (3, 3), (1, 1), (1, 1)),
211
+ nn.Tanh()
212
+ )
213
+
214
+ self.conv_block1_branch0 = nn.Sequential(
215
+ nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)),
216
+ nn.LeakyReLU(0.2, True),
217
+ nn.Conv2d(64, 1, (3, 3), (1, 1), (1, 1)),
218
+ #nn.LeakyReLU(0.2, True),
219
+ #nn.Conv2d(32, 1, (3, 3), (1, 1), (1, 1)),
220
+ nn.Sigmoid()
221
+ )
222
+
223
+
224
+
225
+ self.conv_block1_branch1 = nn.Sequential(
226
+ nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)),
227
+ nn.LeakyReLU(0.2, True),
228
+ nn.Conv2d(64, 1, (3, 3), (1, 1), (1, 1)),
229
+ nn.Sigmoid())
230
+
231
+
232
+
233
+
234
+ def _forward_impl(self, x: Tensor) -> Tensor:
235
+ #Generator
236
+ out1 = self.conv_block1(x)
237
+ out = self.trunk(out1)
238
+ out2 = self.conv_block2(out)
239
+ out = out1 + out2
240
+ out = self.upsampling(F.interpolate(out, scale_factor=2, mode="bicubic"))
241
+ out = self.upsampling(F.interpolate(out, scale_factor=2, mode="bicubic"))
242
+ out = self.conv_block3(out)
243
+ #
244
+ out = self.conv_block4(out)
245
+
246
+ #demResidual = out[:, 1:2, :, :]
247
+ #grayResidual = out[:, 0:1, :, :]
248
+
249
+ # out = self.trunkRGB(out_4)
250
+ #
251
+ # out_dem = out[:, 3:4, :, :] * 0.2 + demResidual # DEM images extracted
252
+ # out_rgb = out[:, 0:3, :, :] * 0.2 + rgbResidual # RGB images extracted
253
+
254
+ #ra0
255
+ #out_rgb= rgbResidual + self.conv_block0_branch0(rgbResidual)
256
+
257
+ out_dem = out + self.conv_block0_branch1(out) #out+ tanh()
258
+ out_gray = out + self.conv_block0_branch0(out) #out+ tanh()
259
+
260
+ out_gray = self.conv_block1_branch0(out_gray) #sigmoid()
261
+ out_dem = self.conv_block1_branch1(out_dem) #sigmoid()
262
+
263
+ return out_gray, out_dem
264
+
265
+
266
+ def forward(self, x: Tensor) -> Tensor:
267
+ return self._forward_impl(x)
268
+
269
+ def _initialize_weights(self) -> None:
270
+ for m in self.modules():
271
+ if isinstance(m, nn.Conv2d):
272
+ nn.init.kaiming_normal_(m.weight)
273
+ if m.bias is not None:
274
+ nn.init.constant_(m.bias, 0)
275
+ m.weight.data *= 0.1
276
+ elif isinstance(m, nn.BatchNorm2d):
277
+ nn.init.constant_(m.weight, 1)
278
+ m.weight.data *= 0.1
279
+
280
+ class Discriminator(nn.Module):
281
+ def __init__(self) -> None:
282
+ super(Discriminator, self).__init__()
283
+ self.features = nn.Sequential(
284
+ # input size. (3) x 512 x 512
285
+ nn.Conv2d(2, 32, (3, 3), (1, 1), (1, 1), bias=True),
286
+ nn.LeakyReLU(0.2, True),
287
+ nn.Conv2d(32, 64, (4, 4), (2, 2), (1, 1), bias=False),
288
+ nn.BatchNorm2d(64),
289
+ nn.LeakyReLU(0.2, True),
290
+ nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1), bias=False),
291
+ nn.BatchNorm2d(64),
292
+ nn.LeakyReLU(0.2, True),
293
+ # state size. (128) x 256 x 256
294
+ nn.Conv2d(64, 128, (4, 4), (2, 2), (1, 1), bias=False),
295
+ nn.BatchNorm2d(128),
296
+ nn.LeakyReLU(0.2, True),
297
+ nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1), bias=False),
298
+ nn.BatchNorm2d(128),
299
+ nn.LeakyReLU(0.2, True),
300
+ # state size. (256) x 64 x 64
301
+ nn.Conv2d(128, 256, (4, 4), (2, 2), (1, 1), bias=False),
302
+ nn.BatchNorm2d(256),
303
+ nn.LeakyReLU(0.2, True),
304
+ nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1), bias=False),
305
+ nn.BatchNorm2d(256),
306
+ nn.LeakyReLU(0.2, True),
307
+ nn.Conv2d(256, 256, (4, 4), (2, 2), (1, 1), bias=False),
308
+ nn.BatchNorm2d(256),
309
+ nn.LeakyReLU(0.2, True),
310
+ nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1), bias=False),
311
+ nn.BatchNorm2d(256),
312
+ nn.LeakyReLU(0.2, True),
313
+ # state size. (512) x 16 x 16
314
+ nn.Conv2d(256, 256, (4, 4), (2, 2), (1, 1), bias=False),
315
+ nn.BatchNorm2d(256),
316
+ nn.LeakyReLU(0.2, True),
317
+
318
+ nn.Conv2d(256, 256, (4, 4), (2, 2), (1, 1), bias=False),
319
+ nn.BatchNorm2d(256),
320
+ nn.LeakyReLU(0.2, True),
321
+ # state size. (512) x 8 x 8
322
+ )
323
+
324
+ self.classifier = nn.Sequential(
325
+ nn.Linear(256 * 8 * 8, 100),
326
+ nn.LeakyReLU(0.2, True),
327
+ nn.Linear(100, 1),
328
+ )
329
+
330
+ def forward(self, x: Tensor) -> Tensor:
331
+ out = self.features(x)
332
+ out = torch.flatten(out, 1)
333
+ out = self.classifier(out)
334
+ return out
335
+
requirements.txt CHANGED
@@ -1,3 +1,4 @@
 
1
  gradio
2
  torch
3
- torchvision
 
1
+ matplotlib
2
  gradio
3
  torch
4
+ torchvision
test.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ from torchvision import transforms
4
+ from PIL import Image
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ from models.modelNetA import Generator as GA
8
+ from models.modelNetB import Generator as GB
9
+ from models.modelNetC import Generator as GC
10
+
11
+
12
+
13
+ DEVICE='cpu'
14
+ model_type = 'model_b'
15
+
16
+ modeltype2path = {
17
+ 'model_a': 'DTM_exp_train10%_model_a/g-best.pth',
18
+ 'model_b': 'DTM_exp_train10%_model_b/g-best.pth',
19
+ 'model_c': 'DTM_exp_train10%_model_c/g-best.pth',
20
+ }
21
+
22
+ if model_type == 'model_a':
23
+ generator = GA()
24
+ if model_type == 'model_b':
25
+ generator = GB()
26
+ if model_type == 'model_c':
27
+ generator = GC()
28
+
29
+ generator = torch.nn.DataParallel(generator)
30
+ state_dict_Gen = torch.load(modeltype2path[model_type], map_location=torch.device('cpu'))
31
+ generator.load_state_dict(state_dict_Gen)
32
+ generator = generator.module.to(DEVICE)
33
+ # generator.to(DEVICE)
34
+ generator.eval()
35
+
36
+ preprocess = transforms.Compose([
37
+ transforms.Grayscale(),
38
+ transforms.Resize((512, 512)),
39
+ transforms.ToTensor()
40
+ ])
41
+ input_img = Image.open('demo_imgs/fake.jpg')
42
+ torch_img = preprocess(input_img).to(DEVICE).unsqueeze(0).to(DEVICE)
43
+ with torch.no_grad():
44
+ output = generator(torch_img)
45
+ sr, sr_dem_selected = output[0], output[1]
46
+ sr = sr.squeeze(0).cpu()
47
+
48
+ print(sr.shape)
49
+ torchvision.utils.save_image(sr, 'sr.png')
50
+
51
+ sr_dem_selected = sr_dem_selected.squeeze().cpu().detach().numpy()
52
+ print(sr_dem_selected.shape)
53
+ plt.imshow(sr_dem_selected, cmap='jet', vmin=0, vmax=np.max(sr_dem_selected))
54
+ plt.colorbar()
55
+ plt.savefig('test.png')