Spaces:
Runtime error
Runtime error
add discriminator.
Browse files- README.md +1 -1
- basicsr/archs/vqgan_arch.py +46 -1
README.md
CHANGED
@@ -16,7 +16,7 @@ S-Lab, Nanyang Technological University
|
|
16 |
<img src="assets/network.jpg" width="800px"/>
|
17 |
|
18 |
|
19 |
-
:star: If CodeFormer is helpful to your
|
20 |
|
21 |
### Updates
|
22 |
|
|
|
16 |
<img src="assets/network.jpg" width="800px"/>
|
17 |
|
18 |
|
19 |
+
:star: If CodeFormer is helpful to your projects, please help star this repo. Thanks! :hugs:
|
20 |
|
21 |
### Updates
|
22 |
|
basicsr/archs/vqgan_arch.py
CHANGED
@@ -387,4 +387,49 @@ class VQAutoEncoder(nn.Module):
|
|
387 |
x = self.encoder(x)
|
388 |
quant, codebook_loss, quant_stats = self.quantize(x)
|
389 |
x = self.generator(quant)
|
390 |
-
return x, codebook_loss, quant_stats
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
387 |
x = self.encoder(x)
|
388 |
quant, codebook_loss, quant_stats = self.quantize(x)
|
389 |
x = self.generator(quant)
|
390 |
+
return x, codebook_loss, quant_stats
|
391 |
+
|
392 |
+
|
393 |
+
|
394 |
+
# patch based discriminator
|
395 |
+
@ARCH_REGISTRY.register()
|
396 |
+
class VQGANDiscriminator(nn.Module):
|
397 |
+
def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None):
|
398 |
+
super().__init__()
|
399 |
+
|
400 |
+
layers = [nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, True)]
|
401 |
+
ndf_mult = 1
|
402 |
+
ndf_mult_prev = 1
|
403 |
+
for n in range(1, n_layers): # gradually increase the number of filters
|
404 |
+
ndf_mult_prev = ndf_mult
|
405 |
+
ndf_mult = min(2 ** n, 8)
|
406 |
+
layers += [
|
407 |
+
nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=2, padding=1, bias=False),
|
408 |
+
nn.BatchNorm2d(ndf * ndf_mult),
|
409 |
+
nn.LeakyReLU(0.2, True)
|
410 |
+
]
|
411 |
+
|
412 |
+
ndf_mult_prev = ndf_mult
|
413 |
+
ndf_mult = min(2 ** n_layers, 8)
|
414 |
+
|
415 |
+
layers += [
|
416 |
+
nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=1, padding=1, bias=False),
|
417 |
+
nn.BatchNorm2d(ndf * ndf_mult),
|
418 |
+
nn.LeakyReLU(0.2, True)
|
419 |
+
]
|
420 |
+
|
421 |
+
layers += [
|
422 |
+
nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)] # output 1 channel prediction map
|
423 |
+
self.main = nn.Sequential(*layers)
|
424 |
+
|
425 |
+
if model_path is not None:
|
426 |
+
chkpt = torch.load(model_path, map_location='cpu')
|
427 |
+
if 'params_d' in chkpt:
|
428 |
+
self.load_state_dict(torch.load(model_path, map_location='cpu')['params_d'])
|
429 |
+
elif 'params' in chkpt:
|
430 |
+
self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
|
431 |
+
else:
|
432 |
+
raise ValueError(f'Wrong params!')
|
433 |
+
|
434 |
+
def forward(self, x):
|
435 |
+
return self.main(x)
|