Upload 7 files
Browse files- deepfillv2/LICENSE +21 -0
- deepfillv2/__init__.py +1 -0
- deepfillv2/network.py +666 -0
- deepfillv2/network_module.py +596 -0
- deepfillv2/network_utils.py +79 -0
- deepfillv2/test_dataset.py +47 -0
- deepfillv2/utils.py +145 -0
deepfillv2/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2020 Qiang Wen
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
deepfillv2/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
deepfillv2/network.py
ADDED
@@ -0,0 +1,666 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.init as init
|
4 |
+
import torchvision
|
5 |
+
|
6 |
+
from deepfillv2.network_module import *
|
7 |
+
|
8 |
+
|
9 |
+
def weights_init(net, init_type="kaiming", init_gain=0.02):
|
10 |
+
"""Initialize network weights.
|
11 |
+
Parameters:
|
12 |
+
net (network) -- network to be initialized
|
13 |
+
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
|
14 |
+
init_var (float) -- scaling factor for normal, xavier and orthogonal.
|
15 |
+
"""
|
16 |
+
|
17 |
+
def init_func(m):
|
18 |
+
classname = m.__class__.__name__
|
19 |
+
if hasattr(m, "weight") and classname.find("Conv") != -1:
|
20 |
+
if init_type == "normal":
|
21 |
+
init.normal_(m.weight.data, 0.0, init_gain)
|
22 |
+
elif init_type == "xavier":
|
23 |
+
init.xavier_normal_(m.weight.data, gain=init_gain)
|
24 |
+
elif init_type == "kaiming":
|
25 |
+
init.kaiming_normal_(m.weight.data, a=0, mode="fan_in")
|
26 |
+
elif init_type == "orthogonal":
|
27 |
+
init.orthogonal_(m.weight.data, gain=init_gain)
|
28 |
+
else:
|
29 |
+
raise NotImplementedError(
|
30 |
+
"initialization method [%s] is not implemented" % init_type
|
31 |
+
)
|
32 |
+
elif classname.find("BatchNorm2d") != -1:
|
33 |
+
init.normal_(m.weight.data, 1.0, 0.02)
|
34 |
+
init.constant_(m.bias.data, 0.0)
|
35 |
+
elif classname.find("Linear") != -1:
|
36 |
+
init.normal_(m.weight, 0, 0.01)
|
37 |
+
init.constant_(m.bias, 0)
|
38 |
+
|
39 |
+
# Apply the initialization function <init_func>
|
40 |
+
net.apply(init_func)
|
41 |
+
|
42 |
+
|
43 |
+
# -----------------------------------------------
|
44 |
+
# Generator
|
45 |
+
# -----------------------------------------------
|
46 |
+
# Input: masked image + mask
|
47 |
+
# Output: filled image
|
48 |
+
class GatedGenerator(nn.Module):
|
49 |
+
def __init__(self, opt):
|
50 |
+
super(GatedGenerator, self).__init__()
|
51 |
+
self.coarse = nn.Sequential(
|
52 |
+
# encoder
|
53 |
+
GatedConv2d(
|
54 |
+
opt.in_channels,
|
55 |
+
opt.latent_channels,
|
56 |
+
5,
|
57 |
+
1,
|
58 |
+
2,
|
59 |
+
pad_type=opt.pad_type,
|
60 |
+
activation=opt.activation,
|
61 |
+
norm=opt.norm,
|
62 |
+
),
|
63 |
+
GatedConv2d(
|
64 |
+
opt.latent_channels,
|
65 |
+
opt.latent_channels * 2,
|
66 |
+
3,
|
67 |
+
2,
|
68 |
+
1,
|
69 |
+
pad_type=opt.pad_type,
|
70 |
+
activation=opt.activation,
|
71 |
+
norm=opt.norm,
|
72 |
+
),
|
73 |
+
GatedConv2d(
|
74 |
+
opt.latent_channels * 2,
|
75 |
+
opt.latent_channels * 2,
|
76 |
+
3,
|
77 |
+
1,
|
78 |
+
1,
|
79 |
+
pad_type=opt.pad_type,
|
80 |
+
activation=opt.activation,
|
81 |
+
norm=opt.norm,
|
82 |
+
),
|
83 |
+
GatedConv2d(
|
84 |
+
opt.latent_channels * 2,
|
85 |
+
opt.latent_channels * 4,
|
86 |
+
3,
|
87 |
+
2,
|
88 |
+
1,
|
89 |
+
pad_type=opt.pad_type,
|
90 |
+
activation=opt.activation,
|
91 |
+
norm=opt.norm,
|
92 |
+
),
|
93 |
+
# Bottleneck
|
94 |
+
GatedConv2d(
|
95 |
+
opt.latent_channels * 4,
|
96 |
+
opt.latent_channels * 4,
|
97 |
+
3,
|
98 |
+
1,
|
99 |
+
1,
|
100 |
+
pad_type=opt.pad_type,
|
101 |
+
activation=opt.activation,
|
102 |
+
norm=opt.norm,
|
103 |
+
),
|
104 |
+
GatedConv2d(
|
105 |
+
opt.latent_channels * 4,
|
106 |
+
opt.latent_channels * 4,
|
107 |
+
3,
|
108 |
+
1,
|
109 |
+
1,
|
110 |
+
pad_type=opt.pad_type,
|
111 |
+
activation=opt.activation,
|
112 |
+
norm=opt.norm,
|
113 |
+
),
|
114 |
+
GatedConv2d(
|
115 |
+
opt.latent_channels * 4,
|
116 |
+
opt.latent_channels * 4,
|
117 |
+
3,
|
118 |
+
1,
|
119 |
+
2,
|
120 |
+
dilation=2,
|
121 |
+
pad_type=opt.pad_type,
|
122 |
+
activation=opt.activation,
|
123 |
+
norm=opt.norm,
|
124 |
+
),
|
125 |
+
GatedConv2d(
|
126 |
+
opt.latent_channels * 4,
|
127 |
+
opt.latent_channels * 4,
|
128 |
+
3,
|
129 |
+
1,
|
130 |
+
4,
|
131 |
+
dilation=4,
|
132 |
+
pad_type=opt.pad_type,
|
133 |
+
activation=opt.activation,
|
134 |
+
norm=opt.norm,
|
135 |
+
),
|
136 |
+
GatedConv2d(
|
137 |
+
opt.latent_channels * 4,
|
138 |
+
opt.latent_channels * 4,
|
139 |
+
3,
|
140 |
+
1,
|
141 |
+
8,
|
142 |
+
dilation=8,
|
143 |
+
pad_type=opt.pad_type,
|
144 |
+
activation=opt.activation,
|
145 |
+
norm=opt.norm,
|
146 |
+
),
|
147 |
+
GatedConv2d(
|
148 |
+
opt.latent_channels * 4,
|
149 |
+
opt.latent_channels * 4,
|
150 |
+
3,
|
151 |
+
1,
|
152 |
+
16,
|
153 |
+
dilation=16,
|
154 |
+
pad_type=opt.pad_type,
|
155 |
+
activation=opt.activation,
|
156 |
+
norm=opt.norm,
|
157 |
+
),
|
158 |
+
GatedConv2d(
|
159 |
+
opt.latent_channels * 4,
|
160 |
+
opt.latent_channels * 4,
|
161 |
+
3,
|
162 |
+
1,
|
163 |
+
1,
|
164 |
+
pad_type=opt.pad_type,
|
165 |
+
activation=opt.activation,
|
166 |
+
norm=opt.norm,
|
167 |
+
),
|
168 |
+
GatedConv2d(
|
169 |
+
opt.latent_channels * 4,
|
170 |
+
opt.latent_channels * 4,
|
171 |
+
3,
|
172 |
+
1,
|
173 |
+
1,
|
174 |
+
pad_type=opt.pad_type,
|
175 |
+
activation=opt.activation,
|
176 |
+
norm=opt.norm,
|
177 |
+
),
|
178 |
+
# decoder
|
179 |
+
TransposeGatedConv2d(
|
180 |
+
opt.latent_channels * 4,
|
181 |
+
opt.latent_channels * 2,
|
182 |
+
3,
|
183 |
+
1,
|
184 |
+
1,
|
185 |
+
pad_type=opt.pad_type,
|
186 |
+
activation=opt.activation,
|
187 |
+
norm=opt.norm,
|
188 |
+
),
|
189 |
+
GatedConv2d(
|
190 |
+
opt.latent_channels * 2,
|
191 |
+
opt.latent_channels * 2,
|
192 |
+
3,
|
193 |
+
1,
|
194 |
+
1,
|
195 |
+
pad_type=opt.pad_type,
|
196 |
+
activation=opt.activation,
|
197 |
+
norm=opt.norm,
|
198 |
+
),
|
199 |
+
TransposeGatedConv2d(
|
200 |
+
opt.latent_channels * 2,
|
201 |
+
opt.latent_channels,
|
202 |
+
3,
|
203 |
+
1,
|
204 |
+
1,
|
205 |
+
pad_type=opt.pad_type,
|
206 |
+
activation=opt.activation,
|
207 |
+
norm=opt.norm,
|
208 |
+
),
|
209 |
+
GatedConv2d(
|
210 |
+
opt.latent_channels,
|
211 |
+
opt.latent_channels // 2,
|
212 |
+
3,
|
213 |
+
1,
|
214 |
+
1,
|
215 |
+
pad_type=opt.pad_type,
|
216 |
+
activation=opt.activation,
|
217 |
+
norm=opt.norm,
|
218 |
+
),
|
219 |
+
GatedConv2d(
|
220 |
+
opt.latent_channels // 2,
|
221 |
+
opt.out_channels,
|
222 |
+
3,
|
223 |
+
1,
|
224 |
+
1,
|
225 |
+
pad_type=opt.pad_type,
|
226 |
+
activation="none",
|
227 |
+
norm=opt.norm,
|
228 |
+
),
|
229 |
+
nn.Tanh(),
|
230 |
+
)
|
231 |
+
|
232 |
+
self.refine_conv = nn.Sequential(
|
233 |
+
GatedConv2d(
|
234 |
+
opt.in_channels,
|
235 |
+
opt.latent_channels,
|
236 |
+
5,
|
237 |
+
1,
|
238 |
+
2,
|
239 |
+
pad_type=opt.pad_type,
|
240 |
+
activation=opt.activation,
|
241 |
+
norm=opt.norm,
|
242 |
+
),
|
243 |
+
GatedConv2d(
|
244 |
+
opt.latent_channels,
|
245 |
+
opt.latent_channels,
|
246 |
+
3,
|
247 |
+
2,
|
248 |
+
1,
|
249 |
+
pad_type=opt.pad_type,
|
250 |
+
activation=opt.activation,
|
251 |
+
norm=opt.norm,
|
252 |
+
),
|
253 |
+
GatedConv2d(
|
254 |
+
opt.latent_channels,
|
255 |
+
opt.latent_channels * 2,
|
256 |
+
3,
|
257 |
+
1,
|
258 |
+
1,
|
259 |
+
pad_type=opt.pad_type,
|
260 |
+
activation=opt.activation,
|
261 |
+
norm=opt.norm,
|
262 |
+
),
|
263 |
+
GatedConv2d(
|
264 |
+
opt.latent_channels * 2,
|
265 |
+
opt.latent_channels * 2,
|
266 |
+
3,
|
267 |
+
2,
|
268 |
+
1,
|
269 |
+
pad_type=opt.pad_type,
|
270 |
+
activation=opt.activation,
|
271 |
+
norm=opt.norm,
|
272 |
+
),
|
273 |
+
GatedConv2d(
|
274 |
+
opt.latent_channels * 2,
|
275 |
+
opt.latent_channels * 4,
|
276 |
+
3,
|
277 |
+
1,
|
278 |
+
1,
|
279 |
+
pad_type=opt.pad_type,
|
280 |
+
activation=opt.activation,
|
281 |
+
norm=opt.norm,
|
282 |
+
),
|
283 |
+
GatedConv2d(
|
284 |
+
opt.latent_channels * 4,
|
285 |
+
opt.latent_channels * 4,
|
286 |
+
3,
|
287 |
+
1,
|
288 |
+
1,
|
289 |
+
pad_type=opt.pad_type,
|
290 |
+
activation=opt.activation,
|
291 |
+
norm=opt.norm,
|
292 |
+
),
|
293 |
+
GatedConv2d(
|
294 |
+
opt.latent_channels * 4,
|
295 |
+
opt.latent_channels * 4,
|
296 |
+
3,
|
297 |
+
1,
|
298 |
+
2,
|
299 |
+
dilation=2,
|
300 |
+
pad_type=opt.pad_type,
|
301 |
+
activation=opt.activation,
|
302 |
+
norm=opt.norm,
|
303 |
+
),
|
304 |
+
GatedConv2d(
|
305 |
+
opt.latent_channels * 4,
|
306 |
+
opt.latent_channels * 4,
|
307 |
+
3,
|
308 |
+
1,
|
309 |
+
4,
|
310 |
+
dilation=4,
|
311 |
+
pad_type=opt.pad_type,
|
312 |
+
activation=opt.activation,
|
313 |
+
norm=opt.norm,
|
314 |
+
),
|
315 |
+
GatedConv2d(
|
316 |
+
opt.latent_channels * 4,
|
317 |
+
opt.latent_channels * 4,
|
318 |
+
3,
|
319 |
+
1,
|
320 |
+
8,
|
321 |
+
dilation=8,
|
322 |
+
pad_type=opt.pad_type,
|
323 |
+
activation=opt.activation,
|
324 |
+
norm=opt.norm,
|
325 |
+
),
|
326 |
+
GatedConv2d(
|
327 |
+
opt.latent_channels * 4,
|
328 |
+
opt.latent_channels * 4,
|
329 |
+
3,
|
330 |
+
1,
|
331 |
+
16,
|
332 |
+
dilation=16,
|
333 |
+
pad_type=opt.pad_type,
|
334 |
+
activation=opt.activation,
|
335 |
+
norm=opt.norm,
|
336 |
+
),
|
337 |
+
)
|
338 |
+
self.refine_atten_1 = nn.Sequential(
|
339 |
+
GatedConv2d(
|
340 |
+
opt.in_channels,
|
341 |
+
opt.latent_channels,
|
342 |
+
5,
|
343 |
+
1,
|
344 |
+
2,
|
345 |
+
pad_type=opt.pad_type,
|
346 |
+
activation=opt.activation,
|
347 |
+
norm=opt.norm,
|
348 |
+
),
|
349 |
+
GatedConv2d(
|
350 |
+
opt.latent_channels,
|
351 |
+
opt.latent_channels,
|
352 |
+
3,
|
353 |
+
2,
|
354 |
+
1,
|
355 |
+
pad_type=opt.pad_type,
|
356 |
+
activation=opt.activation,
|
357 |
+
norm=opt.norm,
|
358 |
+
),
|
359 |
+
GatedConv2d(
|
360 |
+
opt.latent_channels,
|
361 |
+
opt.latent_channels * 2,
|
362 |
+
3,
|
363 |
+
1,
|
364 |
+
1,
|
365 |
+
pad_type=opt.pad_type,
|
366 |
+
activation=opt.activation,
|
367 |
+
norm=opt.norm,
|
368 |
+
),
|
369 |
+
GatedConv2d(
|
370 |
+
opt.latent_channels * 2,
|
371 |
+
opt.latent_channels * 4,
|
372 |
+
3,
|
373 |
+
2,
|
374 |
+
1,
|
375 |
+
pad_type=opt.pad_type,
|
376 |
+
activation=opt.activation,
|
377 |
+
norm=opt.norm,
|
378 |
+
),
|
379 |
+
GatedConv2d(
|
380 |
+
opt.latent_channels * 4,
|
381 |
+
opt.latent_channels * 4,
|
382 |
+
3,
|
383 |
+
1,
|
384 |
+
1,
|
385 |
+
pad_type=opt.pad_type,
|
386 |
+
activation=opt.activation,
|
387 |
+
norm=opt.norm,
|
388 |
+
),
|
389 |
+
GatedConv2d(
|
390 |
+
opt.latent_channels * 4,
|
391 |
+
opt.latent_channels * 4,
|
392 |
+
3,
|
393 |
+
1,
|
394 |
+
1,
|
395 |
+
pad_type=opt.pad_type,
|
396 |
+
activation="relu",
|
397 |
+
norm=opt.norm,
|
398 |
+
),
|
399 |
+
)
|
400 |
+
self.refine_atten_2 = nn.Sequential(
|
401 |
+
GatedConv2d(
|
402 |
+
opt.latent_channels * 4,
|
403 |
+
opt.latent_channels * 4,
|
404 |
+
3,
|
405 |
+
1,
|
406 |
+
1,
|
407 |
+
pad_type=opt.pad_type,
|
408 |
+
activation=opt.activation,
|
409 |
+
norm=opt.norm,
|
410 |
+
),
|
411 |
+
GatedConv2d(
|
412 |
+
opt.latent_channels * 4,
|
413 |
+
opt.latent_channels * 4,
|
414 |
+
3,
|
415 |
+
1,
|
416 |
+
1,
|
417 |
+
pad_type=opt.pad_type,
|
418 |
+
activation=opt.activation,
|
419 |
+
norm=opt.norm,
|
420 |
+
),
|
421 |
+
)
|
422 |
+
self.refine_combine = nn.Sequential(
|
423 |
+
GatedConv2d(
|
424 |
+
opt.latent_channels * 8,
|
425 |
+
opt.latent_channels * 4,
|
426 |
+
3,
|
427 |
+
1,
|
428 |
+
1,
|
429 |
+
pad_type=opt.pad_type,
|
430 |
+
activation=opt.activation,
|
431 |
+
norm=opt.norm,
|
432 |
+
),
|
433 |
+
GatedConv2d(
|
434 |
+
opt.latent_channels * 4,
|
435 |
+
opt.latent_channels * 4,
|
436 |
+
3,
|
437 |
+
1,
|
438 |
+
1,
|
439 |
+
pad_type=opt.pad_type,
|
440 |
+
activation=opt.activation,
|
441 |
+
norm=opt.norm,
|
442 |
+
),
|
443 |
+
TransposeGatedConv2d(
|
444 |
+
opt.latent_channels * 4,
|
445 |
+
opt.latent_channels * 2,
|
446 |
+
3,
|
447 |
+
1,
|
448 |
+
1,
|
449 |
+
pad_type=opt.pad_type,
|
450 |
+
activation=opt.activation,
|
451 |
+
norm=opt.norm,
|
452 |
+
),
|
453 |
+
GatedConv2d(
|
454 |
+
opt.latent_channels * 2,
|
455 |
+
opt.latent_channels * 2,
|
456 |
+
3,
|
457 |
+
1,
|
458 |
+
1,
|
459 |
+
pad_type=opt.pad_type,
|
460 |
+
activation=opt.activation,
|
461 |
+
norm=opt.norm,
|
462 |
+
),
|
463 |
+
TransposeGatedConv2d(
|
464 |
+
opt.latent_channels * 2,
|
465 |
+
opt.latent_channels,
|
466 |
+
3,
|
467 |
+
1,
|
468 |
+
1,
|
469 |
+
pad_type=opt.pad_type,
|
470 |
+
activation=opt.activation,
|
471 |
+
norm=opt.norm,
|
472 |
+
),
|
473 |
+
GatedConv2d(
|
474 |
+
opt.latent_channels,
|
475 |
+
opt.latent_channels // 2,
|
476 |
+
3,
|
477 |
+
1,
|
478 |
+
1,
|
479 |
+
pad_type=opt.pad_type,
|
480 |
+
activation=opt.activation,
|
481 |
+
norm=opt.norm,
|
482 |
+
),
|
483 |
+
GatedConv2d(
|
484 |
+
opt.latent_channels // 2,
|
485 |
+
opt.out_channels,
|
486 |
+
3,
|
487 |
+
1,
|
488 |
+
1,
|
489 |
+
pad_type=opt.pad_type,
|
490 |
+
activation="none",
|
491 |
+
norm=opt.norm,
|
492 |
+
),
|
493 |
+
nn.Tanh(),
|
494 |
+
)
|
495 |
+
|
496 |
+
use_cuda = opt.use_cuda
|
497 |
+
|
498 |
+
self.context_attention = ContextualAttention(
|
499 |
+
ksize=3,
|
500 |
+
stride=1,
|
501 |
+
rate=2,
|
502 |
+
fuse_k=3,
|
503 |
+
softmax_scale=10,
|
504 |
+
fuse=True,
|
505 |
+
use_cuda=use_cuda,
|
506 |
+
)
|
507 |
+
|
508 |
+
def forward(self, img, mask):
|
509 |
+
# img: entire img
|
510 |
+
# mask: 1 for mask region; 0 for unmask region
|
511 |
+
# Coarse
|
512 |
+
first_masked_img = img * (1 - mask) + mask
|
513 |
+
first_in = torch.cat(
|
514 |
+
(first_masked_img, mask), dim=1
|
515 |
+
) # in: [B, 4, H, W]
|
516 |
+
first_out = self.coarse(first_in) # out: [B, 3, H, W]
|
517 |
+
first_out = nn.functional.interpolate(
|
518 |
+
first_out,
|
519 |
+
(img.shape[2], img.shape[3]),
|
520 |
+
recompute_scale_factor=False,
|
521 |
+
)
|
522 |
+
# Refinement
|
523 |
+
second_masked_img = img * (1 - mask) + first_out * mask
|
524 |
+
second_in = torch.cat([second_masked_img, mask], dim=1)
|
525 |
+
refine_conv = self.refine_conv(second_in)
|
526 |
+
refine_atten = self.refine_atten_1(second_in)
|
527 |
+
mask_s = nn.functional.interpolate(
|
528 |
+
mask,
|
529 |
+
(refine_atten.shape[2], refine_atten.shape[3]),
|
530 |
+
recompute_scale_factor=False,
|
531 |
+
)
|
532 |
+
refine_atten = self.context_attention(
|
533 |
+
refine_atten, refine_atten, mask_s
|
534 |
+
)
|
535 |
+
refine_atten = self.refine_atten_2(refine_atten)
|
536 |
+
second_out = torch.cat([refine_conv, refine_atten], dim=1)
|
537 |
+
second_out = self.refine_combine(second_out)
|
538 |
+
second_out = nn.functional.interpolate(
|
539 |
+
second_out,
|
540 |
+
(img.shape[2], img.shape[3]),
|
541 |
+
recompute_scale_factor=False,
|
542 |
+
)
|
543 |
+
return first_out, second_out
|
544 |
+
|
545 |
+
|
546 |
+
# -----------------------------------------------
|
547 |
+
# Discriminator
|
548 |
+
# -----------------------------------------------
|
549 |
+
# Input: generated image / ground truth and mask
|
550 |
+
# Output: patch based region, we set 30 * 30
|
551 |
+
class PatchDiscriminator(nn.Module):
|
552 |
+
def __init__(self, opt):
|
553 |
+
super(PatchDiscriminator, self).__init__()
|
554 |
+
# Down sampling
|
555 |
+
self.block1 = Conv2dLayer(
|
556 |
+
opt.in_channels,
|
557 |
+
opt.latent_channels,
|
558 |
+
7,
|
559 |
+
1,
|
560 |
+
3,
|
561 |
+
pad_type=opt.pad_type,
|
562 |
+
activation=opt.activation,
|
563 |
+
norm=opt.norm,
|
564 |
+
sn=True,
|
565 |
+
)
|
566 |
+
self.block2 = Conv2dLayer(
|
567 |
+
opt.latent_channels,
|
568 |
+
opt.latent_channels * 2,
|
569 |
+
4,
|
570 |
+
2,
|
571 |
+
1,
|
572 |
+
pad_type=opt.pad_type,
|
573 |
+
activation=opt.activation,
|
574 |
+
norm=opt.norm,
|
575 |
+
sn=True,
|
576 |
+
)
|
577 |
+
self.block3 = Conv2dLayer(
|
578 |
+
opt.latent_channels * 2,
|
579 |
+
opt.latent_channels * 4,
|
580 |
+
4,
|
581 |
+
2,
|
582 |
+
1,
|
583 |
+
pad_type=opt.pad_type,
|
584 |
+
activation=opt.activation,
|
585 |
+
norm=opt.norm,
|
586 |
+
sn=True,
|
587 |
+
)
|
588 |
+
self.block4 = Conv2dLayer(
|
589 |
+
opt.latent_channels * 4,
|
590 |
+
opt.latent_channels * 4,
|
591 |
+
4,
|
592 |
+
2,
|
593 |
+
1,
|
594 |
+
pad_type=opt.pad_type,
|
595 |
+
activation=opt.activation,
|
596 |
+
norm=opt.norm,
|
597 |
+
sn=True,
|
598 |
+
)
|
599 |
+
self.block5 = Conv2dLayer(
|
600 |
+
opt.latent_channels * 4,
|
601 |
+
opt.latent_channels * 4,
|
602 |
+
4,
|
603 |
+
2,
|
604 |
+
1,
|
605 |
+
pad_type=opt.pad_type,
|
606 |
+
activation=opt.activation,
|
607 |
+
norm=opt.norm,
|
608 |
+
sn=True,
|
609 |
+
)
|
610 |
+
self.block6 = Conv2dLayer(
|
611 |
+
opt.latent_channels * 4,
|
612 |
+
1,
|
613 |
+
4,
|
614 |
+
2,
|
615 |
+
1,
|
616 |
+
pad_type=opt.pad_type,
|
617 |
+
activation="none",
|
618 |
+
norm="none",
|
619 |
+
sn=True,
|
620 |
+
)
|
621 |
+
|
622 |
+
def forward(self, img, mask):
|
623 |
+
# the input x should contain 4 channels because it is a combination of recon image and mask
|
624 |
+
x = torch.cat((img, mask), 1)
|
625 |
+
x = self.block1(x) # out: [B, 64, 256, 256]
|
626 |
+
x = self.block2(x) # out: [B, 128, 128, 128]
|
627 |
+
x = self.block3(x) # out: [B, 256, 64, 64]
|
628 |
+
x = self.block4(x) # out: [B, 256, 32, 32]
|
629 |
+
x = self.block5(x) # out: [B, 256, 16, 16]
|
630 |
+
x = self.block6(x) # out: [B, 256, 8, 8]
|
631 |
+
return x
|
632 |
+
|
633 |
+
|
634 |
+
# ----------------------------------------
|
635 |
+
# Perceptual Network
|
636 |
+
# ----------------------------------------
|
637 |
+
# VGG-16 conv4_3 features
|
638 |
+
class PerceptualNet(nn.Module):
|
639 |
+
def __init__(self):
|
640 |
+
super(PerceptualNet, self).__init__()
|
641 |
+
block = [
|
642 |
+
torchvision.models.vgg16(pretrained=True).features[:15].eval()
|
643 |
+
]
|
644 |
+
for p in block[0]:
|
645 |
+
p.requires_grad = False
|
646 |
+
self.block = torch.nn.ModuleList(block)
|
647 |
+
self.transform = torch.nn.functional.interpolate
|
648 |
+
self.register_buffer(
|
649 |
+
"mean", torch.FloatTensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
|
650 |
+
)
|
651 |
+
self.register_buffer(
|
652 |
+
"std", torch.FloatTensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
|
653 |
+
)
|
654 |
+
|
655 |
+
def forward(self, x):
|
656 |
+
x = (x - self.mean) / self.std
|
657 |
+
x = self.transform(
|
658 |
+
x,
|
659 |
+
mode="bilinear",
|
660 |
+
size=(224, 224),
|
661 |
+
align_corners=False,
|
662 |
+
recompute_scale_factor=False,
|
663 |
+
)
|
664 |
+
for block in self.block:
|
665 |
+
x = block(x)
|
666 |
+
return x
|
deepfillv2/network_module.py
ADDED
@@ -0,0 +1,596 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
from torch.nn import Parameter
|
5 |
+
|
6 |
+
from deepfillv2.network_utils import *
|
7 |
+
|
8 |
+
|
9 |
+
# -----------------------------------------------
|
10 |
+
# Normal ConvBlock
|
11 |
+
# -----------------------------------------------
|
12 |
+
class Conv2dLayer(nn.Module):
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
in_channels,
|
16 |
+
out_channels,
|
17 |
+
kernel_size,
|
18 |
+
stride=1,
|
19 |
+
padding=0,
|
20 |
+
dilation=1,
|
21 |
+
pad_type="zero",
|
22 |
+
activation="elu",
|
23 |
+
norm="none",
|
24 |
+
sn=False,
|
25 |
+
):
|
26 |
+
super(Conv2dLayer, self).__init__()
|
27 |
+
# Initialize the padding scheme
|
28 |
+
if pad_type == "reflect":
|
29 |
+
self.pad = nn.ReflectionPad2d(padding)
|
30 |
+
elif pad_type == "replicate":
|
31 |
+
self.pad = nn.ReplicationPad2d(padding)
|
32 |
+
elif pad_type == "zero":
|
33 |
+
self.pad = nn.ZeroPad2d(padding)
|
34 |
+
else:
|
35 |
+
assert 0, "Unsupported padding type: {}".format(pad_type)
|
36 |
+
|
37 |
+
# Initialize the normalization type
|
38 |
+
if norm == "bn":
|
39 |
+
self.norm = nn.BatchNorm2d(out_channels)
|
40 |
+
elif norm == "in":
|
41 |
+
self.norm = nn.InstanceNorm2d(out_channels)
|
42 |
+
elif norm == "ln":
|
43 |
+
self.norm = LayerNorm(out_channels)
|
44 |
+
elif norm == "none":
|
45 |
+
self.norm = None
|
46 |
+
else:
|
47 |
+
assert 0, "Unsupported normalization: {}".format(norm)
|
48 |
+
|
49 |
+
# Initialize the activation funtion
|
50 |
+
if activation == "relu":
|
51 |
+
self.activation = nn.ReLU(inplace=True)
|
52 |
+
elif activation == "lrelu":
|
53 |
+
self.activation = nn.LeakyReLU(0.2, inplace=True)
|
54 |
+
elif activation == "elu":
|
55 |
+
self.activation = nn.ELU(inplace=True)
|
56 |
+
elif activation == "selu":
|
57 |
+
self.activation = nn.SELU(inplace=True)
|
58 |
+
elif activation == "tanh":
|
59 |
+
self.activation = nn.Tanh()
|
60 |
+
elif activation == "sigmoid":
|
61 |
+
self.activation = nn.Sigmoid()
|
62 |
+
elif activation == "none":
|
63 |
+
self.activation = None
|
64 |
+
else:
|
65 |
+
assert 0, "Unsupported activation: {}".format(activation)
|
66 |
+
|
67 |
+
# Initialize the convolution layers
|
68 |
+
if sn:
|
69 |
+
self.conv2d = SpectralNorm(
|
70 |
+
nn.Conv2d(
|
71 |
+
in_channels,
|
72 |
+
out_channels,
|
73 |
+
kernel_size,
|
74 |
+
stride,
|
75 |
+
padding=0,
|
76 |
+
dilation=dilation,
|
77 |
+
)
|
78 |
+
)
|
79 |
+
else:
|
80 |
+
self.conv2d = nn.Conv2d(
|
81 |
+
in_channels,
|
82 |
+
out_channels,
|
83 |
+
kernel_size,
|
84 |
+
stride,
|
85 |
+
padding=0,
|
86 |
+
dilation=dilation,
|
87 |
+
)
|
88 |
+
|
89 |
+
def forward(self, x):
|
90 |
+
x = self.pad(x)
|
91 |
+
x = self.conv2d(x)
|
92 |
+
if self.norm:
|
93 |
+
x = self.norm(x)
|
94 |
+
if self.activation:
|
95 |
+
x = self.activation(x)
|
96 |
+
return x
|
97 |
+
|
98 |
+
|
99 |
+
class TransposeConv2dLayer(nn.Module):
|
100 |
+
def __init__(
|
101 |
+
self,
|
102 |
+
in_channels,
|
103 |
+
out_channels,
|
104 |
+
kernel_size,
|
105 |
+
stride=1,
|
106 |
+
padding=0,
|
107 |
+
dilation=1,
|
108 |
+
pad_type="zero",
|
109 |
+
activation="lrelu",
|
110 |
+
norm="none",
|
111 |
+
sn=False,
|
112 |
+
scale_factor=2,
|
113 |
+
):
|
114 |
+
super(TransposeConv2dLayer, self).__init__()
|
115 |
+
# Initialize the conv scheme
|
116 |
+
self.scale_factor = scale_factor
|
117 |
+
self.conv2d = Conv2dLayer(
|
118 |
+
in_channels,
|
119 |
+
out_channels,
|
120 |
+
kernel_size,
|
121 |
+
stride,
|
122 |
+
padding,
|
123 |
+
dilation,
|
124 |
+
pad_type,
|
125 |
+
activation,
|
126 |
+
norm,
|
127 |
+
sn,
|
128 |
+
)
|
129 |
+
|
130 |
+
def forward(self, x):
|
131 |
+
x = F.interpolate(
|
132 |
+
x,
|
133 |
+
scale_factor=self.scale_factor,
|
134 |
+
mode="nearest",
|
135 |
+
recompute_scale_factor=False,
|
136 |
+
)
|
137 |
+
x = self.conv2d(x)
|
138 |
+
return x
|
139 |
+
|
140 |
+
|
141 |
+
# -----------------------------------------------
|
142 |
+
# Gated ConvBlock
|
143 |
+
# -----------------------------------------------
|
144 |
+
class GatedConv2d(nn.Module):
|
145 |
+
def __init__(
|
146 |
+
self,
|
147 |
+
in_channels,
|
148 |
+
out_channels,
|
149 |
+
kernel_size,
|
150 |
+
stride=1,
|
151 |
+
padding=0,
|
152 |
+
dilation=1,
|
153 |
+
pad_type="reflect",
|
154 |
+
activation="elu",
|
155 |
+
norm="none",
|
156 |
+
sn=False,
|
157 |
+
):
|
158 |
+
super(GatedConv2d, self).__init__()
|
159 |
+
# Initialize the padding scheme
|
160 |
+
if pad_type == "reflect":
|
161 |
+
self.pad = nn.ReflectionPad2d(padding)
|
162 |
+
elif pad_type == "replicate":
|
163 |
+
self.pad = nn.ReplicationPad2d(padding)
|
164 |
+
elif pad_type == "zero":
|
165 |
+
self.pad = nn.ZeroPad2d(padding)
|
166 |
+
else:
|
167 |
+
assert 0, "Unsupported padding type: {}".format(pad_type)
|
168 |
+
|
169 |
+
# Initialize the normalization type
|
170 |
+
if norm == "bn":
|
171 |
+
self.norm = nn.BatchNorm2d(out_channels)
|
172 |
+
elif norm == "in":
|
173 |
+
self.norm = nn.InstanceNorm2d(out_channels)
|
174 |
+
elif norm == "ln":
|
175 |
+
self.norm = LayerNorm(out_channels)
|
176 |
+
elif norm == "none":
|
177 |
+
self.norm = None
|
178 |
+
else:
|
179 |
+
assert 0, "Unsupported normalization: {}".format(norm)
|
180 |
+
|
181 |
+
# Initialize the activation funtion
|
182 |
+
if activation == "relu":
|
183 |
+
self.activation = nn.ReLU(inplace=True)
|
184 |
+
elif activation == "lrelu":
|
185 |
+
self.activation = nn.LeakyReLU(0.2, inplace=True)
|
186 |
+
elif activation == "elu":
|
187 |
+
self.activation = nn.ELU()
|
188 |
+
elif activation == "selu":
|
189 |
+
self.activation = nn.SELU(inplace=True)
|
190 |
+
elif activation == "tanh":
|
191 |
+
self.activation = nn.Tanh()
|
192 |
+
elif activation == "sigmoid":
|
193 |
+
self.activation = nn.Sigmoid()
|
194 |
+
elif activation == "none":
|
195 |
+
self.activation = None
|
196 |
+
else:
|
197 |
+
assert 0, "Unsupported activation: {}".format(activation)
|
198 |
+
|
199 |
+
# Initialize the convolution layers
|
200 |
+
if sn:
|
201 |
+
self.conv2d = SpectralNorm(
|
202 |
+
nn.Conv2d(
|
203 |
+
in_channels,
|
204 |
+
out_channels,
|
205 |
+
kernel_size,
|
206 |
+
stride,
|
207 |
+
padding=0,
|
208 |
+
dilation=dilation,
|
209 |
+
)
|
210 |
+
)
|
211 |
+
self.mask_conv2d = SpectralNorm(
|
212 |
+
nn.Conv2d(
|
213 |
+
in_channels,
|
214 |
+
out_channels,
|
215 |
+
kernel_size,
|
216 |
+
stride,
|
217 |
+
padding=0,
|
218 |
+
dilation=dilation,
|
219 |
+
)
|
220 |
+
)
|
221 |
+
else:
|
222 |
+
self.conv2d = nn.Conv2d(
|
223 |
+
in_channels,
|
224 |
+
out_channels,
|
225 |
+
kernel_size,
|
226 |
+
stride,
|
227 |
+
padding=0,
|
228 |
+
dilation=dilation,
|
229 |
+
)
|
230 |
+
self.mask_conv2d = nn.Conv2d(
|
231 |
+
in_channels,
|
232 |
+
out_channels,
|
233 |
+
kernel_size,
|
234 |
+
stride,
|
235 |
+
padding=0,
|
236 |
+
dilation=dilation,
|
237 |
+
)
|
238 |
+
self.sigmoid = torch.nn.Sigmoid()
|
239 |
+
|
240 |
+
def forward(self, x):
|
241 |
+
x = self.pad(x)
|
242 |
+
conv = self.conv2d(x)
|
243 |
+
mask = self.mask_conv2d(x)
|
244 |
+
gated_mask = self.sigmoid(mask)
|
245 |
+
if self.activation:
|
246 |
+
conv = self.activation(conv)
|
247 |
+
x = conv * gated_mask
|
248 |
+
return x
|
249 |
+
|
250 |
+
|
251 |
+
class TransposeGatedConv2d(nn.Module):
|
252 |
+
def __init__(
|
253 |
+
self,
|
254 |
+
in_channels,
|
255 |
+
out_channels,
|
256 |
+
kernel_size,
|
257 |
+
stride=1,
|
258 |
+
padding=0,
|
259 |
+
dilation=1,
|
260 |
+
pad_type="zero",
|
261 |
+
activation="lrelu",
|
262 |
+
norm="none",
|
263 |
+
sn=True,
|
264 |
+
scale_factor=2,
|
265 |
+
):
|
266 |
+
super(TransposeGatedConv2d, self).__init__()
|
267 |
+
# Initialize the conv scheme
|
268 |
+
self.scale_factor = scale_factor
|
269 |
+
self.gated_conv2d = GatedConv2d(
|
270 |
+
in_channels,
|
271 |
+
out_channels,
|
272 |
+
kernel_size,
|
273 |
+
stride,
|
274 |
+
padding,
|
275 |
+
dilation,
|
276 |
+
pad_type,
|
277 |
+
activation,
|
278 |
+
norm,
|
279 |
+
sn,
|
280 |
+
)
|
281 |
+
|
282 |
+
def forward(self, x):
|
283 |
+
x = F.interpolate(
|
284 |
+
x,
|
285 |
+
scale_factor=self.scale_factor,
|
286 |
+
mode="nearest",
|
287 |
+
recompute_scale_factor=False,
|
288 |
+
)
|
289 |
+
x = self.gated_conv2d(x)
|
290 |
+
return x
|
291 |
+
|
292 |
+
|
293 |
+
# ----------------------------------------
|
294 |
+
# Layer Norm
|
295 |
+
# ----------------------------------------
|
296 |
+
class LayerNorm(nn.Module):
|
297 |
+
def __init__(self, num_features, eps=1e-8, affine=True):
|
298 |
+
super(LayerNorm, self).__init__()
|
299 |
+
self.num_features = num_features
|
300 |
+
self.affine = affine
|
301 |
+
self.eps = eps
|
302 |
+
|
303 |
+
if self.affine:
|
304 |
+
self.gamma = Parameter(torch.Tensor(num_features).uniform_())
|
305 |
+
self.beta = Parameter(torch.zeros(num_features))
|
306 |
+
|
307 |
+
def forward(self, x):
|
308 |
+
# layer norm
|
309 |
+
shape = [-1] + [1] * (x.dim() - 1) # for 4d input: [-1, 1, 1, 1]
|
310 |
+
if x.size(0) == 1:
|
311 |
+
# These two lines run much faster in pytorch 0.4 than the two lines listed below.
|
312 |
+
mean = x.view(-1).mean().view(*shape)
|
313 |
+
std = x.view(-1).std().view(*shape)
|
314 |
+
else:
|
315 |
+
mean = x.view(x.size(0), -1).mean(1).view(*shape)
|
316 |
+
std = x.view(x.size(0), -1).std(1).view(*shape)
|
317 |
+
x = (x - mean) / (std + self.eps)
|
318 |
+
# if it is learnable
|
319 |
+
if self.affine:
|
320 |
+
shape = [1, -1] + [1] * (
|
321 |
+
x.dim() - 2
|
322 |
+
) # for 4d input: [1, -1, 1, 1]
|
323 |
+
x = x * self.gamma.view(*shape) + self.beta.view(*shape)
|
324 |
+
return x
|
325 |
+
|
326 |
+
|
327 |
+
# -----------------------------------------------
|
328 |
+
# SpectralNorm
|
329 |
+
# -----------------------------------------------
|
330 |
+
def l2normalize(v, eps=1e-12):
|
331 |
+
return v / (v.norm() + eps)
|
332 |
+
|
333 |
+
|
334 |
+
class SpectralNorm(nn.Module):
|
335 |
+
def __init__(self, module, name="weight", power_iterations=1):
|
336 |
+
super(SpectralNorm, self).__init__()
|
337 |
+
self.module = module
|
338 |
+
self.name = name
|
339 |
+
self.power_iterations = power_iterations
|
340 |
+
if not self._made_params():
|
341 |
+
self._make_params()
|
342 |
+
|
343 |
+
def _update_u_v(self):
|
344 |
+
u = getattr(self.module, self.name + "_u")
|
345 |
+
v = getattr(self.module, self.name + "_v")
|
346 |
+
w = getattr(self.module, self.name + "_bar")
|
347 |
+
|
348 |
+
height = w.data.shape[0]
|
349 |
+
for _ in range(self.power_iterations):
|
350 |
+
v.data = l2normalize(
|
351 |
+
torch.mv(torch.t(w.view(height, -1).data), u.data)
|
352 |
+
)
|
353 |
+
u.data = l2normalize(torch.mv(w.view(height, -1).data, v.data))
|
354 |
+
|
355 |
+
# sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
|
356 |
+
sigma = u.dot(w.view(height, -1).mv(v))
|
357 |
+
setattr(self.module, self.name, w / sigma.expand_as(w))
|
358 |
+
|
359 |
+
def _made_params(self):
|
360 |
+
try:
|
361 |
+
u = getattr(self.module, self.name + "_u")
|
362 |
+
v = getattr(self.module, self.name + "_v")
|
363 |
+
w = getattr(self.module, self.name + "_bar")
|
364 |
+
return True
|
365 |
+
except AttributeError:
|
366 |
+
return False
|
367 |
+
|
368 |
+
def _make_params(self):
|
369 |
+
w = getattr(self.module, self.name)
|
370 |
+
|
371 |
+
height = w.data.shape[0]
|
372 |
+
width = w.view(height, -1).data.shape[1]
|
373 |
+
|
374 |
+
u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
|
375 |
+
v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
|
376 |
+
u.data = l2normalize(u.data)
|
377 |
+
v.data = l2normalize(v.data)
|
378 |
+
w_bar = Parameter(w.data)
|
379 |
+
|
380 |
+
del self.module._parameters[self.name]
|
381 |
+
|
382 |
+
self.module.register_parameter(self.name + "_u", u)
|
383 |
+
self.module.register_parameter(self.name + "_v", v)
|
384 |
+
self.module.register_parameter(self.name + "_bar", w_bar)
|
385 |
+
|
386 |
+
def forward(self, *args):
|
387 |
+
self._update_u_v()
|
388 |
+
return self.module.forward(*args)
|
389 |
+
|
390 |
+
|
391 |
+
class ContextualAttention(nn.Module):
|
392 |
+
def __init__(
|
393 |
+
self,
|
394 |
+
ksize=3,
|
395 |
+
stride=1,
|
396 |
+
rate=1,
|
397 |
+
fuse_k=3,
|
398 |
+
softmax_scale=10,
|
399 |
+
fuse=True,
|
400 |
+
use_cuda=True,
|
401 |
+
device_ids=None,
|
402 |
+
):
|
403 |
+
super(ContextualAttention, self).__init__()
|
404 |
+
self.ksize = ksize
|
405 |
+
self.stride = stride
|
406 |
+
self.rate = rate
|
407 |
+
self.fuse_k = fuse_k
|
408 |
+
self.softmax_scale = softmax_scale
|
409 |
+
self.fuse = fuse
|
410 |
+
self.use_cuda = use_cuda
|
411 |
+
self.device_ids = device_ids
|
412 |
+
|
413 |
+
def forward(self, f, b, mask=None):
|
414 |
+
"""Contextual attention layer implementation.
|
415 |
+
Contextual attention is first introduced in publication:
|
416 |
+
Generative Image Inpainting with Contextual Attention, Yu et al.
|
417 |
+
Args:
|
418 |
+
f: Input feature to match (foreground).
|
419 |
+
b: Input feature for match (background).
|
420 |
+
mask: Input mask for b, indicating patches not available.
|
421 |
+
ksize: Kernel size for contextual attention.
|
422 |
+
stride: Stride for extracting patches from b.
|
423 |
+
rate: Dilation for matching.
|
424 |
+
softmax_scale: Scaled softmax for attention.
|
425 |
+
Returns:
|
426 |
+
torch.tensor: output
|
427 |
+
"""
|
428 |
+
# get shapes
|
429 |
+
raw_int_fs = list(f.size()) # b*c*h*w
|
430 |
+
raw_int_bs = list(b.size()) # b*c*h*w
|
431 |
+
|
432 |
+
# extract patches from background with stride and rate
|
433 |
+
kernel = 2 * self.rate
|
434 |
+
# raw_w is extracted for reconstruction
|
435 |
+
raw_w = extract_image_patches(
|
436 |
+
b,
|
437 |
+
ksizes=[kernel, kernel],
|
438 |
+
strides=[self.rate * self.stride, self.rate * self.stride],
|
439 |
+
rates=[1, 1],
|
440 |
+
padding="same",
|
441 |
+
) # [N, C*k*k, L]
|
442 |
+
# raw_shape: [N, C, k, k, L] [4, 192, 4, 4, 1024]
|
443 |
+
raw_w = raw_w.view(raw_int_bs[0], raw_int_bs[1], kernel, kernel, -1)
|
444 |
+
raw_w = raw_w.permute(0, 4, 1, 2, 3) # raw_shape: [N, L, C, k, k]
|
445 |
+
raw_w_groups = torch.split(raw_w, 1, dim=0)
|
446 |
+
|
447 |
+
# downscaling foreground option: downscaling both foreground and
|
448 |
+
# background for matching and use original background for reconstruction.
|
449 |
+
f = F.interpolate(
|
450 |
+
f,
|
451 |
+
scale_factor=1.0 / self.rate,
|
452 |
+
mode="nearest",
|
453 |
+
recompute_scale_factor=False,
|
454 |
+
)
|
455 |
+
b = F.interpolate(
|
456 |
+
b,
|
457 |
+
scale_factor=1.0 / self.rate,
|
458 |
+
mode="nearest",
|
459 |
+
recompute_scale_factor=False,
|
460 |
+
)
|
461 |
+
int_fs = list(f.size()) # b*c*h*w
|
462 |
+
int_bs = list(b.size())
|
463 |
+
f_groups = torch.split(
|
464 |
+
f, 1, dim=0
|
465 |
+
) # split tensors along the batch dimension
|
466 |
+
# w shape: [N, C*k*k, L]
|
467 |
+
w = extract_image_patches(
|
468 |
+
b,
|
469 |
+
ksizes=[self.ksize, self.ksize],
|
470 |
+
strides=[self.stride, self.stride],
|
471 |
+
rates=[1, 1],
|
472 |
+
padding="same",
|
473 |
+
)
|
474 |
+
# w shape: [N, C, k, k, L]
|
475 |
+
w = w.view(int_bs[0], int_bs[1], self.ksize, self.ksize, -1)
|
476 |
+
w = w.permute(0, 4, 1, 2, 3) # w shape: [N, L, C, k, k]
|
477 |
+
w_groups = torch.split(w, 1, dim=0)
|
478 |
+
|
479 |
+
# process mask
|
480 |
+
mask = F.interpolate(
|
481 |
+
mask,
|
482 |
+
scale_factor=1.0 / self.rate,
|
483 |
+
mode="nearest",
|
484 |
+
recompute_scale_factor=False,
|
485 |
+
)
|
486 |
+
int_ms = list(mask.size())
|
487 |
+
# m shape: [N, C*k*k, L]
|
488 |
+
m = extract_image_patches(
|
489 |
+
mask,
|
490 |
+
ksizes=[self.ksize, self.ksize],
|
491 |
+
strides=[self.stride, self.stride],
|
492 |
+
rates=[1, 1],
|
493 |
+
padding="same",
|
494 |
+
)
|
495 |
+
|
496 |
+
# m shape: [N, C, k, k, L]
|
497 |
+
m = m.view(int_ms[0], int_ms[1], self.ksize, self.ksize, -1)
|
498 |
+
m = m.permute(0, 4, 1, 2, 3) # m shape: [N, L, C, k, k]
|
499 |
+
m = m[0] # m shape: [L, C, k, k]
|
500 |
+
# mm shape: [L, 1, 1, 1]
|
501 |
+
mm = (reduce_mean(m, axis=[1, 2, 3], keepdim=True) == 0.0).to(
|
502 |
+
torch.float32
|
503 |
+
)
|
504 |
+
mm = mm.permute(1, 0, 2, 3) # mm shape: [1, L, 1, 1]
|
505 |
+
|
506 |
+
y = []
|
507 |
+
offsets = []
|
508 |
+
k = self.fuse_k
|
509 |
+
scale = (
|
510 |
+
self.softmax_scale
|
511 |
+
) # to fit the PyTorch tensor image value range
|
512 |
+
fuse_weight = torch.eye(k).view(1, 1, k, k) # 1*1*k*k
|
513 |
+
if self.use_cuda:
|
514 |
+
fuse_weight = fuse_weight.cuda()
|
515 |
+
|
516 |
+
for xi, wi, raw_wi in zip(f_groups, w_groups, raw_w_groups):
|
517 |
+
"""
|
518 |
+
O => output channel as a conv filter
|
519 |
+
I => input channel as a conv filter
|
520 |
+
xi : separated tensor along batch dimension of front; (B=1, C=128, H=32, W=32)
|
521 |
+
wi : separated patch tensor along batch dimension of back; (B=1, O=32*32, I=128, KH=3, KW=3)
|
522 |
+
raw_wi : separated tensor along batch dimension of back; (B=1, I=32*32, O=128, KH=4, KW=4)
|
523 |
+
"""
|
524 |
+
# conv for compare
|
525 |
+
escape_NaN = torch.FloatTensor([1e-4])
|
526 |
+
if self.use_cuda:
|
527 |
+
escape_NaN = escape_NaN.cuda()
|
528 |
+
wi = wi[0] # [L, C, k, k]
|
529 |
+
max_wi = torch.sqrt(
|
530 |
+
reduce_sum(
|
531 |
+
torch.pow(wi, 2) + escape_NaN, axis=[1, 2, 3], keepdim=True
|
532 |
+
)
|
533 |
+
)
|
534 |
+
wi_normed = wi / max_wi
|
535 |
+
# xi shape: [1, C, H, W], yi shape: [1, L, H, W]
|
536 |
+
xi = same_padding(
|
537 |
+
xi, [self.ksize, self.ksize], [1, 1], [1, 1]
|
538 |
+
) # xi: 1*c*H*W
|
539 |
+
yi = F.conv2d(xi, wi_normed, stride=1) # [1, L, H, W]
|
540 |
+
# conv implementation for fuse scores to encourage large patches
|
541 |
+
if self.fuse:
|
542 |
+
# make all of depth to spatial resolution
|
543 |
+
yi = yi.view(
|
544 |
+
1, 1, int_bs[2] * int_bs[3], int_fs[2] * int_fs[3]
|
545 |
+
) # (B=1, I=1, H=32*32, W=32*32)
|
546 |
+
yi = same_padding(yi, [k, k], [1, 1], [1, 1])
|
547 |
+
yi = F.conv2d(
|
548 |
+
yi, fuse_weight, stride=1
|
549 |
+
) # (B=1, C=1, H=32*32, W=32*32)
|
550 |
+
yi = yi.contiguous().view(
|
551 |
+
1, int_bs[2], int_bs[3], int_fs[2], int_fs[3]
|
552 |
+
) # (B=1, 32, 32, 32, 32)
|
553 |
+
yi = yi.permute(0, 2, 1, 4, 3)
|
554 |
+
yi = yi.contiguous().view(
|
555 |
+
1, 1, int_bs[2] * int_bs[3], int_fs[2] * int_fs[3]
|
556 |
+
)
|
557 |
+
yi = same_padding(yi, [k, k], [1, 1], [1, 1])
|
558 |
+
yi = F.conv2d(yi, fuse_weight, stride=1)
|
559 |
+
yi = yi.contiguous().view(
|
560 |
+
1, int_bs[3], int_bs[2], int_fs[3], int_fs[2]
|
561 |
+
)
|
562 |
+
yi = yi.permute(0, 2, 1, 4, 3).contiguous()
|
563 |
+
yi = yi.view(
|
564 |
+
1, int_bs[2] * int_bs[3], int_fs[2], int_fs[3]
|
565 |
+
) # (B=1, C=32*32, H=32, W=32)
|
566 |
+
# softmax to match
|
567 |
+
yi = yi * mm
|
568 |
+
yi = F.softmax(yi * scale, dim=1)
|
569 |
+
yi = yi * mm # [1, L, H, W]
|
570 |
+
|
571 |
+
offset = torch.argmax(yi, dim=1, keepdim=True) # 1*1*H*W
|
572 |
+
|
573 |
+
if int_bs != int_fs:
|
574 |
+
# Normalize the offset value to match foreground dimension
|
575 |
+
times = float(int_fs[2] * int_fs[3]) / float(
|
576 |
+
int_bs[2] * int_bs[3]
|
577 |
+
)
|
578 |
+
offset = ((offset + 1).float() * times - 1).to(torch.int64)
|
579 |
+
offset = torch.cat(
|
580 |
+
[offset // int_fs[3], offset % int_fs[3]], dim=1
|
581 |
+
) # 1*2*H*W
|
582 |
+
|
583 |
+
# deconv for patch pasting
|
584 |
+
wi_center = raw_wi[0]
|
585 |
+
# yi = F.pad(yi, [0, 1, 0, 1]) # here may need conv_transpose same padding
|
586 |
+
yi = (
|
587 |
+
F.conv_transpose2d(yi, wi_center, stride=self.rate, padding=1)
|
588 |
+
/ 4.0
|
589 |
+
) # (B=1, C=128, H=64, W=64)
|
590 |
+
y.append(yi)
|
591 |
+
offsets.append(offset)
|
592 |
+
|
593 |
+
y = torch.cat(y, dim=0) # back to the mini-batch
|
594 |
+
y.contiguous().view(raw_int_fs)
|
595 |
+
|
596 |
+
return y
|
deepfillv2/network_utils.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# for contextual attention
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
def extract_image_patches(images, ksizes, strides, rates, padding="same"):
|
6 |
+
"""
|
7 |
+
Extract patches from images and put them in the C output dimension.
|
8 |
+
:param padding:
|
9 |
+
:param images: [batch, channels, in_rows, in_cols]. A 4-D Tensor with shape
|
10 |
+
:param ksizes: [ksize_rows, ksize_cols]. The size of the sliding window for
|
11 |
+
each dimension of images
|
12 |
+
:param strides: [stride_rows, stride_cols]
|
13 |
+
:param rates: [dilation_rows, dilation_cols]
|
14 |
+
:return: A Tensor
|
15 |
+
"""
|
16 |
+
assert len(images.size()) == 4
|
17 |
+
assert padding in ["same", "valid"]
|
18 |
+
batch_size, channel, height, width = images.size()
|
19 |
+
|
20 |
+
if padding == "same":
|
21 |
+
images = same_padding(images, ksizes, strides, rates)
|
22 |
+
elif padding == "valid":
|
23 |
+
pass
|
24 |
+
else:
|
25 |
+
raise NotImplementedError(
|
26 |
+
'Unsupported padding type: {}.\
|
27 |
+
Only "same" or "valid" are supported.'.format(
|
28 |
+
padding
|
29 |
+
)
|
30 |
+
)
|
31 |
+
|
32 |
+
unfold = torch.nn.Unfold(
|
33 |
+
kernel_size=ksizes, dilation=rates, padding=0, stride=strides
|
34 |
+
)
|
35 |
+
patches = unfold(images)
|
36 |
+
return patches # [N, C*k*k, L], L is the total number of such blocks
|
37 |
+
|
38 |
+
|
39 |
+
def same_padding(images, ksizes, strides, rates):
|
40 |
+
assert len(images.size()) == 4
|
41 |
+
batch_size, channel, rows, cols = images.size()
|
42 |
+
out_rows = (rows + strides[0] - 1) // strides[0]
|
43 |
+
out_cols = (cols + strides[1] - 1) // strides[1]
|
44 |
+
effective_k_row = (ksizes[0] - 1) * rates[0] + 1
|
45 |
+
effective_k_col = (ksizes[1] - 1) * rates[1] + 1
|
46 |
+
padding_rows = max(0, (out_rows - 1) * strides[0] + effective_k_row - rows)
|
47 |
+
padding_cols = max(0, (out_cols - 1) * strides[1] + effective_k_col - cols)
|
48 |
+
# Pad the input
|
49 |
+
padding_top = int(padding_rows / 2.0)
|
50 |
+
padding_left = int(padding_cols / 2.0)
|
51 |
+
padding_bottom = padding_rows - padding_top
|
52 |
+
padding_right = padding_cols - padding_left
|
53 |
+
paddings = (padding_left, padding_right, padding_top, padding_bottom)
|
54 |
+
images = torch.nn.ZeroPad2d(paddings)(images)
|
55 |
+
return images
|
56 |
+
|
57 |
+
|
58 |
+
def reduce_mean(x, axis=None, keepdim=False):
|
59 |
+
if not axis:
|
60 |
+
axis = range(len(x.shape))
|
61 |
+
for i in sorted(axis, reverse=True):
|
62 |
+
x = torch.mean(x, dim=i, keepdim=keepdim)
|
63 |
+
return x
|
64 |
+
|
65 |
+
|
66 |
+
def reduce_std(x, axis=None, keepdim=False):
|
67 |
+
if not axis:
|
68 |
+
axis = range(len(x.shape))
|
69 |
+
for i in sorted(axis, reverse=True):
|
70 |
+
x = torch.std(x, dim=i, keepdim=keepdim)
|
71 |
+
return x
|
72 |
+
|
73 |
+
|
74 |
+
def reduce_sum(x, axis=None, keepdim=False):
|
75 |
+
if not axis:
|
76 |
+
axis = range(len(x.shape))
|
77 |
+
for i in sorted(axis, reverse=True):
|
78 |
+
x = torch.sum(x, dim=i, keepdim=keepdim)
|
79 |
+
return x
|
deepfillv2/test_dataset.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from torch.utils.data import Dataset
|
5 |
+
|
6 |
+
from config import *
|
7 |
+
|
8 |
+
|
9 |
+
class InpaintDataset(Dataset):
|
10 |
+
def __init__(self):
|
11 |
+
self.imglist = [INIMAGE]
|
12 |
+
self.masklist = [MASKIMAGE]
|
13 |
+
self.setsize = RESIZE_TO
|
14 |
+
|
15 |
+
def __len__(self):
|
16 |
+
return len(self.imglist)
|
17 |
+
|
18 |
+
def __getitem__(self, index):
|
19 |
+
# image
|
20 |
+
img = cv2.imread(self.imglist[index])
|
21 |
+
mask = cv2.imread(self.masklist[index])[:, :, 0]
|
22 |
+
## COMMENTING FOR NOW
|
23 |
+
# h, w = mask.shape
|
24 |
+
# # img = cv2.resize(img, (w, h))
|
25 |
+
img = cv2.resize(img, self.setsize)
|
26 |
+
mask = cv2.resize(mask, self.setsize)
|
27 |
+
##
|
28 |
+
# find the Minimum bounding rectangle in the mask
|
29 |
+
"""
|
30 |
+
contours, hier = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
31 |
+
for cidx, cnt in enumerate(contours):
|
32 |
+
(x, y, w, h) = cv2.boundingRect(cnt)
|
33 |
+
mask[y:y+h, x:x+w] = 255
|
34 |
+
"""
|
35 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
36 |
+
|
37 |
+
img = (
|
38 |
+
torch.from_numpy(img.astype(np.float32) / 255.0)
|
39 |
+
.permute(2, 0, 1)
|
40 |
+
.contiguous()
|
41 |
+
)
|
42 |
+
mask = (
|
43 |
+
torch.from_numpy(mask.astype(np.float32) / 255.0)
|
44 |
+
.unsqueeze(0)
|
45 |
+
.contiguous()
|
46 |
+
)
|
47 |
+
return img, mask
|
deepfillv2/utils.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import cv2
|
4 |
+
import torch
|
5 |
+
from deepfillv2 import network
|
6 |
+
import skimage
|
7 |
+
|
8 |
+
from config import GPU_DEVICE
|
9 |
+
|
10 |
+
|
11 |
+
# ----------------------------------------
|
12 |
+
# Network
|
13 |
+
# ----------------------------------------
|
14 |
+
def create_generator(opt):
|
15 |
+
# Initialize the networks
|
16 |
+
generator = network.GatedGenerator(opt)
|
17 |
+
print("-- Generator is created! --")
|
18 |
+
network.weights_init(
|
19 |
+
generator, init_type=opt.init_type, init_gain=opt.init_gain
|
20 |
+
)
|
21 |
+
print("-- Initialized generator with %s type --" % opt.init_type)
|
22 |
+
return generator
|
23 |
+
|
24 |
+
|
25 |
+
def create_discriminator(opt):
|
26 |
+
# Initialize the networks
|
27 |
+
discriminator = network.PatchDiscriminator(opt)
|
28 |
+
print("-- Discriminator is created! --")
|
29 |
+
network.weights_init(
|
30 |
+
discriminator, init_type=opt.init_type, init_gain=opt.init_gain
|
31 |
+
)
|
32 |
+
print("-- Initialize discriminator with %s type --" % opt.init_type)
|
33 |
+
return discriminator
|
34 |
+
|
35 |
+
|
36 |
+
def create_perceptualnet():
|
37 |
+
# Get the first 15 layers of vgg16, which is conv3_3
|
38 |
+
perceptualnet = network.PerceptualNet()
|
39 |
+
print("-- Perceptual network is created! --")
|
40 |
+
return perceptualnet
|
41 |
+
|
42 |
+
|
43 |
+
# ----------------------------------------
|
44 |
+
# PATH processing
|
45 |
+
# ----------------------------------------
|
46 |
+
def text_readlines(filename):
|
47 |
+
# Try to read a txt file and return a list.Return [] if there was a mistake.
|
48 |
+
try:
|
49 |
+
file = open(filename, "r")
|
50 |
+
except IOError:
|
51 |
+
error = []
|
52 |
+
return error
|
53 |
+
content = file.readlines()
|
54 |
+
# This for loop deletes the EOF (like \n)
|
55 |
+
for i in range(len(content)):
|
56 |
+
content[i] = content[i][: len(content[i]) - 1]
|
57 |
+
file.close()
|
58 |
+
return content
|
59 |
+
|
60 |
+
|
61 |
+
def savetxt(name, loss_log):
|
62 |
+
np_loss_log = np.array(loss_log)
|
63 |
+
np.savetxt(name, np_loss_log)
|
64 |
+
|
65 |
+
|
66 |
+
def get_files(path, mask=False):
|
67 |
+
# read a folder, return the complete path
|
68 |
+
ret = []
|
69 |
+
for root, dirs, files in os.walk(path):
|
70 |
+
for filespath in files:
|
71 |
+
if filespath != ".DS_Store":
|
72 |
+
continue
|
73 |
+
ret.append(os.path.join(root, filespath))
|
74 |
+
return ret
|
75 |
+
|
76 |
+
|
77 |
+
def get_names(path):
|
78 |
+
# read a folder, return the image name
|
79 |
+
ret = []
|
80 |
+
for root, dirs, files in os.walk(path):
|
81 |
+
for filespath in files:
|
82 |
+
ret.append(filespath)
|
83 |
+
return ret
|
84 |
+
|
85 |
+
|
86 |
+
def text_save(content, filename, mode="a"):
|
87 |
+
# save a list to a txt
|
88 |
+
# Try to save a list variable in txt file.
|
89 |
+
file = open(filename, mode)
|
90 |
+
for i in range(len(content)):
|
91 |
+
file.write(str(content[i]) + "\n")
|
92 |
+
file.close()
|
93 |
+
|
94 |
+
|
95 |
+
def check_path(path):
|
96 |
+
if not os.path.exists(path):
|
97 |
+
os.makedirs(path)
|
98 |
+
|
99 |
+
|
100 |
+
# ----------------------------------------
|
101 |
+
# Validation and Sample at training
|
102 |
+
# ----------------------------------------
|
103 |
+
def save_sample_png(
|
104 |
+
sample_folder, sample_name, img_list, name_list, pixel_max_cnt=255
|
105 |
+
):
|
106 |
+
# Save image one-by-one
|
107 |
+
for i in range(len(img_list)):
|
108 |
+
img = img_list[i]
|
109 |
+
# Recover normalization: * 255 because last layer is sigmoid activated
|
110 |
+
img = img * 255
|
111 |
+
# Process img_copy and do not destroy the data of img
|
112 |
+
img_copy = (
|
113 |
+
img.clone().data.permute(0, 2, 3, 1)[0, :, :, :].to("cpu").numpy()
|
114 |
+
)
|
115 |
+
img_copy = np.clip(img_copy, 0, pixel_max_cnt)
|
116 |
+
img_copy = img_copy.astype(np.uint8)
|
117 |
+
img_copy = cv2.cvtColor(img_copy, cv2.COLOR_RGB2BGR)
|
118 |
+
# Save to certain path
|
119 |
+
save_img_path = os.path.join(sample_folder, sample_name)
|
120 |
+
cv2.imwrite(save_img_path, img_copy)
|
121 |
+
|
122 |
+
|
123 |
+
def psnr(pred, target, pixel_max_cnt=255):
|
124 |
+
mse = torch.mul(target - pred, target - pred)
|
125 |
+
rmse_avg = (torch.mean(mse).item()) ** 0.5
|
126 |
+
p = 20 * np.log10(pixel_max_cnt / rmse_avg)
|
127 |
+
return p
|
128 |
+
|
129 |
+
|
130 |
+
def grey_psnr(pred, target, pixel_max_cnt=255):
|
131 |
+
pred = torch.sum(pred, dim=0)
|
132 |
+
target = torch.sum(target, dim=0)
|
133 |
+
mse = torch.mul(target - pred, target - pred)
|
134 |
+
rmse_avg = (torch.mean(mse).item()) ** 0.5
|
135 |
+
p = 20 * np.log10(pixel_max_cnt * 3 / rmse_avg)
|
136 |
+
return p
|
137 |
+
|
138 |
+
|
139 |
+
def ssim(pred, target):
|
140 |
+
pred = pred.clone().data.permute(0, 2, 3, 1).to(GPU_DEVICE).numpy()
|
141 |
+
target = target.clone().data.permute(0, 2, 3, 1).to(GPU_DEVICE).numpy()
|
142 |
+
target = target[0]
|
143 |
+
pred = pred[0]
|
144 |
+
ssim = skimage.measure.compare_ssim(target, pred, multichannel=True)
|
145 |
+
return ssim
|