Spaces:
Running
Running
from torch import nn | |
from TTS.tts.layers.generic.res_conv_bn import ResidualConv1dBNBlock | |
from TTS.tts.layers.generic.transformer import FFTransformerBlock | |
from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer | |
class RelativePositionTransformerEncoder(nn.Module): | |
"""Speedy speech encoder built on Transformer with Relative Position encoding. | |
TODO: Integrate speaker conditioning vector. | |
Args: | |
in_channels (int): number of input channels. | |
out_channels (int): number of output channels. | |
hidden_channels (int): number of hidden channels | |
params (dict): dictionary for residual convolutional blocks. | |
""" | |
def __init__(self, in_channels, out_channels, hidden_channels, params): | |
super().__init__() | |
self.prenet = ResidualConv1dBNBlock( | |
in_channels, | |
hidden_channels, | |
hidden_channels, | |
kernel_size=5, | |
num_res_blocks=3, | |
num_conv_blocks=1, | |
dilations=[1, 1, 1], | |
) | |
self.rel_pos_transformer = RelativePositionTransformer(hidden_channels, out_channels, hidden_channels, **params) | |
def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument | |
if x_mask is None: | |
x_mask = 1 | |
o = self.prenet(x) * x_mask | |
o = self.rel_pos_transformer(o, x_mask) | |
return o | |
class ResidualConv1dBNEncoder(nn.Module): | |
"""Residual Convolutional Encoder as in the original Speedy Speech paper | |
TODO: Integrate speaker conditioning vector. | |
Args: | |
in_channels (int): number of input channels. | |
out_channels (int): number of output channels. | |
hidden_channels (int): number of hidden channels | |
params (dict): dictionary for residual convolutional blocks. | |
""" | |
def __init__(self, in_channels, out_channels, hidden_channels, params): | |
super().__init__() | |
self.prenet = nn.Sequential(nn.Conv1d(in_channels, hidden_channels, 1), nn.ReLU()) | |
self.res_conv_block = ResidualConv1dBNBlock(hidden_channels, hidden_channels, hidden_channels, **params) | |
self.postnet = nn.Sequential( | |
*[ | |
nn.Conv1d(hidden_channels, hidden_channels, 1), | |
nn.ReLU(), | |
nn.BatchNorm1d(hidden_channels), | |
nn.Conv1d(hidden_channels, out_channels, 1), | |
] | |
) | |
def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument | |
if x_mask is None: | |
x_mask = 1 | |
o = self.prenet(x) * x_mask | |
o = self.res_conv_block(o, x_mask) | |
o = self.postnet(o + x) * x_mask | |
return o * x_mask | |
class Encoder(nn.Module): | |
# pylint: disable=dangerous-default-value | |
"""Factory class for Speedy Speech encoder enables different encoder types internally. | |
Args: | |
num_chars (int): number of characters. | |
out_channels (int): number of output channels. | |
in_hidden_channels (int): input and hidden channels. Model keeps the input channels for the intermediate layers. | |
encoder_type (str): encoder layer types. 'transformers' or 'residual_conv_bn'. Default 'residual_conv_bn'. | |
encoder_params (dict): model parameters for specified encoder type. | |
c_in_channels (int): number of channels for conditional input. | |
Note: | |
Default encoder_params to be set in config.json... | |
```python | |
# for 'relative_position_transformer' | |
encoder_params={ | |
'hidden_channels_ffn': 128, | |
'num_heads': 2, | |
"kernel_size": 3, | |
"dropout_p": 0.1, | |
"num_layers": 6, | |
"rel_attn_window_size": 4, | |
"input_length": None | |
}, | |
# for 'residual_conv_bn' | |
encoder_params = { | |
"kernel_size": 4, | |
"dilations": 4 * [1, 2, 4] + [1], | |
"num_conv_blocks": 2, | |
"num_res_blocks": 13 | |
} | |
# for 'fftransformer' | |
encoder_params = { | |
"hidden_channels_ffn": 1024 , | |
"num_heads": 2, | |
"num_layers": 6, | |
"dropout_p": 0.1 | |
} | |
``` | |
""" | |
def __init__( | |
self, | |
in_hidden_channels, | |
out_channels, | |
encoder_type="residual_conv_bn", | |
encoder_params={"kernel_size": 4, "dilations": 4 * [1, 2, 4] + [1], "num_conv_blocks": 2, "num_res_blocks": 13}, | |
c_in_channels=0, | |
): | |
super().__init__() | |
self.out_channels = out_channels | |
self.in_channels = in_hidden_channels | |
self.hidden_channels = in_hidden_channels | |
self.encoder_type = encoder_type | |
self.c_in_channels = c_in_channels | |
# init encoder | |
if encoder_type.lower() == "relative_position_transformer": | |
# text encoder | |
# pylint: disable=unexpected-keyword-arg | |
self.encoder = RelativePositionTransformerEncoder( | |
in_hidden_channels, out_channels, in_hidden_channels, encoder_params | |
) | |
elif encoder_type.lower() == "residual_conv_bn": | |
self.encoder = ResidualConv1dBNEncoder(in_hidden_channels, out_channels, in_hidden_channels, encoder_params) | |
elif encoder_type.lower() == "fftransformer": | |
assert ( | |
in_hidden_channels == out_channels | |
), "[!] must be `in_channels` == `out_channels` when encoder type is 'fftransformer'" | |
# pylint: disable=unexpected-keyword-arg | |
self.encoder = FFTransformerBlock(in_hidden_channels, **encoder_params) | |
else: | |
raise NotImplementedError(" [!] unknown encoder type.") | |
def forward(self, x, x_mask, g=None): # pylint: disable=unused-argument | |
""" | |
Shapes: | |
x: [B, C, T] | |
x_mask: [B, 1, T] | |
g: [B, C, 1] | |
""" | |
o = self.encoder(x, x_mask) | |
return o * x_mask | |