Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
from utils.utils import weights_init | |
from model.discriminator import JCU_Discriminator, Discriminator | |
class MultiScaleDiscriminator(nn.Module): | |
def __init__(self, num_D = 3, ndf = 16, n_layers = 3, downsampling_factor = 4, disc_out = 512): | |
super().__init__() | |
self.model = nn.ModuleDict() | |
for i in range(num_D): | |
self.model[f"disc_{i}"] = Discriminator( | |
ndf, n_layers, downsampling_factor, disc_out | |
) | |
self.downsample = nn.AvgPool1d(downsampling_factor, stride=2, padding=1, count_include_pad=False) | |
self.apply(weights_init) | |
def forward(self, x): | |
results = [] | |
for key, disc in self.model.items(): | |
results.append(disc(x)) | |
x = self.downsample(x) | |
return results | |
class MultiScaleDiscriminatorJCU(nn.Module): | |
def __init__(self, num_D = 3, downsampling_factor = 4): | |
super(MultiScaleDiscriminator, self).__init__() | |
self.model = nn.ModuleDict() | |
for i in range(num_D): | |
self.model[f"disc_{i}"] = JCU_Discriminator() | |
self.downsample = nn.AvgPool1d(downsampling_factor, stride=2, padding=1, count_include_pad=False) | |
def forward(self, x, mel): | |
results = [] | |
for key, disc in self.model.items(): | |
results.append(disc(x, mel)) # [[uncond, cond], [uncond, cond], [uncond, cond]] | |
x = self.downsample(x) | |
mel = self.downsample(mel) | |
return results # [D01, D02, D03] | |
if __name__ == '__main__': | |
model = MultiScaleDiscriminator() | |
x = torch.randn(3, 1, 22050) | |
print(x.shape) | |
print(model) | |
scores = model(x) | |
for (features, score) in scores: | |
print("Length of features : ", len(features)) | |
print("Length of score : ", len(score)) | |
for feat in features: | |
print(feat.shape) | |
print(score.shape) | |
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
print(pytorch_total_params) |