text_to_speech / mtts /models /fs2_model.py
wuxulong19950206
First model version
14d1720
from typing import List, Optional
import torch
import torch.nn as nn
from torch import Tensor
from mtts.models.decoder import Decoder
from mtts.models.encoder import FS2TransformerEncoder
from mtts.models.fs2_variance import VarianceAdaptor
from mtts.models.postnet import PostNet
from mtts.utils.logging import get_logger
ENCODERS = [
FS2TransformerEncoder,
]
logger = get_logger(__file__)
def __read_vocab(file):
with open(file) as f:
lines = f.read().split('\n')
lines = [line for line in lines if len(line) > 0]
return lines
def _get_layer(emb_config: dict): # -> Optional[List[nn.Module],List[float]]:
logger.info(f'building embedding with config: {emb_config}')
if emb_config['enable']:
if emb_config['vocab'] is None:
vocab_size = emb_config['vocab_size']
else:
vocab = __read_vocab(emb_config['vocab'])
vocab_size = len(vocab)
layer = nn.Embedding(vocab_size, emb_config['dim'], padding_idx=0)
return layer, emb_config['weight']
else:
return None, None
def _build_embedding_layers(config):
layers = nn.ModuleList()
weights = []
for c in (
config['pinyin_embedding'],
config['hanzi_embedding'],
config['speaker_embedding'],
):
layer, weight = _get_layer(c)
if layer is not None:
layers.append(layer)
weights.append(weight)
return layers, weights
def get_mask_from_lengths(lengths, max_len=None):
batch_size = lengths.shape[0]
if max_len is None:
max_len = torch.max(lengths).item()
ids = torch.arange(0, max_len).unsqueeze(0).expand(batch_size, -1).to(lengths.device)
mask = (ids >= lengths.unsqueeze(1).expand(-1, max_len))
return mask
class FastSpeech2(nn.Module):
""" FastSpeech2 """
def __init__(self, config):
super(FastSpeech2, self).__init__()
emb_layers, emb_weights = _build_embedding_layers(config)
EncoderClass = eval(config['encoder']['encoder_type'])
assert EncoderClass in ENCODERS
encoder_conf = config['encoder']['conf']
encoder_conf.update({'emb_layers': emb_layers})
encoder_conf.update({'embeding_weights': emb_weights})
self.encoder = EncoderClass(**encoder_conf)
dur_config = config['duration_predictor']
self.variance_adaptor = VarianceAdaptor(**dur_config)
decoder_config = config['decoder']
self.decoder = Decoder(**decoder_config)
n_mels = config['fbank']['n_mels']
mel_linear_input_dim = decoder_config['hidden_dim']
self.mel_linear = nn.Linear(mel_linear_input_dim, n_mels)
self.postnet = PostNet(**config['postnet'], )
def forward(self,
input_seqs: List[Tensor],
seq_len: Tensor,
mel_len: Optional[Tensor] = None,
d_target: Optional[Tensor] = None,
max_src_len=None,
max_mel_len=None,
d_control=1.0,
p_control=1.0,
e_control=1.0):
src_mask = get_mask_from_lengths(seq_len, max_src_len)
if mel_len is not None:
mel_mask = get_mask_from_lengths(mel_len, max_mel_len)
else:
mel_mask = None
encoder_output = self.encoder(input_seqs, src_mask)
if d_target is not None:
variance_adaptor_output, d_prediction, _, _ = self.variance_adaptor(encoder_output, src_mask, mel_mask,
d_target, max_mel_len, d_control)
else:
variance_adaptor_output, d_prediction, mel_len, mel_mask = self.variance_adaptor(
encoder_output, src_mask, mel_mask, d_target, max_mel_len, d_control)
decoder_output = self.decoder(variance_adaptor_output, mel_mask)
mel_pred = self.mel_linear(decoder_output)
postnet_input = torch.unsqueeze(mel_pred, 1)
postnet_output = self.postnet(postnet_input) + mel_pred
return mel_pred, postnet_output, d_prediction, src_mask, mel_mask, mel_len
if __name__ == "__main__":
# Test
import yaml
with open('../../examples/aishell3/config.yaml') as f:
config = yaml.safe_load(f)
model = FastSpeech2(config)
print(model)
print(sum(param.numel() for param in model.parameters()))