Spaces:
Running
Running
File size: 5,913 Bytes
9b2107c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
from torch import nn
from TTS.tts.layers.generic.res_conv_bn import ResidualConv1dBNBlock
from TTS.tts.layers.generic.transformer import FFTransformerBlock
from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer
class RelativePositionTransformerEncoder(nn.Module):
"""Speedy speech encoder built on Transformer with Relative Position encoding.
TODO: Integrate speaker conditioning vector.
Args:
in_channels (int): number of input channels.
out_channels (int): number of output channels.
hidden_channels (int): number of hidden channels
params (dict): dictionary for residual convolutional blocks.
"""
def __init__(self, in_channels, out_channels, hidden_channels, params):
super().__init__()
self.prenet = ResidualConv1dBNBlock(
in_channels,
hidden_channels,
hidden_channels,
kernel_size=5,
num_res_blocks=3,
num_conv_blocks=1,
dilations=[1, 1, 1],
)
self.rel_pos_transformer = RelativePositionTransformer(hidden_channels, out_channels, hidden_channels, **params)
def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument
if x_mask is None:
x_mask = 1
o = self.prenet(x) * x_mask
o = self.rel_pos_transformer(o, x_mask)
return o
class ResidualConv1dBNEncoder(nn.Module):
"""Residual Convolutional Encoder as in the original Speedy Speech paper
TODO: Integrate speaker conditioning vector.
Args:
in_channels (int): number of input channels.
out_channels (int): number of output channels.
hidden_channels (int): number of hidden channels
params (dict): dictionary for residual convolutional blocks.
"""
def __init__(self, in_channels, out_channels, hidden_channels, params):
super().__init__()
self.prenet = nn.Sequential(nn.Conv1d(in_channels, hidden_channels, 1), nn.ReLU())
self.res_conv_block = ResidualConv1dBNBlock(hidden_channels, hidden_channels, hidden_channels, **params)
self.postnet = nn.Sequential(
*[
nn.Conv1d(hidden_channels, hidden_channels, 1),
nn.ReLU(),
nn.BatchNorm1d(hidden_channels),
nn.Conv1d(hidden_channels, out_channels, 1),
]
)
def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument
if x_mask is None:
x_mask = 1
o = self.prenet(x) * x_mask
o = self.res_conv_block(o, x_mask)
o = self.postnet(o + x) * x_mask
return o * x_mask
class Encoder(nn.Module):
# pylint: disable=dangerous-default-value
"""Factory class for Speedy Speech encoder enables different encoder types internally.
Args:
num_chars (int): number of characters.
out_channels (int): number of output channels.
in_hidden_channels (int): input and hidden channels. Model keeps the input channels for the intermediate layers.
encoder_type (str): encoder layer types. 'transformers' or 'residual_conv_bn'. Default 'residual_conv_bn'.
encoder_params (dict): model parameters for specified encoder type.
c_in_channels (int): number of channels for conditional input.
Note:
Default encoder_params to be set in config.json...
```python
# for 'relative_position_transformer'
encoder_params={
'hidden_channels_ffn': 128,
'num_heads': 2,
"kernel_size": 3,
"dropout_p": 0.1,
"num_layers": 6,
"rel_attn_window_size": 4,
"input_length": None
},
# for 'residual_conv_bn'
encoder_params = {
"kernel_size": 4,
"dilations": 4 * [1, 2, 4] + [1],
"num_conv_blocks": 2,
"num_res_blocks": 13
}
# for 'fftransformer'
encoder_params = {
"hidden_channels_ffn": 1024 ,
"num_heads": 2,
"num_layers": 6,
"dropout_p": 0.1
}
```
"""
def __init__(
self,
in_hidden_channels,
out_channels,
encoder_type="residual_conv_bn",
encoder_params={"kernel_size": 4, "dilations": 4 * [1, 2, 4] + [1], "num_conv_blocks": 2, "num_res_blocks": 13},
c_in_channels=0,
):
super().__init__()
self.out_channels = out_channels
self.in_channels = in_hidden_channels
self.hidden_channels = in_hidden_channels
self.encoder_type = encoder_type
self.c_in_channels = c_in_channels
# init encoder
if encoder_type.lower() == "relative_position_transformer":
# text encoder
# pylint: disable=unexpected-keyword-arg
self.encoder = RelativePositionTransformerEncoder(
in_hidden_channels, out_channels, in_hidden_channels, encoder_params
)
elif encoder_type.lower() == "residual_conv_bn":
self.encoder = ResidualConv1dBNEncoder(in_hidden_channels, out_channels, in_hidden_channels, encoder_params)
elif encoder_type.lower() == "fftransformer":
assert (
in_hidden_channels == out_channels
), "[!] must be `in_channels` == `out_channels` when encoder type is 'fftransformer'"
# pylint: disable=unexpected-keyword-arg
self.encoder = FFTransformerBlock(in_hidden_channels, **encoder_params)
else:
raise NotImplementedError(" [!] unknown encoder type.")
def forward(self, x, x_mask, g=None): # pylint: disable=unused-argument
"""
Shapes:
x: [B, C, T]
x_mask: [B, 1, T]
g: [B, C, 1]
"""
o = self.encoder(x, x_mask)
return o * x_mask
|