NeuralFalcon commited on
Commit
712b45c
·
verified ·
1 Parent(s): 2b2eb3f

Upload 7 files

Browse files
deepfillv2/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2020 Qiang Wen
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
deepfillv2/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
deepfillv2/network.py ADDED
@@ -0,0 +1,666 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.init as init
4
+ import torchvision
5
+
6
+ from deepfillv2.network_module import *
7
+
8
+
9
+ def weights_init(net, init_type="kaiming", init_gain=0.02):
10
+ """Initialize network weights.
11
+ Parameters:
12
+ net (network) -- network to be initialized
13
+ init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
14
+ init_var (float) -- scaling factor for normal, xavier and orthogonal.
15
+ """
16
+
17
+ def init_func(m):
18
+ classname = m.__class__.__name__
19
+ if hasattr(m, "weight") and classname.find("Conv") != -1:
20
+ if init_type == "normal":
21
+ init.normal_(m.weight.data, 0.0, init_gain)
22
+ elif init_type == "xavier":
23
+ init.xavier_normal_(m.weight.data, gain=init_gain)
24
+ elif init_type == "kaiming":
25
+ init.kaiming_normal_(m.weight.data, a=0, mode="fan_in")
26
+ elif init_type == "orthogonal":
27
+ init.orthogonal_(m.weight.data, gain=init_gain)
28
+ else:
29
+ raise NotImplementedError(
30
+ "initialization method [%s] is not implemented" % init_type
31
+ )
32
+ elif classname.find("BatchNorm2d") != -1:
33
+ init.normal_(m.weight.data, 1.0, 0.02)
34
+ init.constant_(m.bias.data, 0.0)
35
+ elif classname.find("Linear") != -1:
36
+ init.normal_(m.weight, 0, 0.01)
37
+ init.constant_(m.bias, 0)
38
+
39
+ # Apply the initialization function <init_func>
40
+ net.apply(init_func)
41
+
42
+
43
+ # -----------------------------------------------
44
+ # Generator
45
+ # -----------------------------------------------
46
+ # Input: masked image + mask
47
+ # Output: filled image
48
+ class GatedGenerator(nn.Module):
49
+ def __init__(self, opt):
50
+ super(GatedGenerator, self).__init__()
51
+ self.coarse = nn.Sequential(
52
+ # encoder
53
+ GatedConv2d(
54
+ opt.in_channels,
55
+ opt.latent_channels,
56
+ 5,
57
+ 1,
58
+ 2,
59
+ pad_type=opt.pad_type,
60
+ activation=opt.activation,
61
+ norm=opt.norm,
62
+ ),
63
+ GatedConv2d(
64
+ opt.latent_channels,
65
+ opt.latent_channels * 2,
66
+ 3,
67
+ 2,
68
+ 1,
69
+ pad_type=opt.pad_type,
70
+ activation=opt.activation,
71
+ norm=opt.norm,
72
+ ),
73
+ GatedConv2d(
74
+ opt.latent_channels * 2,
75
+ opt.latent_channels * 2,
76
+ 3,
77
+ 1,
78
+ 1,
79
+ pad_type=opt.pad_type,
80
+ activation=opt.activation,
81
+ norm=opt.norm,
82
+ ),
83
+ GatedConv2d(
84
+ opt.latent_channels * 2,
85
+ opt.latent_channels * 4,
86
+ 3,
87
+ 2,
88
+ 1,
89
+ pad_type=opt.pad_type,
90
+ activation=opt.activation,
91
+ norm=opt.norm,
92
+ ),
93
+ # Bottleneck
94
+ GatedConv2d(
95
+ opt.latent_channels * 4,
96
+ opt.latent_channels * 4,
97
+ 3,
98
+ 1,
99
+ 1,
100
+ pad_type=opt.pad_type,
101
+ activation=opt.activation,
102
+ norm=opt.norm,
103
+ ),
104
+ GatedConv2d(
105
+ opt.latent_channels * 4,
106
+ opt.latent_channels * 4,
107
+ 3,
108
+ 1,
109
+ 1,
110
+ pad_type=opt.pad_type,
111
+ activation=opt.activation,
112
+ norm=opt.norm,
113
+ ),
114
+ GatedConv2d(
115
+ opt.latent_channels * 4,
116
+ opt.latent_channels * 4,
117
+ 3,
118
+ 1,
119
+ 2,
120
+ dilation=2,
121
+ pad_type=opt.pad_type,
122
+ activation=opt.activation,
123
+ norm=opt.norm,
124
+ ),
125
+ GatedConv2d(
126
+ opt.latent_channels * 4,
127
+ opt.latent_channels * 4,
128
+ 3,
129
+ 1,
130
+ 4,
131
+ dilation=4,
132
+ pad_type=opt.pad_type,
133
+ activation=opt.activation,
134
+ norm=opt.norm,
135
+ ),
136
+ GatedConv2d(
137
+ opt.latent_channels * 4,
138
+ opt.latent_channels * 4,
139
+ 3,
140
+ 1,
141
+ 8,
142
+ dilation=8,
143
+ pad_type=opt.pad_type,
144
+ activation=opt.activation,
145
+ norm=opt.norm,
146
+ ),
147
+ GatedConv2d(
148
+ opt.latent_channels * 4,
149
+ opt.latent_channels * 4,
150
+ 3,
151
+ 1,
152
+ 16,
153
+ dilation=16,
154
+ pad_type=opt.pad_type,
155
+ activation=opt.activation,
156
+ norm=opt.norm,
157
+ ),
158
+ GatedConv2d(
159
+ opt.latent_channels * 4,
160
+ opt.latent_channels * 4,
161
+ 3,
162
+ 1,
163
+ 1,
164
+ pad_type=opt.pad_type,
165
+ activation=opt.activation,
166
+ norm=opt.norm,
167
+ ),
168
+ GatedConv2d(
169
+ opt.latent_channels * 4,
170
+ opt.latent_channels * 4,
171
+ 3,
172
+ 1,
173
+ 1,
174
+ pad_type=opt.pad_type,
175
+ activation=opt.activation,
176
+ norm=opt.norm,
177
+ ),
178
+ # decoder
179
+ TransposeGatedConv2d(
180
+ opt.latent_channels * 4,
181
+ opt.latent_channels * 2,
182
+ 3,
183
+ 1,
184
+ 1,
185
+ pad_type=opt.pad_type,
186
+ activation=opt.activation,
187
+ norm=opt.norm,
188
+ ),
189
+ GatedConv2d(
190
+ opt.latent_channels * 2,
191
+ opt.latent_channels * 2,
192
+ 3,
193
+ 1,
194
+ 1,
195
+ pad_type=opt.pad_type,
196
+ activation=opt.activation,
197
+ norm=opt.norm,
198
+ ),
199
+ TransposeGatedConv2d(
200
+ opt.latent_channels * 2,
201
+ opt.latent_channels,
202
+ 3,
203
+ 1,
204
+ 1,
205
+ pad_type=opt.pad_type,
206
+ activation=opt.activation,
207
+ norm=opt.norm,
208
+ ),
209
+ GatedConv2d(
210
+ opt.latent_channels,
211
+ opt.latent_channels // 2,
212
+ 3,
213
+ 1,
214
+ 1,
215
+ pad_type=opt.pad_type,
216
+ activation=opt.activation,
217
+ norm=opt.norm,
218
+ ),
219
+ GatedConv2d(
220
+ opt.latent_channels // 2,
221
+ opt.out_channels,
222
+ 3,
223
+ 1,
224
+ 1,
225
+ pad_type=opt.pad_type,
226
+ activation="none",
227
+ norm=opt.norm,
228
+ ),
229
+ nn.Tanh(),
230
+ )
231
+
232
+ self.refine_conv = nn.Sequential(
233
+ GatedConv2d(
234
+ opt.in_channels,
235
+ opt.latent_channels,
236
+ 5,
237
+ 1,
238
+ 2,
239
+ pad_type=opt.pad_type,
240
+ activation=opt.activation,
241
+ norm=opt.norm,
242
+ ),
243
+ GatedConv2d(
244
+ opt.latent_channels,
245
+ opt.latent_channels,
246
+ 3,
247
+ 2,
248
+ 1,
249
+ pad_type=opt.pad_type,
250
+ activation=opt.activation,
251
+ norm=opt.norm,
252
+ ),
253
+ GatedConv2d(
254
+ opt.latent_channels,
255
+ opt.latent_channels * 2,
256
+ 3,
257
+ 1,
258
+ 1,
259
+ pad_type=opt.pad_type,
260
+ activation=opt.activation,
261
+ norm=opt.norm,
262
+ ),
263
+ GatedConv2d(
264
+ opt.latent_channels * 2,
265
+ opt.latent_channels * 2,
266
+ 3,
267
+ 2,
268
+ 1,
269
+ pad_type=opt.pad_type,
270
+ activation=opt.activation,
271
+ norm=opt.norm,
272
+ ),
273
+ GatedConv2d(
274
+ opt.latent_channels * 2,
275
+ opt.latent_channels * 4,
276
+ 3,
277
+ 1,
278
+ 1,
279
+ pad_type=opt.pad_type,
280
+ activation=opt.activation,
281
+ norm=opt.norm,
282
+ ),
283
+ GatedConv2d(
284
+ opt.latent_channels * 4,
285
+ opt.latent_channels * 4,
286
+ 3,
287
+ 1,
288
+ 1,
289
+ pad_type=opt.pad_type,
290
+ activation=opt.activation,
291
+ norm=opt.norm,
292
+ ),
293
+ GatedConv2d(
294
+ opt.latent_channels * 4,
295
+ opt.latent_channels * 4,
296
+ 3,
297
+ 1,
298
+ 2,
299
+ dilation=2,
300
+ pad_type=opt.pad_type,
301
+ activation=opt.activation,
302
+ norm=opt.norm,
303
+ ),
304
+ GatedConv2d(
305
+ opt.latent_channels * 4,
306
+ opt.latent_channels * 4,
307
+ 3,
308
+ 1,
309
+ 4,
310
+ dilation=4,
311
+ pad_type=opt.pad_type,
312
+ activation=opt.activation,
313
+ norm=opt.norm,
314
+ ),
315
+ GatedConv2d(
316
+ opt.latent_channels * 4,
317
+ opt.latent_channels * 4,
318
+ 3,
319
+ 1,
320
+ 8,
321
+ dilation=8,
322
+ pad_type=opt.pad_type,
323
+ activation=opt.activation,
324
+ norm=opt.norm,
325
+ ),
326
+ GatedConv2d(
327
+ opt.latent_channels * 4,
328
+ opt.latent_channels * 4,
329
+ 3,
330
+ 1,
331
+ 16,
332
+ dilation=16,
333
+ pad_type=opt.pad_type,
334
+ activation=opt.activation,
335
+ norm=opt.norm,
336
+ ),
337
+ )
338
+ self.refine_atten_1 = nn.Sequential(
339
+ GatedConv2d(
340
+ opt.in_channels,
341
+ opt.latent_channels,
342
+ 5,
343
+ 1,
344
+ 2,
345
+ pad_type=opt.pad_type,
346
+ activation=opt.activation,
347
+ norm=opt.norm,
348
+ ),
349
+ GatedConv2d(
350
+ opt.latent_channels,
351
+ opt.latent_channels,
352
+ 3,
353
+ 2,
354
+ 1,
355
+ pad_type=opt.pad_type,
356
+ activation=opt.activation,
357
+ norm=opt.norm,
358
+ ),
359
+ GatedConv2d(
360
+ opt.latent_channels,
361
+ opt.latent_channels * 2,
362
+ 3,
363
+ 1,
364
+ 1,
365
+ pad_type=opt.pad_type,
366
+ activation=opt.activation,
367
+ norm=opt.norm,
368
+ ),
369
+ GatedConv2d(
370
+ opt.latent_channels * 2,
371
+ opt.latent_channels * 4,
372
+ 3,
373
+ 2,
374
+ 1,
375
+ pad_type=opt.pad_type,
376
+ activation=opt.activation,
377
+ norm=opt.norm,
378
+ ),
379
+ GatedConv2d(
380
+ opt.latent_channels * 4,
381
+ opt.latent_channels * 4,
382
+ 3,
383
+ 1,
384
+ 1,
385
+ pad_type=opt.pad_type,
386
+ activation=opt.activation,
387
+ norm=opt.norm,
388
+ ),
389
+ GatedConv2d(
390
+ opt.latent_channels * 4,
391
+ opt.latent_channels * 4,
392
+ 3,
393
+ 1,
394
+ 1,
395
+ pad_type=opt.pad_type,
396
+ activation="relu",
397
+ norm=opt.norm,
398
+ ),
399
+ )
400
+ self.refine_atten_2 = nn.Sequential(
401
+ GatedConv2d(
402
+ opt.latent_channels * 4,
403
+ opt.latent_channels * 4,
404
+ 3,
405
+ 1,
406
+ 1,
407
+ pad_type=opt.pad_type,
408
+ activation=opt.activation,
409
+ norm=opt.norm,
410
+ ),
411
+ GatedConv2d(
412
+ opt.latent_channels * 4,
413
+ opt.latent_channels * 4,
414
+ 3,
415
+ 1,
416
+ 1,
417
+ pad_type=opt.pad_type,
418
+ activation=opt.activation,
419
+ norm=opt.norm,
420
+ ),
421
+ )
422
+ self.refine_combine = nn.Sequential(
423
+ GatedConv2d(
424
+ opt.latent_channels * 8,
425
+ opt.latent_channels * 4,
426
+ 3,
427
+ 1,
428
+ 1,
429
+ pad_type=opt.pad_type,
430
+ activation=opt.activation,
431
+ norm=opt.norm,
432
+ ),
433
+ GatedConv2d(
434
+ opt.latent_channels * 4,
435
+ opt.latent_channels * 4,
436
+ 3,
437
+ 1,
438
+ 1,
439
+ pad_type=opt.pad_type,
440
+ activation=opt.activation,
441
+ norm=opt.norm,
442
+ ),
443
+ TransposeGatedConv2d(
444
+ opt.latent_channels * 4,
445
+ opt.latent_channels * 2,
446
+ 3,
447
+ 1,
448
+ 1,
449
+ pad_type=opt.pad_type,
450
+ activation=opt.activation,
451
+ norm=opt.norm,
452
+ ),
453
+ GatedConv2d(
454
+ opt.latent_channels * 2,
455
+ opt.latent_channels * 2,
456
+ 3,
457
+ 1,
458
+ 1,
459
+ pad_type=opt.pad_type,
460
+ activation=opt.activation,
461
+ norm=opt.norm,
462
+ ),
463
+ TransposeGatedConv2d(
464
+ opt.latent_channels * 2,
465
+ opt.latent_channels,
466
+ 3,
467
+ 1,
468
+ 1,
469
+ pad_type=opt.pad_type,
470
+ activation=opt.activation,
471
+ norm=opt.norm,
472
+ ),
473
+ GatedConv2d(
474
+ opt.latent_channels,
475
+ opt.latent_channels // 2,
476
+ 3,
477
+ 1,
478
+ 1,
479
+ pad_type=opt.pad_type,
480
+ activation=opt.activation,
481
+ norm=opt.norm,
482
+ ),
483
+ GatedConv2d(
484
+ opt.latent_channels // 2,
485
+ opt.out_channels,
486
+ 3,
487
+ 1,
488
+ 1,
489
+ pad_type=opt.pad_type,
490
+ activation="none",
491
+ norm=opt.norm,
492
+ ),
493
+ nn.Tanh(),
494
+ )
495
+
496
+ use_cuda = opt.use_cuda
497
+
498
+ self.context_attention = ContextualAttention(
499
+ ksize=3,
500
+ stride=1,
501
+ rate=2,
502
+ fuse_k=3,
503
+ softmax_scale=10,
504
+ fuse=True,
505
+ use_cuda=use_cuda,
506
+ )
507
+
508
+ def forward(self, img, mask):
509
+ # img: entire img
510
+ # mask: 1 for mask region; 0 for unmask region
511
+ # Coarse
512
+ first_masked_img = img * (1 - mask) + mask
513
+ first_in = torch.cat(
514
+ (first_masked_img, mask), dim=1
515
+ ) # in: [B, 4, H, W]
516
+ first_out = self.coarse(first_in) # out: [B, 3, H, W]
517
+ first_out = nn.functional.interpolate(
518
+ first_out,
519
+ (img.shape[2], img.shape[3]),
520
+ recompute_scale_factor=False,
521
+ )
522
+ # Refinement
523
+ second_masked_img = img * (1 - mask) + first_out * mask
524
+ second_in = torch.cat([second_masked_img, mask], dim=1)
525
+ refine_conv = self.refine_conv(second_in)
526
+ refine_atten = self.refine_atten_1(second_in)
527
+ mask_s = nn.functional.interpolate(
528
+ mask,
529
+ (refine_atten.shape[2], refine_atten.shape[3]),
530
+ recompute_scale_factor=False,
531
+ )
532
+ refine_atten = self.context_attention(
533
+ refine_atten, refine_atten, mask_s
534
+ )
535
+ refine_atten = self.refine_atten_2(refine_atten)
536
+ second_out = torch.cat([refine_conv, refine_atten], dim=1)
537
+ second_out = self.refine_combine(second_out)
538
+ second_out = nn.functional.interpolate(
539
+ second_out,
540
+ (img.shape[2], img.shape[3]),
541
+ recompute_scale_factor=False,
542
+ )
543
+ return first_out, second_out
544
+
545
+
546
+ # -----------------------------------------------
547
+ # Discriminator
548
+ # -----------------------------------------------
549
+ # Input: generated image / ground truth and mask
550
+ # Output: patch based region, we set 30 * 30
551
+ class PatchDiscriminator(nn.Module):
552
+ def __init__(self, opt):
553
+ super(PatchDiscriminator, self).__init__()
554
+ # Down sampling
555
+ self.block1 = Conv2dLayer(
556
+ opt.in_channels,
557
+ opt.latent_channels,
558
+ 7,
559
+ 1,
560
+ 3,
561
+ pad_type=opt.pad_type,
562
+ activation=opt.activation,
563
+ norm=opt.norm,
564
+ sn=True,
565
+ )
566
+ self.block2 = Conv2dLayer(
567
+ opt.latent_channels,
568
+ opt.latent_channels * 2,
569
+ 4,
570
+ 2,
571
+ 1,
572
+ pad_type=opt.pad_type,
573
+ activation=opt.activation,
574
+ norm=opt.norm,
575
+ sn=True,
576
+ )
577
+ self.block3 = Conv2dLayer(
578
+ opt.latent_channels * 2,
579
+ opt.latent_channels * 4,
580
+ 4,
581
+ 2,
582
+ 1,
583
+ pad_type=opt.pad_type,
584
+ activation=opt.activation,
585
+ norm=opt.norm,
586
+ sn=True,
587
+ )
588
+ self.block4 = Conv2dLayer(
589
+ opt.latent_channels * 4,
590
+ opt.latent_channels * 4,
591
+ 4,
592
+ 2,
593
+ 1,
594
+ pad_type=opt.pad_type,
595
+ activation=opt.activation,
596
+ norm=opt.norm,
597
+ sn=True,
598
+ )
599
+ self.block5 = Conv2dLayer(
600
+ opt.latent_channels * 4,
601
+ opt.latent_channels * 4,
602
+ 4,
603
+ 2,
604
+ 1,
605
+ pad_type=opt.pad_type,
606
+ activation=opt.activation,
607
+ norm=opt.norm,
608
+ sn=True,
609
+ )
610
+ self.block6 = Conv2dLayer(
611
+ opt.latent_channels * 4,
612
+ 1,
613
+ 4,
614
+ 2,
615
+ 1,
616
+ pad_type=opt.pad_type,
617
+ activation="none",
618
+ norm="none",
619
+ sn=True,
620
+ )
621
+
622
+ def forward(self, img, mask):
623
+ # the input x should contain 4 channels because it is a combination of recon image and mask
624
+ x = torch.cat((img, mask), 1)
625
+ x = self.block1(x) # out: [B, 64, 256, 256]
626
+ x = self.block2(x) # out: [B, 128, 128, 128]
627
+ x = self.block3(x) # out: [B, 256, 64, 64]
628
+ x = self.block4(x) # out: [B, 256, 32, 32]
629
+ x = self.block5(x) # out: [B, 256, 16, 16]
630
+ x = self.block6(x) # out: [B, 256, 8, 8]
631
+ return x
632
+
633
+
634
+ # ----------------------------------------
635
+ # Perceptual Network
636
+ # ----------------------------------------
637
+ # VGG-16 conv4_3 features
638
+ class PerceptualNet(nn.Module):
639
+ def __init__(self):
640
+ super(PerceptualNet, self).__init__()
641
+ block = [
642
+ torchvision.models.vgg16(pretrained=True).features[:15].eval()
643
+ ]
644
+ for p in block[0]:
645
+ p.requires_grad = False
646
+ self.block = torch.nn.ModuleList(block)
647
+ self.transform = torch.nn.functional.interpolate
648
+ self.register_buffer(
649
+ "mean", torch.FloatTensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
650
+ )
651
+ self.register_buffer(
652
+ "std", torch.FloatTensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
653
+ )
654
+
655
+ def forward(self, x):
656
+ x = (x - self.mean) / self.std
657
+ x = self.transform(
658
+ x,
659
+ mode="bilinear",
660
+ size=(224, 224),
661
+ align_corners=False,
662
+ recompute_scale_factor=False,
663
+ )
664
+ for block in self.block:
665
+ x = block(x)
666
+ return x
deepfillv2/network_module.py ADDED
@@ -0,0 +1,596 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+ from torch.nn import Parameter
5
+
6
+ from deepfillv2.network_utils import *
7
+
8
+
9
+ # -----------------------------------------------
10
+ # Normal ConvBlock
11
+ # -----------------------------------------------
12
+ class Conv2dLayer(nn.Module):
13
+ def __init__(
14
+ self,
15
+ in_channels,
16
+ out_channels,
17
+ kernel_size,
18
+ stride=1,
19
+ padding=0,
20
+ dilation=1,
21
+ pad_type="zero",
22
+ activation="elu",
23
+ norm="none",
24
+ sn=False,
25
+ ):
26
+ super(Conv2dLayer, self).__init__()
27
+ # Initialize the padding scheme
28
+ if pad_type == "reflect":
29
+ self.pad = nn.ReflectionPad2d(padding)
30
+ elif pad_type == "replicate":
31
+ self.pad = nn.ReplicationPad2d(padding)
32
+ elif pad_type == "zero":
33
+ self.pad = nn.ZeroPad2d(padding)
34
+ else:
35
+ assert 0, "Unsupported padding type: {}".format(pad_type)
36
+
37
+ # Initialize the normalization type
38
+ if norm == "bn":
39
+ self.norm = nn.BatchNorm2d(out_channels)
40
+ elif norm == "in":
41
+ self.norm = nn.InstanceNorm2d(out_channels)
42
+ elif norm == "ln":
43
+ self.norm = LayerNorm(out_channels)
44
+ elif norm == "none":
45
+ self.norm = None
46
+ else:
47
+ assert 0, "Unsupported normalization: {}".format(norm)
48
+
49
+ # Initialize the activation funtion
50
+ if activation == "relu":
51
+ self.activation = nn.ReLU(inplace=True)
52
+ elif activation == "lrelu":
53
+ self.activation = nn.LeakyReLU(0.2, inplace=True)
54
+ elif activation == "elu":
55
+ self.activation = nn.ELU(inplace=True)
56
+ elif activation == "selu":
57
+ self.activation = nn.SELU(inplace=True)
58
+ elif activation == "tanh":
59
+ self.activation = nn.Tanh()
60
+ elif activation == "sigmoid":
61
+ self.activation = nn.Sigmoid()
62
+ elif activation == "none":
63
+ self.activation = None
64
+ else:
65
+ assert 0, "Unsupported activation: {}".format(activation)
66
+
67
+ # Initialize the convolution layers
68
+ if sn:
69
+ self.conv2d = SpectralNorm(
70
+ nn.Conv2d(
71
+ in_channels,
72
+ out_channels,
73
+ kernel_size,
74
+ stride,
75
+ padding=0,
76
+ dilation=dilation,
77
+ )
78
+ )
79
+ else:
80
+ self.conv2d = nn.Conv2d(
81
+ in_channels,
82
+ out_channels,
83
+ kernel_size,
84
+ stride,
85
+ padding=0,
86
+ dilation=dilation,
87
+ )
88
+
89
+ def forward(self, x):
90
+ x = self.pad(x)
91
+ x = self.conv2d(x)
92
+ if self.norm:
93
+ x = self.norm(x)
94
+ if self.activation:
95
+ x = self.activation(x)
96
+ return x
97
+
98
+
99
+ class TransposeConv2dLayer(nn.Module):
100
+ def __init__(
101
+ self,
102
+ in_channels,
103
+ out_channels,
104
+ kernel_size,
105
+ stride=1,
106
+ padding=0,
107
+ dilation=1,
108
+ pad_type="zero",
109
+ activation="lrelu",
110
+ norm="none",
111
+ sn=False,
112
+ scale_factor=2,
113
+ ):
114
+ super(TransposeConv2dLayer, self).__init__()
115
+ # Initialize the conv scheme
116
+ self.scale_factor = scale_factor
117
+ self.conv2d = Conv2dLayer(
118
+ in_channels,
119
+ out_channels,
120
+ kernel_size,
121
+ stride,
122
+ padding,
123
+ dilation,
124
+ pad_type,
125
+ activation,
126
+ norm,
127
+ sn,
128
+ )
129
+
130
+ def forward(self, x):
131
+ x = F.interpolate(
132
+ x,
133
+ scale_factor=self.scale_factor,
134
+ mode="nearest",
135
+ recompute_scale_factor=False,
136
+ )
137
+ x = self.conv2d(x)
138
+ return x
139
+
140
+
141
+ # -----------------------------------------------
142
+ # Gated ConvBlock
143
+ # -----------------------------------------------
144
+ class GatedConv2d(nn.Module):
145
+ def __init__(
146
+ self,
147
+ in_channels,
148
+ out_channels,
149
+ kernel_size,
150
+ stride=1,
151
+ padding=0,
152
+ dilation=1,
153
+ pad_type="reflect",
154
+ activation="elu",
155
+ norm="none",
156
+ sn=False,
157
+ ):
158
+ super(GatedConv2d, self).__init__()
159
+ # Initialize the padding scheme
160
+ if pad_type == "reflect":
161
+ self.pad = nn.ReflectionPad2d(padding)
162
+ elif pad_type == "replicate":
163
+ self.pad = nn.ReplicationPad2d(padding)
164
+ elif pad_type == "zero":
165
+ self.pad = nn.ZeroPad2d(padding)
166
+ else:
167
+ assert 0, "Unsupported padding type: {}".format(pad_type)
168
+
169
+ # Initialize the normalization type
170
+ if norm == "bn":
171
+ self.norm = nn.BatchNorm2d(out_channels)
172
+ elif norm == "in":
173
+ self.norm = nn.InstanceNorm2d(out_channels)
174
+ elif norm == "ln":
175
+ self.norm = LayerNorm(out_channels)
176
+ elif norm == "none":
177
+ self.norm = None
178
+ else:
179
+ assert 0, "Unsupported normalization: {}".format(norm)
180
+
181
+ # Initialize the activation funtion
182
+ if activation == "relu":
183
+ self.activation = nn.ReLU(inplace=True)
184
+ elif activation == "lrelu":
185
+ self.activation = nn.LeakyReLU(0.2, inplace=True)
186
+ elif activation == "elu":
187
+ self.activation = nn.ELU()
188
+ elif activation == "selu":
189
+ self.activation = nn.SELU(inplace=True)
190
+ elif activation == "tanh":
191
+ self.activation = nn.Tanh()
192
+ elif activation == "sigmoid":
193
+ self.activation = nn.Sigmoid()
194
+ elif activation == "none":
195
+ self.activation = None
196
+ else:
197
+ assert 0, "Unsupported activation: {}".format(activation)
198
+
199
+ # Initialize the convolution layers
200
+ if sn:
201
+ self.conv2d = SpectralNorm(
202
+ nn.Conv2d(
203
+ in_channels,
204
+ out_channels,
205
+ kernel_size,
206
+ stride,
207
+ padding=0,
208
+ dilation=dilation,
209
+ )
210
+ )
211
+ self.mask_conv2d = SpectralNorm(
212
+ nn.Conv2d(
213
+ in_channels,
214
+ out_channels,
215
+ kernel_size,
216
+ stride,
217
+ padding=0,
218
+ dilation=dilation,
219
+ )
220
+ )
221
+ else:
222
+ self.conv2d = nn.Conv2d(
223
+ in_channels,
224
+ out_channels,
225
+ kernel_size,
226
+ stride,
227
+ padding=0,
228
+ dilation=dilation,
229
+ )
230
+ self.mask_conv2d = nn.Conv2d(
231
+ in_channels,
232
+ out_channels,
233
+ kernel_size,
234
+ stride,
235
+ padding=0,
236
+ dilation=dilation,
237
+ )
238
+ self.sigmoid = torch.nn.Sigmoid()
239
+
240
+ def forward(self, x):
241
+ x = self.pad(x)
242
+ conv = self.conv2d(x)
243
+ mask = self.mask_conv2d(x)
244
+ gated_mask = self.sigmoid(mask)
245
+ if self.activation:
246
+ conv = self.activation(conv)
247
+ x = conv * gated_mask
248
+ return x
249
+
250
+
251
+ class TransposeGatedConv2d(nn.Module):
252
+ def __init__(
253
+ self,
254
+ in_channels,
255
+ out_channels,
256
+ kernel_size,
257
+ stride=1,
258
+ padding=0,
259
+ dilation=1,
260
+ pad_type="zero",
261
+ activation="lrelu",
262
+ norm="none",
263
+ sn=True,
264
+ scale_factor=2,
265
+ ):
266
+ super(TransposeGatedConv2d, self).__init__()
267
+ # Initialize the conv scheme
268
+ self.scale_factor = scale_factor
269
+ self.gated_conv2d = GatedConv2d(
270
+ in_channels,
271
+ out_channels,
272
+ kernel_size,
273
+ stride,
274
+ padding,
275
+ dilation,
276
+ pad_type,
277
+ activation,
278
+ norm,
279
+ sn,
280
+ )
281
+
282
+ def forward(self, x):
283
+ x = F.interpolate(
284
+ x,
285
+ scale_factor=self.scale_factor,
286
+ mode="nearest",
287
+ recompute_scale_factor=False,
288
+ )
289
+ x = self.gated_conv2d(x)
290
+ return x
291
+
292
+
293
+ # ----------------------------------------
294
+ # Layer Norm
295
+ # ----------------------------------------
296
+ class LayerNorm(nn.Module):
297
+ def __init__(self, num_features, eps=1e-8, affine=True):
298
+ super(LayerNorm, self).__init__()
299
+ self.num_features = num_features
300
+ self.affine = affine
301
+ self.eps = eps
302
+
303
+ if self.affine:
304
+ self.gamma = Parameter(torch.Tensor(num_features).uniform_())
305
+ self.beta = Parameter(torch.zeros(num_features))
306
+
307
+ def forward(self, x):
308
+ # layer norm
309
+ shape = [-1] + [1] * (x.dim() - 1) # for 4d input: [-1, 1, 1, 1]
310
+ if x.size(0) == 1:
311
+ # These two lines run much faster in pytorch 0.4 than the two lines listed below.
312
+ mean = x.view(-1).mean().view(*shape)
313
+ std = x.view(-1).std().view(*shape)
314
+ else:
315
+ mean = x.view(x.size(0), -1).mean(1).view(*shape)
316
+ std = x.view(x.size(0), -1).std(1).view(*shape)
317
+ x = (x - mean) / (std + self.eps)
318
+ # if it is learnable
319
+ if self.affine:
320
+ shape = [1, -1] + [1] * (
321
+ x.dim() - 2
322
+ ) # for 4d input: [1, -1, 1, 1]
323
+ x = x * self.gamma.view(*shape) + self.beta.view(*shape)
324
+ return x
325
+
326
+
327
+ # -----------------------------------------------
328
+ # SpectralNorm
329
+ # -----------------------------------------------
330
+ def l2normalize(v, eps=1e-12):
331
+ return v / (v.norm() + eps)
332
+
333
+
334
+ class SpectralNorm(nn.Module):
335
+ def __init__(self, module, name="weight", power_iterations=1):
336
+ super(SpectralNorm, self).__init__()
337
+ self.module = module
338
+ self.name = name
339
+ self.power_iterations = power_iterations
340
+ if not self._made_params():
341
+ self._make_params()
342
+
343
+ def _update_u_v(self):
344
+ u = getattr(self.module, self.name + "_u")
345
+ v = getattr(self.module, self.name + "_v")
346
+ w = getattr(self.module, self.name + "_bar")
347
+
348
+ height = w.data.shape[0]
349
+ for _ in range(self.power_iterations):
350
+ v.data = l2normalize(
351
+ torch.mv(torch.t(w.view(height, -1).data), u.data)
352
+ )
353
+ u.data = l2normalize(torch.mv(w.view(height, -1).data, v.data))
354
+
355
+ # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
356
+ sigma = u.dot(w.view(height, -1).mv(v))
357
+ setattr(self.module, self.name, w / sigma.expand_as(w))
358
+
359
+ def _made_params(self):
360
+ try:
361
+ u = getattr(self.module, self.name + "_u")
362
+ v = getattr(self.module, self.name + "_v")
363
+ w = getattr(self.module, self.name + "_bar")
364
+ return True
365
+ except AttributeError:
366
+ return False
367
+
368
+ def _make_params(self):
369
+ w = getattr(self.module, self.name)
370
+
371
+ height = w.data.shape[0]
372
+ width = w.view(height, -1).data.shape[1]
373
+
374
+ u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
375
+ v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
376
+ u.data = l2normalize(u.data)
377
+ v.data = l2normalize(v.data)
378
+ w_bar = Parameter(w.data)
379
+
380
+ del self.module._parameters[self.name]
381
+
382
+ self.module.register_parameter(self.name + "_u", u)
383
+ self.module.register_parameter(self.name + "_v", v)
384
+ self.module.register_parameter(self.name + "_bar", w_bar)
385
+
386
+ def forward(self, *args):
387
+ self._update_u_v()
388
+ return self.module.forward(*args)
389
+
390
+
391
+ class ContextualAttention(nn.Module):
392
+ def __init__(
393
+ self,
394
+ ksize=3,
395
+ stride=1,
396
+ rate=1,
397
+ fuse_k=3,
398
+ softmax_scale=10,
399
+ fuse=True,
400
+ use_cuda=True,
401
+ device_ids=None,
402
+ ):
403
+ super(ContextualAttention, self).__init__()
404
+ self.ksize = ksize
405
+ self.stride = stride
406
+ self.rate = rate
407
+ self.fuse_k = fuse_k
408
+ self.softmax_scale = softmax_scale
409
+ self.fuse = fuse
410
+ self.use_cuda = use_cuda
411
+ self.device_ids = device_ids
412
+
413
+ def forward(self, f, b, mask=None):
414
+ """Contextual attention layer implementation.
415
+ Contextual attention is first introduced in publication:
416
+ Generative Image Inpainting with Contextual Attention, Yu et al.
417
+ Args:
418
+ f: Input feature to match (foreground).
419
+ b: Input feature for match (background).
420
+ mask: Input mask for b, indicating patches not available.
421
+ ksize: Kernel size for contextual attention.
422
+ stride: Stride for extracting patches from b.
423
+ rate: Dilation for matching.
424
+ softmax_scale: Scaled softmax for attention.
425
+ Returns:
426
+ torch.tensor: output
427
+ """
428
+ # get shapes
429
+ raw_int_fs = list(f.size()) # b*c*h*w
430
+ raw_int_bs = list(b.size()) # b*c*h*w
431
+
432
+ # extract patches from background with stride and rate
433
+ kernel = 2 * self.rate
434
+ # raw_w is extracted for reconstruction
435
+ raw_w = extract_image_patches(
436
+ b,
437
+ ksizes=[kernel, kernel],
438
+ strides=[self.rate * self.stride, self.rate * self.stride],
439
+ rates=[1, 1],
440
+ padding="same",
441
+ ) # [N, C*k*k, L]
442
+ # raw_shape: [N, C, k, k, L] [4, 192, 4, 4, 1024]
443
+ raw_w = raw_w.view(raw_int_bs[0], raw_int_bs[1], kernel, kernel, -1)
444
+ raw_w = raw_w.permute(0, 4, 1, 2, 3) # raw_shape: [N, L, C, k, k]
445
+ raw_w_groups = torch.split(raw_w, 1, dim=0)
446
+
447
+ # downscaling foreground option: downscaling both foreground and
448
+ # background for matching and use original background for reconstruction.
449
+ f = F.interpolate(
450
+ f,
451
+ scale_factor=1.0 / self.rate,
452
+ mode="nearest",
453
+ recompute_scale_factor=False,
454
+ )
455
+ b = F.interpolate(
456
+ b,
457
+ scale_factor=1.0 / self.rate,
458
+ mode="nearest",
459
+ recompute_scale_factor=False,
460
+ )
461
+ int_fs = list(f.size()) # b*c*h*w
462
+ int_bs = list(b.size())
463
+ f_groups = torch.split(
464
+ f, 1, dim=0
465
+ ) # split tensors along the batch dimension
466
+ # w shape: [N, C*k*k, L]
467
+ w = extract_image_patches(
468
+ b,
469
+ ksizes=[self.ksize, self.ksize],
470
+ strides=[self.stride, self.stride],
471
+ rates=[1, 1],
472
+ padding="same",
473
+ )
474
+ # w shape: [N, C, k, k, L]
475
+ w = w.view(int_bs[0], int_bs[1], self.ksize, self.ksize, -1)
476
+ w = w.permute(0, 4, 1, 2, 3) # w shape: [N, L, C, k, k]
477
+ w_groups = torch.split(w, 1, dim=0)
478
+
479
+ # process mask
480
+ mask = F.interpolate(
481
+ mask,
482
+ scale_factor=1.0 / self.rate,
483
+ mode="nearest",
484
+ recompute_scale_factor=False,
485
+ )
486
+ int_ms = list(mask.size())
487
+ # m shape: [N, C*k*k, L]
488
+ m = extract_image_patches(
489
+ mask,
490
+ ksizes=[self.ksize, self.ksize],
491
+ strides=[self.stride, self.stride],
492
+ rates=[1, 1],
493
+ padding="same",
494
+ )
495
+
496
+ # m shape: [N, C, k, k, L]
497
+ m = m.view(int_ms[0], int_ms[1], self.ksize, self.ksize, -1)
498
+ m = m.permute(0, 4, 1, 2, 3) # m shape: [N, L, C, k, k]
499
+ m = m[0] # m shape: [L, C, k, k]
500
+ # mm shape: [L, 1, 1, 1]
501
+ mm = (reduce_mean(m, axis=[1, 2, 3], keepdim=True) == 0.0).to(
502
+ torch.float32
503
+ )
504
+ mm = mm.permute(1, 0, 2, 3) # mm shape: [1, L, 1, 1]
505
+
506
+ y = []
507
+ offsets = []
508
+ k = self.fuse_k
509
+ scale = (
510
+ self.softmax_scale
511
+ ) # to fit the PyTorch tensor image value range
512
+ fuse_weight = torch.eye(k).view(1, 1, k, k) # 1*1*k*k
513
+ if self.use_cuda:
514
+ fuse_weight = fuse_weight.cuda()
515
+
516
+ for xi, wi, raw_wi in zip(f_groups, w_groups, raw_w_groups):
517
+ """
518
+ O => output channel as a conv filter
519
+ I => input channel as a conv filter
520
+ xi : separated tensor along batch dimension of front; (B=1, C=128, H=32, W=32)
521
+ wi : separated patch tensor along batch dimension of back; (B=1, O=32*32, I=128, KH=3, KW=3)
522
+ raw_wi : separated tensor along batch dimension of back; (B=1, I=32*32, O=128, KH=4, KW=4)
523
+ """
524
+ # conv for compare
525
+ escape_NaN = torch.FloatTensor([1e-4])
526
+ if self.use_cuda:
527
+ escape_NaN = escape_NaN.cuda()
528
+ wi = wi[0] # [L, C, k, k]
529
+ max_wi = torch.sqrt(
530
+ reduce_sum(
531
+ torch.pow(wi, 2) + escape_NaN, axis=[1, 2, 3], keepdim=True
532
+ )
533
+ )
534
+ wi_normed = wi / max_wi
535
+ # xi shape: [1, C, H, W], yi shape: [1, L, H, W]
536
+ xi = same_padding(
537
+ xi, [self.ksize, self.ksize], [1, 1], [1, 1]
538
+ ) # xi: 1*c*H*W
539
+ yi = F.conv2d(xi, wi_normed, stride=1) # [1, L, H, W]
540
+ # conv implementation for fuse scores to encourage large patches
541
+ if self.fuse:
542
+ # make all of depth to spatial resolution
543
+ yi = yi.view(
544
+ 1, 1, int_bs[2] * int_bs[3], int_fs[2] * int_fs[3]
545
+ ) # (B=1, I=1, H=32*32, W=32*32)
546
+ yi = same_padding(yi, [k, k], [1, 1], [1, 1])
547
+ yi = F.conv2d(
548
+ yi, fuse_weight, stride=1
549
+ ) # (B=1, C=1, H=32*32, W=32*32)
550
+ yi = yi.contiguous().view(
551
+ 1, int_bs[2], int_bs[3], int_fs[2], int_fs[3]
552
+ ) # (B=1, 32, 32, 32, 32)
553
+ yi = yi.permute(0, 2, 1, 4, 3)
554
+ yi = yi.contiguous().view(
555
+ 1, 1, int_bs[2] * int_bs[3], int_fs[2] * int_fs[3]
556
+ )
557
+ yi = same_padding(yi, [k, k], [1, 1], [1, 1])
558
+ yi = F.conv2d(yi, fuse_weight, stride=1)
559
+ yi = yi.contiguous().view(
560
+ 1, int_bs[3], int_bs[2], int_fs[3], int_fs[2]
561
+ )
562
+ yi = yi.permute(0, 2, 1, 4, 3).contiguous()
563
+ yi = yi.view(
564
+ 1, int_bs[2] * int_bs[3], int_fs[2], int_fs[3]
565
+ ) # (B=1, C=32*32, H=32, W=32)
566
+ # softmax to match
567
+ yi = yi * mm
568
+ yi = F.softmax(yi * scale, dim=1)
569
+ yi = yi * mm # [1, L, H, W]
570
+
571
+ offset = torch.argmax(yi, dim=1, keepdim=True) # 1*1*H*W
572
+
573
+ if int_bs != int_fs:
574
+ # Normalize the offset value to match foreground dimension
575
+ times = float(int_fs[2] * int_fs[3]) / float(
576
+ int_bs[2] * int_bs[3]
577
+ )
578
+ offset = ((offset + 1).float() * times - 1).to(torch.int64)
579
+ offset = torch.cat(
580
+ [offset // int_fs[3], offset % int_fs[3]], dim=1
581
+ ) # 1*2*H*W
582
+
583
+ # deconv for patch pasting
584
+ wi_center = raw_wi[0]
585
+ # yi = F.pad(yi, [0, 1, 0, 1]) # here may need conv_transpose same padding
586
+ yi = (
587
+ F.conv_transpose2d(yi, wi_center, stride=self.rate, padding=1)
588
+ / 4.0
589
+ ) # (B=1, C=128, H=64, W=64)
590
+ y.append(yi)
591
+ offsets.append(offset)
592
+
593
+ y = torch.cat(y, dim=0) # back to the mini-batch
594
+ y.contiguous().view(raw_int_fs)
595
+
596
+ return y
deepfillv2/network_utils.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # for contextual attention
2
+ import torch
3
+
4
+
5
+ def extract_image_patches(images, ksizes, strides, rates, padding="same"):
6
+ """
7
+ Extract patches from images and put them in the C output dimension.
8
+ :param padding:
9
+ :param images: [batch, channels, in_rows, in_cols]. A 4-D Tensor with shape
10
+ :param ksizes: [ksize_rows, ksize_cols]. The size of the sliding window for
11
+ each dimension of images
12
+ :param strides: [stride_rows, stride_cols]
13
+ :param rates: [dilation_rows, dilation_cols]
14
+ :return: A Tensor
15
+ """
16
+ assert len(images.size()) == 4
17
+ assert padding in ["same", "valid"]
18
+ batch_size, channel, height, width = images.size()
19
+
20
+ if padding == "same":
21
+ images = same_padding(images, ksizes, strides, rates)
22
+ elif padding == "valid":
23
+ pass
24
+ else:
25
+ raise NotImplementedError(
26
+ 'Unsupported padding type: {}.\
27
+ Only "same" or "valid" are supported.'.format(
28
+ padding
29
+ )
30
+ )
31
+
32
+ unfold = torch.nn.Unfold(
33
+ kernel_size=ksizes, dilation=rates, padding=0, stride=strides
34
+ )
35
+ patches = unfold(images)
36
+ return patches # [N, C*k*k, L], L is the total number of such blocks
37
+
38
+
39
+ def same_padding(images, ksizes, strides, rates):
40
+ assert len(images.size()) == 4
41
+ batch_size, channel, rows, cols = images.size()
42
+ out_rows = (rows + strides[0] - 1) // strides[0]
43
+ out_cols = (cols + strides[1] - 1) // strides[1]
44
+ effective_k_row = (ksizes[0] - 1) * rates[0] + 1
45
+ effective_k_col = (ksizes[1] - 1) * rates[1] + 1
46
+ padding_rows = max(0, (out_rows - 1) * strides[0] + effective_k_row - rows)
47
+ padding_cols = max(0, (out_cols - 1) * strides[1] + effective_k_col - cols)
48
+ # Pad the input
49
+ padding_top = int(padding_rows / 2.0)
50
+ padding_left = int(padding_cols / 2.0)
51
+ padding_bottom = padding_rows - padding_top
52
+ padding_right = padding_cols - padding_left
53
+ paddings = (padding_left, padding_right, padding_top, padding_bottom)
54
+ images = torch.nn.ZeroPad2d(paddings)(images)
55
+ return images
56
+
57
+
58
+ def reduce_mean(x, axis=None, keepdim=False):
59
+ if not axis:
60
+ axis = range(len(x.shape))
61
+ for i in sorted(axis, reverse=True):
62
+ x = torch.mean(x, dim=i, keepdim=keepdim)
63
+ return x
64
+
65
+
66
+ def reduce_std(x, axis=None, keepdim=False):
67
+ if not axis:
68
+ axis = range(len(x.shape))
69
+ for i in sorted(axis, reverse=True):
70
+ x = torch.std(x, dim=i, keepdim=keepdim)
71
+ return x
72
+
73
+
74
+ def reduce_sum(x, axis=None, keepdim=False):
75
+ if not axis:
76
+ axis = range(len(x.shape))
77
+ for i in sorted(axis, reverse=True):
78
+ x = torch.sum(x, dim=i, keepdim=keepdim)
79
+ return x
deepfillv2/test_dataset.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import torch
4
+ from torch.utils.data import Dataset
5
+
6
+ from config import *
7
+
8
+
9
+ class InpaintDataset(Dataset):
10
+ def __init__(self):
11
+ self.imglist = [INIMAGE]
12
+ self.masklist = [MASKIMAGE]
13
+ self.setsize = RESIZE_TO
14
+
15
+ def __len__(self):
16
+ return len(self.imglist)
17
+
18
+ def __getitem__(self, index):
19
+ # image
20
+ img = cv2.imread(self.imglist[index])
21
+ mask = cv2.imread(self.masklist[index])[:, :, 0]
22
+ ## COMMENTING FOR NOW
23
+ # h, w = mask.shape
24
+ # # img = cv2.resize(img, (w, h))
25
+ img = cv2.resize(img, self.setsize)
26
+ mask = cv2.resize(mask, self.setsize)
27
+ ##
28
+ # find the Minimum bounding rectangle in the mask
29
+ """
30
+ contours, hier = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
31
+ for cidx, cnt in enumerate(contours):
32
+ (x, y, w, h) = cv2.boundingRect(cnt)
33
+ mask[y:y+h, x:x+w] = 255
34
+ """
35
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
36
+
37
+ img = (
38
+ torch.from_numpy(img.astype(np.float32) / 255.0)
39
+ .permute(2, 0, 1)
40
+ .contiguous()
41
+ )
42
+ mask = (
43
+ torch.from_numpy(mask.astype(np.float32) / 255.0)
44
+ .unsqueeze(0)
45
+ .contiguous()
46
+ )
47
+ return img, mask
deepfillv2/utils.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import cv2
4
+ import torch
5
+ from deepfillv2 import network
6
+ import skimage
7
+
8
+ from config import GPU_DEVICE
9
+
10
+
11
+ # ----------------------------------------
12
+ # Network
13
+ # ----------------------------------------
14
+ def create_generator(opt):
15
+ # Initialize the networks
16
+ generator = network.GatedGenerator(opt)
17
+ print("-- Generator is created! --")
18
+ network.weights_init(
19
+ generator, init_type=opt.init_type, init_gain=opt.init_gain
20
+ )
21
+ print("-- Initialized generator with %s type --" % opt.init_type)
22
+ return generator
23
+
24
+
25
+ def create_discriminator(opt):
26
+ # Initialize the networks
27
+ discriminator = network.PatchDiscriminator(opt)
28
+ print("-- Discriminator is created! --")
29
+ network.weights_init(
30
+ discriminator, init_type=opt.init_type, init_gain=opt.init_gain
31
+ )
32
+ print("-- Initialize discriminator with %s type --" % opt.init_type)
33
+ return discriminator
34
+
35
+
36
+ def create_perceptualnet():
37
+ # Get the first 15 layers of vgg16, which is conv3_3
38
+ perceptualnet = network.PerceptualNet()
39
+ print("-- Perceptual network is created! --")
40
+ return perceptualnet
41
+
42
+
43
+ # ----------------------------------------
44
+ # PATH processing
45
+ # ----------------------------------------
46
+ def text_readlines(filename):
47
+ # Try to read a txt file and return a list.Return [] if there was a mistake.
48
+ try:
49
+ file = open(filename, "r")
50
+ except IOError:
51
+ error = []
52
+ return error
53
+ content = file.readlines()
54
+ # This for loop deletes the EOF (like \n)
55
+ for i in range(len(content)):
56
+ content[i] = content[i][: len(content[i]) - 1]
57
+ file.close()
58
+ return content
59
+
60
+
61
+ def savetxt(name, loss_log):
62
+ np_loss_log = np.array(loss_log)
63
+ np.savetxt(name, np_loss_log)
64
+
65
+
66
+ def get_files(path, mask=False):
67
+ # read a folder, return the complete path
68
+ ret = []
69
+ for root, dirs, files in os.walk(path):
70
+ for filespath in files:
71
+ if filespath != ".DS_Store":
72
+ continue
73
+ ret.append(os.path.join(root, filespath))
74
+ return ret
75
+
76
+
77
+ def get_names(path):
78
+ # read a folder, return the image name
79
+ ret = []
80
+ for root, dirs, files in os.walk(path):
81
+ for filespath in files:
82
+ ret.append(filespath)
83
+ return ret
84
+
85
+
86
+ def text_save(content, filename, mode="a"):
87
+ # save a list to a txt
88
+ # Try to save a list variable in txt file.
89
+ file = open(filename, mode)
90
+ for i in range(len(content)):
91
+ file.write(str(content[i]) + "\n")
92
+ file.close()
93
+
94
+
95
+ def check_path(path):
96
+ if not os.path.exists(path):
97
+ os.makedirs(path)
98
+
99
+
100
+ # ----------------------------------------
101
+ # Validation and Sample at training
102
+ # ----------------------------------------
103
+ def save_sample_png(
104
+ sample_folder, sample_name, img_list, name_list, pixel_max_cnt=255
105
+ ):
106
+ # Save image one-by-one
107
+ for i in range(len(img_list)):
108
+ img = img_list[i]
109
+ # Recover normalization: * 255 because last layer is sigmoid activated
110
+ img = img * 255
111
+ # Process img_copy and do not destroy the data of img
112
+ img_copy = (
113
+ img.clone().data.permute(0, 2, 3, 1)[0, :, :, :].to("cpu").numpy()
114
+ )
115
+ img_copy = np.clip(img_copy, 0, pixel_max_cnt)
116
+ img_copy = img_copy.astype(np.uint8)
117
+ img_copy = cv2.cvtColor(img_copy, cv2.COLOR_RGB2BGR)
118
+ # Save to certain path
119
+ save_img_path = os.path.join(sample_folder, sample_name)
120
+ cv2.imwrite(save_img_path, img_copy)
121
+
122
+
123
+ def psnr(pred, target, pixel_max_cnt=255):
124
+ mse = torch.mul(target - pred, target - pred)
125
+ rmse_avg = (torch.mean(mse).item()) ** 0.5
126
+ p = 20 * np.log10(pixel_max_cnt / rmse_avg)
127
+ return p
128
+
129
+
130
+ def grey_psnr(pred, target, pixel_max_cnt=255):
131
+ pred = torch.sum(pred, dim=0)
132
+ target = torch.sum(target, dim=0)
133
+ mse = torch.mul(target - pred, target - pred)
134
+ rmse_avg = (torch.mean(mse).item()) ** 0.5
135
+ p = 20 * np.log10(pixel_max_cnt * 3 / rmse_avg)
136
+ return p
137
+
138
+
139
+ def ssim(pred, target):
140
+ pred = pred.clone().data.permute(0, 2, 3, 1).to(GPU_DEVICE).numpy()
141
+ target = target.clone().data.permute(0, 2, 3, 1).to(GPU_DEVICE).numpy()
142
+ target = target[0]
143
+ pred = pred[0]
144
+ ssim = skimage.measure.compare_ssim(target, pred, multichannel=True)
145
+ return ssim