|
|
|
|
|
|
|
|
|
|
|
|
|
import copy |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
|
|
from Modules.Vocoder.Avocodo_Discriminators import MultiCoMBDiscriminator |
|
from Modules.Vocoder.Avocodo_Discriminators import MultiSubBandDiscriminator |
|
from Modules.Vocoder.SAN_modules import SANConv1d |
|
from Modules.Vocoder.SAN_modules import SANConv2d |
|
|
|
|
|
class HiFiGANPeriodDiscriminator(torch.nn.Module): |
|
|
|
def __init__(self, |
|
in_channels=1, |
|
out_channels=1, |
|
period=3, |
|
kernel_sizes=(5, 3), |
|
channels=32, |
|
downsample_scales=(3, 3, 3, 3, 1), |
|
max_downsample_channels=1024, |
|
bias=True, |
|
nonlinear_activation="LeakyReLU", |
|
nonlinear_activation_params={"negative_slope": 0.1}, |
|
use_weight_norm=True, |
|
use_spectral_norm=False, ): |
|
""" |
|
Initialize HiFiGANPeriodDiscriminator module. |
|
|
|
Args: |
|
in_channels (int): Number of input channels. |
|
out_channels (int): Number of output channels. |
|
period (int): Period. |
|
kernel_sizes (list): Kernel sizes of initial conv layers and the final conv layer. |
|
channels (int): Number of initial channels. |
|
downsample_scales (list): List of downsampling scales. |
|
max_downsample_channels (int): Number of maximum downsampling channels. |
|
bias (bool): Whether to add bias parameter in convolution layers. |
|
nonlinear_activation (str): Activation function module name. |
|
nonlinear_activation_params (dict): Hyperparameters for activation function. |
|
use_weight_norm (bool): Whether to use weight norm. |
|
If set to true, it will be applied to all of the conv layers. |
|
use_spectral_norm (bool): Whether to use spectral norm. |
|
If set to true, it will be applied to all of the conv layers. |
|
""" |
|
super().__init__() |
|
assert len(kernel_sizes) == 2 |
|
assert kernel_sizes[0] % 2 == 1, "Kernel size must be odd number." |
|
assert kernel_sizes[1] % 2 == 1, "Kernel size must be odd number." |
|
|
|
self.period = period |
|
self.convs = torch.nn.ModuleList() |
|
in_chs = in_channels |
|
out_chs = channels |
|
for downsample_scale in downsample_scales: |
|
self.convs += [torch.nn.Sequential(torch.nn.Conv2d(in_chs, |
|
out_chs, |
|
(kernel_sizes[0], 1), |
|
(downsample_scale, 1), |
|
padding=((kernel_sizes[0] - 1) // 2, 0), ), |
|
getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params), )] |
|
in_chs = out_chs |
|
|
|
out_chs = min(out_chs * 4, max_downsample_channels) |
|
self.output_conv = SANConv2d(out_chs, out_channels, (kernel_sizes[1] - 1, 1), 1, padding=((kernel_sizes[1] - 1) // 2, 0)) |
|
|
|
if use_weight_norm and use_spectral_norm: |
|
raise ValueError("Either use use_weight_norm or use_spectral_norm.") |
|
|
|
|
|
if use_weight_norm: |
|
self.apply_weight_norm() |
|
|
|
|
|
if use_spectral_norm: |
|
self.apply_spectral_norm() |
|
|
|
def forward(self, x, discriminator_train_flag): |
|
""" |
|
Calculate forward propagation. |
|
|
|
Args: |
|
x (Tensor): Input tensor (B, in_channels, T). |
|
|
|
Returns: |
|
list: List of each layer's tensors. |
|
""" |
|
|
|
b, c, t = x.shape |
|
if t % self.period != 0: |
|
n_pad = self.period - (t % self.period) |
|
x = F.pad(x, (0, n_pad), "reflect") |
|
t = t + n_pad |
|
x = x.view(b, c, t // self.period, self.period) |
|
|
|
|
|
outs = [] |
|
for layer in self.convs: |
|
x = layer(x) |
|
outs = outs + [x] |
|
x = self.output_conv(x, discriminator_train_flag) |
|
return x, outs |
|
|
|
def apply_weight_norm(self): |
|
""" |
|
Apply weight normalization module from all of the layers. |
|
""" |
|
|
|
def _apply_weight_norm(m): |
|
if isinstance(m, torch.nn.Conv2d): |
|
torch.nn.utils.weight_norm(m) |
|
|
|
self.apply(_apply_weight_norm) |
|
|
|
def apply_spectral_norm(self): |
|
""" |
|
Apply spectral normalization module from all of the layers. |
|
""" |
|
|
|
def _apply_spectral_norm(m): |
|
if isinstance(m, torch.nn.Conv2d): |
|
torch.nn.utils.spectral_norm(m) |
|
|
|
self.apply(_apply_spectral_norm) |
|
|
|
|
|
class HiFiGANMultiPeriodDiscriminator(torch.nn.Module): |
|
|
|
def __init__(self, |
|
periods=(2, 3, 5, 7, 11), |
|
discriminator_params={"in_channels" : 1, |
|
"out_channels" : 1, |
|
"kernel_sizes" : [5, 3], |
|
"channels" : 32, |
|
"downsample_scales" : [3, 3, 3, 3, 1], |
|
"max_downsample_channels" : 1024, |
|
"bias" : True, |
|
"nonlinear_activation" : "LeakyReLU", |
|
"nonlinear_activation_params": {"negative_slope": 0.1}, |
|
"use_weight_norm" : True, |
|
"use_spectral_norm" : False, }, ): |
|
""" |
|
Initialize HiFiGANMultiPeriodDiscriminator module. |
|
|
|
Args: |
|
periods (list): List of periods. |
|
discriminator_params (dict): Parameters for hifi-gan period discriminator module. |
|
The period parameter will be overwritten. |
|
""" |
|
super().__init__() |
|
self.discriminators = torch.nn.ModuleList() |
|
for period in periods: |
|
params = copy.deepcopy(discriminator_params) |
|
params["period"] = period |
|
self.discriminators += [HiFiGANPeriodDiscriminator(**params)] |
|
|
|
def forward(self, x, discriminator_train_flag): |
|
"""Calculate forward propagation. |
|
Args: |
|
x (Tensor): Input noise signal (B, 1, T). |
|
Returns: |
|
List: List of list of each discriminator outputs, which consists of each layer output tensors. |
|
""" |
|
outs = [] |
|
feats = [] |
|
for f in self.discriminators: |
|
d_out, d_feats = f(x, discriminator_train_flag) |
|
outs = outs + [d_out] |
|
feats = feats + d_feats |
|
|
|
return outs, feats |
|
|
|
|
|
class HiFiGANScaleDiscriminator(torch.nn.Module): |
|
|
|
def __init__(self, |
|
in_channels=1, |
|
out_channels=1, |
|
kernel_sizes=(15, 41, 5, 3), |
|
channels=128, |
|
max_downsample_channels=1024, |
|
max_groups=16, |
|
bias=True, |
|
downsample_scales=(2, 2, 4, 4, 1), |
|
nonlinear_activation="LeakyReLU", |
|
nonlinear_activation_params={"negative_slope": 0.1}, |
|
use_weight_norm=True, |
|
use_spectral_norm=False, ): |
|
""" |
|
Initialize HiFiGAN scale discriminator module. |
|
|
|
Args: |
|
in_channels (int): Number of input channels. |
|
out_channels (int): Number of output channels. |
|
kernel_sizes (list): List of four kernel sizes. The first will be used for the first conv layer, |
|
and the second is for downsampling part, and the remaining two are for output layers. |
|
channels (int): Initial number of channels for conv layer. |
|
max_downsample_channels (int): Maximum number of channels for downsampling layers. |
|
bias (bool): Whether to add bias parameter in convolution layers. |
|
downsample_scales (list): List of downsampling scales. |
|
nonlinear_activation (str): Activation function module name. |
|
nonlinear_activation_params (dict): Hyperparameters for activation function. |
|
use_weight_norm (bool): Whether to use weight norm. |
|
If set to true, it will be applied to all of the conv layers. |
|
use_spectral_norm (bool): Whether to use spectral norm. |
|
If set to true, it will be applied to all of the conv layers. |
|
""" |
|
super().__init__() |
|
self.layers = torch.nn.ModuleList() |
|
|
|
|
|
assert len(kernel_sizes) == 4 |
|
for ks in kernel_sizes: |
|
assert ks % 2 == 1 |
|
|
|
|
|
self.layers += [torch.nn.Sequential(torch.nn.Conv1d(in_channels, |
|
channels, |
|
|
|
kernel_sizes[0], |
|
bias=bias, |
|
padding=(kernel_sizes[0] - 1) // 2, ), |
|
getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params), )] |
|
|
|
|
|
in_chs = channels |
|
out_chs = channels |
|
|
|
groups = 4 |
|
for downsample_scale in downsample_scales: |
|
self.layers += [torch.nn.Sequential(torch.nn.Conv1d(in_chs, |
|
out_chs, |
|
kernel_size=kernel_sizes[1], |
|
stride=downsample_scale, |
|
padding=(kernel_sizes[1] - 1) // 2, |
|
groups=groups, |
|
bias=bias, |
|
), getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params), )] |
|
in_chs = out_chs |
|
|
|
out_chs = min(in_chs * 2, max_downsample_channels) |
|
|
|
groups = min(groups * 4, max_groups) |
|
|
|
|
|
out_chs = min(in_chs * 2, max_downsample_channels) |
|
self.layers += [torch.nn.Sequential(torch.nn.Conv1d(in_chs, |
|
out_chs, |
|
kernel_size=kernel_sizes[2], |
|
stride=1, |
|
padding=(kernel_sizes[2] - 1) // 2, |
|
bias=bias, ), |
|
getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params), )] |
|
self.post_conv = SANConv1d(out_chs, out_channels, kernel_sizes[3], padding=(kernel_sizes[3] - 1) // 2) |
|
|
|
if use_weight_norm and use_spectral_norm: |
|
raise ValueError("Either use use_weight_norm or use_spectral_norm.") |
|
|
|
|
|
if use_weight_norm: |
|
self.apply_weight_norm() |
|
|
|
|
|
if use_spectral_norm: |
|
self.apply_spectral_norm() |
|
|
|
def forward(self, x, discriminator_train_flag): |
|
""" |
|
Calculate forward propagation. |
|
|
|
Args: |
|
x (Tensor): Input noise signal (B, 1, T). |
|
|
|
Returns: |
|
List: List of output tensors of each layer. |
|
""" |
|
outs = [] |
|
for f in self.layers: |
|
x = f(x) |
|
outs = outs + [x] |
|
|
|
x = self.post_conv(x, discriminator_train_flag) |
|
|
|
return x, outs |
|
|
|
def apply_weight_norm(self): |
|
""" |
|
Apply weight normalization module from all of the layers. |
|
""" |
|
|
|
def _apply_weight_norm(m): |
|
if isinstance(m, torch.nn.Conv2d): |
|
torch.nn.utils.weight_norm(m) |
|
|
|
self.apply(_apply_weight_norm) |
|
|
|
def apply_spectral_norm(self): |
|
""" |
|
Apply spectral normalization module from all of the layers. |
|
""" |
|
|
|
def _apply_spectral_norm(m): |
|
if isinstance(m, torch.nn.Conv2d): |
|
torch.nn.utils.spectral_norm(m) |
|
|
|
self.apply(_apply_spectral_norm) |
|
|
|
|
|
class HiFiGANMultiScaleDiscriminator(torch.nn.Module): |
|
|
|
def __init__(self, |
|
scales=3, |
|
downsample_pooling="AvgPool1d", |
|
|
|
downsample_pooling_params={"kernel_size": 4, |
|
"stride" : 2, |
|
"padding" : 2, }, |
|
discriminator_params={"in_channels" : 1, |
|
"out_channels" : 1, |
|
"kernel_sizes" : [15, 41, 5, 3], |
|
"channels" : 128, |
|
"max_downsample_channels" : 1024, |
|
"max_groups" : 16, |
|
"bias" : True, |
|
"downsample_scales" : [2, 2, 4, 4, 1], |
|
"nonlinear_activation" : "LeakyReLU", |
|
"nonlinear_activation_params": {"negative_slope": 0.1}, }, |
|
follow_official_norm=False, ): |
|
""" |
|
Initialize HiFiGAN multi-scale discriminator module. |
|
|
|
Args: |
|
scales (int): Number of multi-scales. |
|
downsample_pooling (str): Pooling module name for downsampling of the inputs. |
|
downsample_pooling_params (dict): Parameters for the above pooling module. |
|
discriminator_params (dict): Parameters for hifi-gan scale discriminator module. |
|
follow_official_norm (bool): Whether to follow the norm setting of the official |
|
implementaion. The first discriminator uses spectral norm and the other |
|
discriminators use weight norm. |
|
""" |
|
super().__init__() |
|
self.discriminators = torch.nn.ModuleList() |
|
|
|
|
|
for i in range(scales): |
|
params = copy.deepcopy(discriminator_params) |
|
if follow_official_norm: |
|
if i == 0: |
|
params["use_weight_norm"] = False |
|
params["use_spectral_norm"] = True |
|
else: |
|
params["use_weight_norm"] = True |
|
params["use_spectral_norm"] = False |
|
self.discriminators += [HiFiGANScaleDiscriminator(**params)] |
|
self.pooling = getattr(torch.nn, downsample_pooling)(**downsample_pooling_params) |
|
|
|
def forward(self, x, discriminator_train_flag): |
|
""" |
|
Calculate forward propagation. |
|
|
|
Args: |
|
x (Tensor): Input noise signal (B, 1, T). |
|
|
|
Returns: |
|
List: List of list of each discriminator outputs, which consists of each layer output tensors. |
|
""" |
|
outs = [] |
|
feats = [] |
|
for f in self.discriminators: |
|
out, d_feats = f(x, discriminator_train_flag) |
|
feats = feats + d_feats |
|
outs = outs + [out] |
|
x = self.pooling(x) |
|
|
|
return outs, feats |
|
|
|
|
|
class HiFiGANMultiScaleMultiPeriodDiscriminator(torch.nn.Module): |
|
|
|
def __init__(self, |
|
|
|
scales=3, |
|
scale_downsample_pooling="AvgPool1d", |
|
scale_downsample_pooling_params={"kernel_size": 4, |
|
"stride" : 2, |
|
"padding" : 2, }, |
|
scale_discriminator_params={"in_channels" : 1, |
|
"out_channels" : 1, |
|
"kernel_sizes" : [15, 41, 5, 3], |
|
"channels" : 128, |
|
"max_downsample_channels" : 1024, |
|
"max_groups" : 16, |
|
"bias" : True, |
|
"downsample_scales" : [4, 4, 4, 4, 1], |
|
"nonlinear_activation" : "LeakyReLU", |
|
"nonlinear_activation_params": {"negative_slope": 0.1}, }, |
|
follow_official_norm=True, |
|
|
|
periods=[2, 3, 5, 7, 11], |
|
period_discriminator_params={"in_channels" : 1, |
|
"out_channels" : 1, |
|
"kernel_sizes" : [5, 3], |
|
"channels" : 32, |
|
"downsample_scales" : [3, 3, 3, 3, 1], |
|
"max_downsample_channels" : 1024, |
|
"bias" : True, |
|
"nonlinear_activation" : "LeakyReLU", |
|
"nonlinear_activation_params": {"negative_slope": 0.1}, |
|
"use_weight_norm" : True, |
|
"use_spectral_norm" : False, }, ): |
|
""" |
|
Initialize HiFiGAN multi-scale + multi-period discriminator module. |
|
|
|
Args: |
|
scales (int): Number of multi-scales. |
|
scale_downsample_pooling (str): Pooling module name for downsampling of the inputs. |
|
scale_downsample_pooling_params (dict): Parameters for the above pooling module. |
|
scale_discriminator_params (dict): Parameters for hifi-gan scale discriminator module. |
|
follow_official_norm (bool): Whether to follow the norm setting of the official |
|
implementaion. The first discriminator uses spectral norm and the other |
|
discriminators use weight norm. |
|
periods (list): List of periods. |
|
period_discriminator_params (dict): Parameters for hifi-gan period discriminator module. |
|
The period parameter will be overwritten. |
|
""" |
|
super().__init__() |
|
self.msd = HiFiGANMultiScaleDiscriminator(scales=scales, |
|
downsample_pooling=scale_downsample_pooling, |
|
downsample_pooling_params=scale_downsample_pooling_params, |
|
discriminator_params=scale_discriminator_params, |
|
follow_official_norm=follow_official_norm, ) |
|
self.mpd = HiFiGANMultiPeriodDiscriminator(periods=periods, |
|
discriminator_params=period_discriminator_params, ) |
|
|
|
def forward(self, x): |
|
""" |
|
Calculate forward propagation. |
|
|
|
Args: |
|
x (Tensor): Input noise signal (B, 1, T). |
|
|
|
Returns: |
|
List: List of list of each discriminator outputs, |
|
which consists of each layer output tensors. |
|
Multi scale and multi period ones are concatenated. |
|
""" |
|
msd_outs = self.msd(x) |
|
mpd_outs = self.mpd(x) |
|
return msd_outs + mpd_outs |
|
|
|
|
|
class AvocodoHiFiGANJointDiscriminator(torch.nn.Module): |
|
|
|
def __init__(self, |
|
|
|
scales=3, |
|
scale_downsample_pooling="AvgPool1d", |
|
scale_downsample_pooling_params={"kernel_size": 4, |
|
"stride" : 2, |
|
"padding" : 2, }, |
|
scale_discriminator_params={"in_channels" : 1, |
|
"out_channels" : 1, |
|
"kernel_sizes" : [15, 41, 5, 3], |
|
"channels" : 128, |
|
"max_downsample_channels" : 1024, |
|
"max_groups" : 16, |
|
"bias" : True, |
|
"downsample_scales" : [4, 4, 4, 4, 1], |
|
"nonlinear_activation" : "LeakyReLU", |
|
"nonlinear_activation_params": {"negative_slope": 0.1}, }, |
|
follow_official_norm=True, |
|
|
|
periods=(2, 3, 5, 7, 11), |
|
period_discriminator_params={"in_channels" : 1, |
|
"out_channels" : 1, |
|
"kernel_sizes" : [5, 3], |
|
"channels" : 32, |
|
"downsample_scales" : [3, 3, 3, 3, 1], |
|
"max_downsample_channels" : 1024, |
|
"bias" : True, |
|
"nonlinear_activation" : "LeakyReLU", |
|
"nonlinear_activation_params": {"negative_slope": 0.1}, |
|
"use_weight_norm" : True, |
|
"use_spectral_norm" : False, }, |
|
|
|
kernels=((7, 11, 11, 11, 11, 5), |
|
(11, 21, 21, 21, 21, 5), |
|
(15, 41, 41, 41, 41, 5)), |
|
channels=(16, 64, 256, 1024, 1024, 1024), |
|
groups=(1, 4, 16, 64, 256, 1), |
|
strides=(1, 1, 4, 4, 4, 1), |
|
|
|
tkernels=(7, 5, 3), |
|
fkernel=5, |
|
tchannels=(64, 128, 256, 256, 256), |
|
fchannels=(32, 64, 128, 128, 128), |
|
tstrides=((1, 1, 3, 3, 1), |
|
(1, 1, 3, 3, 1), |
|
(1, 1, 3, 3, 1)), |
|
fstride=(1, 1, 3, 3, 1), |
|
tdilations=(((5, 7, 11), (5, 7, 11), (5, 7, 11), (5, 7, 11), (5, 7, 11), (5, 7, 11)), |
|
((3, 5, 7), (3, 5, 7), (3, 5, 7), (3, 5, 7), (3, 5, 7)), |
|
((1, 2, 3), (1, 2, 3), (1, 2, 3), (1, 2, 3), (1, 2, 3))), |
|
fdilations=((1, 2, 3), |
|
(1, 2, 3), |
|
(1, 2, 3), |
|
(2, 3, 5), |
|
(2, 3, 5)), |
|
tsubband=(6, 11, 16), |
|
n=16, |
|
m=64, |
|
freq_init_ch=192): |
|
super().__init__() |
|
self.msd = HiFiGANMultiScaleDiscriminator(scales=scales, |
|
downsample_pooling=scale_downsample_pooling, |
|
downsample_pooling_params=scale_downsample_pooling_params, |
|
discriminator_params=scale_discriminator_params, |
|
follow_official_norm=follow_official_norm, ) |
|
self.mpd = HiFiGANMultiPeriodDiscriminator(periods=periods, |
|
discriminator_params=period_discriminator_params, ) |
|
self.mcmbd = MultiCoMBDiscriminator(kernels, channels, groups, strides) |
|
self.msbd = MultiSubBandDiscriminator(tkernels, fkernel, tchannels, fchannels, tstrides, fstride, tdilations, fdilations, tsubband, n, m, freq_init_ch) |
|
|
|
def forward(self, wave, intermediate_wave_upsampled_twice=None, intermediate_wave_upsampled_once=None, discriminator_train_flag=False): |
|
""" |
|
Calculate forward propagation. |
|
|
|
Args: |
|
wave: The predicted or gold waveform |
|
intermediate_wave_upsampled_twice: the wave before the final upsampling in the generator |
|
intermediate_wave_upsampled_once: the wave before the second final upsampling in the generator |
|
|
|
Returns: |
|
List: List of lists of each discriminator outputs, |
|
which consists of each layer's output tensors. |
|
""" |
|
msd_outs, msd_feats = self.msd(wave, discriminator_train_flag) |
|
mpd_outs, mpd_feats = self.mpd(wave, discriminator_train_flag) |
|
mcmbd_outs, mcmbd_feats = self.mcmbd(wave_final=wave, |
|
intermediate_wave_upsampled_twice=intermediate_wave_upsampled_twice, |
|
intermediate_wave_upsampled_once=intermediate_wave_upsampled_once, |
|
discriminator_train_flag=discriminator_train_flag) |
|
msbd_outs, msbd_feats = self.msbd(wave, discriminator_train_flag) |
|
return msd_outs + mpd_outs + mcmbd_outs + msbd_outs, msd_feats + mpd_feats + mcmbd_feats + msbd_feats |
|
|