File size: 4,775 Bytes
59b7eeb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import copy
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import weight_norm, spectral_norm
from einops import rearrange
class HiFiGANPeriodDiscriminator(torch.nn.Module):
"""HiFiGAN period discriminator module."""
def __init__(
self,
in_channels=1,
out_channels=1,
period=3,
kernel_sizes=[5, 3],
channels=32,
downsample_scales=[3, 3, 3, 3, 1],
channel_increasing_factor=4,
max_downsample_channels=1024,
nonlinear_activation="LeakyReLU",
nonlinear_activation_params={"negative_slope": 0.1},
use_weight_norm=True,
):
"""Initialize HiFiGANPeriodDiscriminator module.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
period (int): Period.
kernel_sizes (list): Kernel sizes of initial conv layers and the final conv layer.
channels (int): Number of initial channels.
downsample_scales (list): List of downsampling scales.
max_downsample_channels (int): Number of maximum downsampling channels.
nonlinear_activation (str): Activation function module name.
nonlinear_activation_params (dict): Hyperparameters for activation function.
use_weight_norm (bool): Whether to use weight norm.
If set to true, it will be applied to all of the conv layers.
"""
super().__init__()
assert len(kernel_sizes) == 2
assert kernel_sizes[0] % 2 == 1, "Kernel size must be odd number."
assert kernel_sizes[1] % 2 == 1, "Kernel size must be odd number."
self.period = period
self.convs = torch.nn.ModuleList()
in_chs = in_channels
out_chs = channels
for downsample_scale in downsample_scales:
self.convs += [
torch.nn.Sequential(
torch.nn.Conv2d(
in_chs,
out_chs,
(kernel_sizes[0], 1),
(downsample_scale, 1),
padding=((kernel_sizes[0] - 1) // 2, 0),
),
getattr(torch.nn, nonlinear_activation)(
**nonlinear_activation_params
),
)
]
in_chs = out_chs
out_chs = min(out_chs * channel_increasing_factor, max_downsample_channels)
self.output_conv = torch.nn.Conv2d(
in_chs,
out_channels,
(kernel_sizes[1] - 1, 1),
1,
padding=((kernel_sizes[1] - 1) // 2, 0),
)
if use_weight_norm:
self.apply_weight_norm()
def forward(self, x):
"""Calculate forward propagation.
Args:
c (Tensor): Input tensor (B, in_channels, T).
Returns:
list: List of each layer's tensors.
"""
b, c, t = x.shape
if t % self.period != 0:
n_pad = self.period - (t % self.period)
x = F.pad(x, (0, n_pad), "reflect")
t += n_pad
x = x.view(b, c, t // self.period, self.period)
outs = []
for layer in self.convs:
x = layer(x)
outs += [x]
x = self.output_conv(x)
x = torch.flatten(x, 1, -1)
outs += [x]
return outs
def apply_weight_norm(self):
def _apply_weight_norm(m):
if isinstance(m, torch.nn.Conv2d):
torch.nn.utils.weight_norm(m)
self.apply(_apply_weight_norm)
class HiFiGANMultiPeriodDiscriminator(torch.nn.Module):
def __init__(
self,
periods=[2, 3, 5, 7, 11],
**kwargs,
):
"""Initialize HiFiGANMultiPeriodDiscriminator module.
Args:
periods (list): List of periods.
discriminator_params (dict): Parameters for hifi-gan period discriminator module.
The period parameter will be overwritten.
"""
super().__init__()
self.discriminators = torch.nn.ModuleList()
for period in periods:
params = copy.deepcopy(kwargs)
params["period"] = period
self.discriminators += [HiFiGANPeriodDiscriminator(**params)]
def forward(self, x):
"""Calculate forward propagation.
Args:
x (Tensor): Input noise signal (B, 1, T).
Returns:
List: List of list of each discriminator outputs, which consists of each layer output tensors.
"""
outs = []
for f in self.discriminators:
outs += [f(x)]
return outs
|