sczhou commited on
Commit
9ea1225
·
1 Parent(s): 7cc7518

add discriminator.

Browse files
Files changed (2) hide show
  1. README.md +1 -1
  2. basicsr/archs/vqgan_arch.py +46 -1
README.md CHANGED
@@ -16,7 +16,7 @@ S-Lab, Nanyang Technological University
16
  <img src="assets/network.jpg" width="800px"/>
17
 
18
 
19
- :star: If CodeFormer is helpful to your pothos or projects, please help star this repo. Thanks! :hugs:
20
 
21
  ### Updates
22
 
 
16
  <img src="assets/network.jpg" width="800px"/>
17
 
18
 
19
+ :star: If CodeFormer is helpful to your projects, please help star this repo. Thanks! :hugs:
20
 
21
  ### Updates
22
 
basicsr/archs/vqgan_arch.py CHANGED
@@ -387,4 +387,49 @@ class VQAutoEncoder(nn.Module):
387
  x = self.encoder(x)
388
  quant, codebook_loss, quant_stats = self.quantize(x)
389
  x = self.generator(quant)
390
- return x, codebook_loss, quant_stats
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
387
  x = self.encoder(x)
388
  quant, codebook_loss, quant_stats = self.quantize(x)
389
  x = self.generator(quant)
390
+ return x, codebook_loss, quant_stats
391
+
392
+
393
+
394
+ # patch based discriminator
395
+ @ARCH_REGISTRY.register()
396
+ class VQGANDiscriminator(nn.Module):
397
+ def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None):
398
+ super().__init__()
399
+
400
+ layers = [nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, True)]
401
+ ndf_mult = 1
402
+ ndf_mult_prev = 1
403
+ for n in range(1, n_layers): # gradually increase the number of filters
404
+ ndf_mult_prev = ndf_mult
405
+ ndf_mult = min(2 ** n, 8)
406
+ layers += [
407
+ nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=2, padding=1, bias=False),
408
+ nn.BatchNorm2d(ndf * ndf_mult),
409
+ nn.LeakyReLU(0.2, True)
410
+ ]
411
+
412
+ ndf_mult_prev = ndf_mult
413
+ ndf_mult = min(2 ** n_layers, 8)
414
+
415
+ layers += [
416
+ nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=1, padding=1, bias=False),
417
+ nn.BatchNorm2d(ndf * ndf_mult),
418
+ nn.LeakyReLU(0.2, True)
419
+ ]
420
+
421
+ layers += [
422
+ nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)] # output 1 channel prediction map
423
+ self.main = nn.Sequential(*layers)
424
+
425
+ if model_path is not None:
426
+ chkpt = torch.load(model_path, map_location='cpu')
427
+ if 'params_d' in chkpt:
428
+ self.load_state_dict(torch.load(model_path, map_location='cpu')['params_d'])
429
+ elif 'params' in chkpt:
430
+ self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
431
+ else:
432
+ raise ValueError(f'Wrong params!')
433
+
434
+ def forward(self, x):
435
+ return self.main(x)