import copy import logging import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.utils import weight_norm, spectral_norm from einops import rearrange class HiFiGANPeriodDiscriminator(torch.nn.Module): """HiFiGAN period discriminator module.""" def __init__( self, in_channels=1, out_channels=1, period=3, kernel_sizes=[5, 3], channels=32, downsample_scales=[3, 3, 3, 3, 1], channel_increasing_factor=4, max_downsample_channels=1024, nonlinear_activation="LeakyReLU", nonlinear_activation_params={"negative_slope": 0.1}, use_weight_norm=True, ): """Initialize HiFiGANPeriodDiscriminator module. Args: in_channels (int): Number of input channels. out_channels (int): Number of output channels. period (int): Period. kernel_sizes (list): Kernel sizes of initial conv layers and the final conv layer. channels (int): Number of initial channels. downsample_scales (list): List of downsampling scales. max_downsample_channels (int): Number of maximum downsampling channels. nonlinear_activation (str): Activation function module name. nonlinear_activation_params (dict): Hyperparameters for activation function. use_weight_norm (bool): Whether to use weight norm. If set to true, it will be applied to all of the conv layers. """ super().__init__() assert len(kernel_sizes) == 2 assert kernel_sizes[0] % 2 == 1, "Kernel size must be odd number." assert kernel_sizes[1] % 2 == 1, "Kernel size must be odd number." self.period = period self.convs = torch.nn.ModuleList() in_chs = in_channels out_chs = channels for downsample_scale in downsample_scales: self.convs += [ torch.nn.Sequential( torch.nn.Conv2d( in_chs, out_chs, (kernel_sizes[0], 1), (downsample_scale, 1), padding=((kernel_sizes[0] - 1) // 2, 0), ), getattr(torch.nn, nonlinear_activation)( **nonlinear_activation_params ), ) ] in_chs = out_chs out_chs = min(out_chs * channel_increasing_factor, max_downsample_channels) self.output_conv = torch.nn.Conv2d( in_chs, out_channels, (kernel_sizes[1] - 1, 1), 1, padding=((kernel_sizes[1] - 1) // 2, 0), ) if use_weight_norm: self.apply_weight_norm() def forward(self, x): """Calculate forward propagation. Args: c (Tensor): Input tensor (B, in_channels, T). Returns: list: List of each layer's tensors. """ b, c, t = x.shape if t % self.period != 0: n_pad = self.period - (t % self.period) x = F.pad(x, (0, n_pad), "reflect") t += n_pad x = x.view(b, c, t // self.period, self.period) outs = [] for layer in self.convs: x = layer(x) outs += [x] x = self.output_conv(x) x = torch.flatten(x, 1, -1) outs += [x] return outs def apply_weight_norm(self): def _apply_weight_norm(m): if isinstance(m, torch.nn.Conv2d): torch.nn.utils.weight_norm(m) self.apply(_apply_weight_norm) class HiFiGANMultiPeriodDiscriminator(torch.nn.Module): def __init__( self, periods=[2, 3, 5, 7, 11], **kwargs, ): """Initialize HiFiGANMultiPeriodDiscriminator module. Args: periods (list): List of periods. discriminator_params (dict): Parameters for hifi-gan period discriminator module. The period parameter will be overwritten. """ super().__init__() self.discriminators = torch.nn.ModuleList() for period in periods: params = copy.deepcopy(kwargs) params["period"] = period self.discriminators += [HiFiGANPeriodDiscriminator(**params)] def forward(self, x): """Calculate forward propagation. Args: x (Tensor): Input noise signal (B, 1, T). Returns: List: List of list of each discriminator outputs, which consists of each layer output tensors. """ outs = [] for f in self.discriminators: outs += [f(x)] return outs