Spaces:
Sleeping
Sleeping
import torch | |
from torch import nn | |
from TTS.tts.layers.generic.res_conv_bn import Conv1dBN, Conv1dBNBlock, ResidualConv1dBNBlock | |
from TTS.tts.layers.generic.transformer import FFTransformerBlock | |
from TTS.tts.layers.generic.wavenet import WNBlocks | |
from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer | |
class WaveNetDecoder(nn.Module): | |
"""WaveNet based decoder with a prenet and a postnet. | |
prenet: conv1d_1x1 | |
postnet: 3 x [conv1d_1x1 -> relu] -> conv1d_1x1 | |
TODO: Integrate speaker conditioning vector. | |
Note: | |
default wavenet parameters; | |
params = { | |
"num_blocks": 12, | |
"hidden_channels":192, | |
"kernel_size": 5, | |
"dilation_rate": 1, | |
"num_layers": 4, | |
"dropout_p": 0.05 | |
} | |
Args: | |
in_channels (int): number of input channels. | |
out_channels (int): number of output channels. | |
hidden_channels (int): number of hidden channels for prenet and postnet. | |
params (dict): dictionary for residual convolutional blocks. | |
""" | |
def __init__(self, in_channels, out_channels, hidden_channels, c_in_channels, params): | |
super().__init__() | |
# prenet | |
self.prenet = torch.nn.Conv1d(in_channels, params["hidden_channels"], 1) | |
# wavenet layers | |
self.wn = WNBlocks(params["hidden_channels"], c_in_channels=c_in_channels, **params) | |
# postnet | |
self.postnet = [ | |
torch.nn.Conv1d(params["hidden_channels"], hidden_channels, 1), | |
torch.nn.ReLU(), | |
torch.nn.Conv1d(hidden_channels, hidden_channels, 1), | |
torch.nn.ReLU(), | |
torch.nn.Conv1d(hidden_channels, hidden_channels, 1), | |
torch.nn.ReLU(), | |
torch.nn.Conv1d(hidden_channels, out_channels, 1), | |
] | |
self.postnet = nn.Sequential(*self.postnet) | |
def forward(self, x, x_mask=None, g=None): | |
x = self.prenet(x) * x_mask | |
x = self.wn(x, x_mask, g) | |
o = self.postnet(x) * x_mask | |
return o | |
class RelativePositionTransformerDecoder(nn.Module): | |
"""Decoder with Relative Positional Transformer. | |
Note: | |
Default params | |
params={ | |
'hidden_channels_ffn': 128, | |
'num_heads': 2, | |
"kernel_size": 3, | |
"dropout_p": 0.1, | |
"num_layers": 8, | |
"rel_attn_window_size": 4, | |
"input_length": None | |
} | |
Args: | |
in_channels (int): number of input channels. | |
out_channels (int): number of output channels. | |
hidden_channels (int): number of hidden channels including Transformer layers. | |
params (dict): dictionary for residual convolutional blocks. | |
""" | |
def __init__(self, in_channels, out_channels, hidden_channels, params): | |
super().__init__() | |
self.prenet = Conv1dBN(in_channels, hidden_channels, 1, 1) | |
self.rel_pos_transformer = RelativePositionTransformer(in_channels, out_channels, hidden_channels, **params) | |
def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument | |
o = self.prenet(x) * x_mask | |
o = self.rel_pos_transformer(o, x_mask) | |
return o | |
class FFTransformerDecoder(nn.Module): | |
"""Decoder with FeedForwardTransformer. | |
Default params | |
params={ | |
'hidden_channels_ffn': 1024, | |
'num_heads': 2, | |
"dropout_p": 0.1, | |
"num_layers": 6, | |
} | |
Args: | |
in_channels (int): number of input channels. | |
out_channels (int): number of output channels. | |
hidden_channels (int): number of hidden channels including Transformer layers. | |
params (dict): dictionary for residual convolutional blocks. | |
""" | |
def __init__(self, in_channels, out_channels, params): | |
super().__init__() | |
self.transformer_block = FFTransformerBlock(in_channels, **params) | |
self.postnet = nn.Conv1d(in_channels, out_channels, 1) | |
def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument | |
# TODO: handle multi-speaker | |
x_mask = 1 if x_mask is None else x_mask | |
o = self.transformer_block(x) * x_mask | |
o = self.postnet(o) * x_mask | |
return o | |
class ResidualConv1dBNDecoder(nn.Module): | |
"""Residual Convolutional Decoder as in the original Speedy Speech paper | |
TODO: Integrate speaker conditioning vector. | |
Note: | |
Default params | |
params = { | |
"kernel_size": 4, | |
"dilations": 4 * [1, 2, 4, 8] + [1], | |
"num_conv_blocks": 2, | |
"num_res_blocks": 17 | |
} | |
Args: | |
in_channels (int): number of input channels. | |
out_channels (int): number of output channels. | |
hidden_channels (int): number of hidden channels including ResidualConv1dBNBlock layers. | |
params (dict): dictionary for residual convolutional blocks. | |
""" | |
def __init__(self, in_channels, out_channels, hidden_channels, params): | |
super().__init__() | |
self.res_conv_block = ResidualConv1dBNBlock(in_channels, hidden_channels, hidden_channels, **params) | |
self.post_conv = nn.Conv1d(hidden_channels, hidden_channels, 1) | |
self.postnet = nn.Sequential( | |
Conv1dBNBlock( | |
hidden_channels, hidden_channels, hidden_channels, params["kernel_size"], 1, num_conv_blocks=2 | |
), | |
nn.Conv1d(hidden_channels, out_channels, 1), | |
) | |
def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument | |
o = self.res_conv_block(x, x_mask) | |
o = self.post_conv(o) + x | |
return self.postnet(o) * x_mask | |
class Decoder(nn.Module): | |
"""Decodes the expanded phoneme encoding into spectrograms | |
Args: | |
out_channels (int): number of output channels. | |
in_hidden_channels (int): input and hidden channels. Model keeps the input channels for the intermediate layers. | |
decoder_type (str): decoder layer types. 'transformers' or 'residual_conv_bn'. Default 'residual_conv_bn'. | |
decoder_params (dict): model parameters for specified decoder type. | |
c_in_channels (int): number of channels for conditional input. | |
Shapes: | |
- input: (B, C, T) | |
""" | |
# pylint: disable=dangerous-default-value | |
def __init__( | |
self, | |
out_channels, | |
in_hidden_channels, | |
decoder_type="residual_conv_bn", | |
decoder_params={ | |
"kernel_size": 4, | |
"dilations": 4 * [1, 2, 4, 8] + [1], | |
"num_conv_blocks": 2, | |
"num_res_blocks": 17, | |
}, | |
c_in_channels=0, | |
): | |
super().__init__() | |
if decoder_type.lower() == "relative_position_transformer": | |
self.decoder = RelativePositionTransformerDecoder( | |
in_channels=in_hidden_channels, | |
out_channels=out_channels, | |
hidden_channels=in_hidden_channels, | |
params=decoder_params, | |
) | |
elif decoder_type.lower() == "residual_conv_bn": | |
self.decoder = ResidualConv1dBNDecoder( | |
in_channels=in_hidden_channels, | |
out_channels=out_channels, | |
hidden_channels=in_hidden_channels, | |
params=decoder_params, | |
) | |
elif decoder_type.lower() == "wavenet": | |
self.decoder = WaveNetDecoder( | |
in_channels=in_hidden_channels, | |
out_channels=out_channels, | |
hidden_channels=in_hidden_channels, | |
c_in_channels=c_in_channels, | |
params=decoder_params, | |
) | |
elif decoder_type.lower() == "fftransformer": | |
self.decoder = FFTransformerDecoder(in_hidden_channels, out_channels, decoder_params) | |
else: | |
raise ValueError(f"[!] Unknown decoder type - {decoder_type}") | |
def forward(self, x, x_mask, g=None): # pylint: disable=unused-argument | |
""" | |
Args: | |
x: [B, C, T] | |
x_mask: [B, 1, T] | |
g: [B, C_g, 1] | |
""" | |
# TODO: implement multi-speaker | |
o = self.decoder(x, x_mask, g) | |
return o | |