|
import copy |
|
import logging |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.nn.utils import weight_norm, spectral_norm |
|
from einops import rearrange |
|
|
|
|
|
class HiFiGANPeriodDiscriminator(torch.nn.Module): |
|
"""HiFiGAN period discriminator module.""" |
|
def __init__( |
|
self, |
|
in_channels=1, |
|
out_channels=1, |
|
period=3, |
|
kernel_sizes=[5, 3], |
|
channels=32, |
|
downsample_scales=[3, 3, 3, 3, 1], |
|
channel_increasing_factor=4, |
|
max_downsample_channels=1024, |
|
nonlinear_activation="LeakyReLU", |
|
nonlinear_activation_params={"negative_slope": 0.1}, |
|
use_weight_norm=True, |
|
): |
|
"""Initialize HiFiGANPeriodDiscriminator module. |
|
Args: |
|
in_channels (int): Number of input channels. |
|
out_channels (int): Number of output channels. |
|
period (int): Period. |
|
kernel_sizes (list): Kernel sizes of initial conv layers and the final conv layer. |
|
channels (int): Number of initial channels. |
|
downsample_scales (list): List of downsampling scales. |
|
max_downsample_channels (int): Number of maximum downsampling channels. |
|
nonlinear_activation (str): Activation function module name. |
|
nonlinear_activation_params (dict): Hyperparameters for activation function. |
|
use_weight_norm (bool): Whether to use weight norm. |
|
If set to true, it will be applied to all of the conv layers. |
|
""" |
|
super().__init__() |
|
assert len(kernel_sizes) == 2 |
|
assert kernel_sizes[0] % 2 == 1, "Kernel size must be odd number." |
|
assert kernel_sizes[1] % 2 == 1, "Kernel size must be odd number." |
|
|
|
self.period = period |
|
self.convs = torch.nn.ModuleList() |
|
in_chs = in_channels |
|
out_chs = channels |
|
for downsample_scale in downsample_scales: |
|
self.convs += [ |
|
torch.nn.Sequential( |
|
torch.nn.Conv2d( |
|
in_chs, |
|
out_chs, |
|
(kernel_sizes[0], 1), |
|
(downsample_scale, 1), |
|
padding=((kernel_sizes[0] - 1) // 2, 0), |
|
), |
|
getattr(torch.nn, nonlinear_activation)( |
|
**nonlinear_activation_params |
|
), |
|
) |
|
] |
|
in_chs = out_chs |
|
out_chs = min(out_chs * channel_increasing_factor, max_downsample_channels) |
|
self.output_conv = torch.nn.Conv2d( |
|
in_chs, |
|
out_channels, |
|
(kernel_sizes[1] - 1, 1), |
|
1, |
|
padding=((kernel_sizes[1] - 1) // 2, 0), |
|
) |
|
|
|
if use_weight_norm: |
|
self.apply_weight_norm() |
|
|
|
def forward(self, x): |
|
"""Calculate forward propagation. |
|
Args: |
|
c (Tensor): Input tensor (B, in_channels, T). |
|
Returns: |
|
list: List of each layer's tensors. |
|
""" |
|
b, c, t = x.shape |
|
if t % self.period != 0: |
|
n_pad = self.period - (t % self.period) |
|
x = F.pad(x, (0, n_pad), "reflect") |
|
t += n_pad |
|
x = x.view(b, c, t // self.period, self.period) |
|
|
|
outs = [] |
|
for layer in self.convs: |
|
x = layer(x) |
|
outs += [x] |
|
x = self.output_conv(x) |
|
x = torch.flatten(x, 1, -1) |
|
outs += [x] |
|
|
|
return outs |
|
|
|
def apply_weight_norm(self): |
|
def _apply_weight_norm(m): |
|
if isinstance(m, torch.nn.Conv2d): |
|
torch.nn.utils.weight_norm(m) |
|
|
|
self.apply(_apply_weight_norm) |
|
|
|
|
|
class HiFiGANMultiPeriodDiscriminator(torch.nn.Module): |
|
def __init__( |
|
self, |
|
periods=[2, 3, 5, 7, 11], |
|
**kwargs, |
|
): |
|
"""Initialize HiFiGANMultiPeriodDiscriminator module. |
|
Args: |
|
periods (list): List of periods. |
|
discriminator_params (dict): Parameters for hifi-gan period discriminator module. |
|
The period parameter will be overwritten. |
|
""" |
|
super().__init__() |
|
self.discriminators = torch.nn.ModuleList() |
|
for period in periods: |
|
params = copy.deepcopy(kwargs) |
|
params["period"] = period |
|
self.discriminators += [HiFiGANPeriodDiscriminator(**params)] |
|
|
|
def forward(self, x): |
|
"""Calculate forward propagation. |
|
Args: |
|
x (Tensor): Input noise signal (B, 1, T). |
|
Returns: |
|
List: List of list of each discriminator outputs, which consists of each layer output tensors. |
|
""" |
|
outs = [] |
|
for f in self.discriminators: |
|
outs += [f(x)] |
|
|
|
return outs |
|
|