Spaces:
Sleeping
Sleeping
from typing import List, Optional | |
import torch | |
import torch.nn as nn | |
from torch import Tensor | |
from mtts.models.decoder import Decoder | |
from mtts.models.encoder import FS2TransformerEncoder | |
from mtts.models.fs2_variance import VarianceAdaptor | |
from mtts.models.postnet import PostNet | |
from mtts.utils.logging import get_logger | |
ENCODERS = [ | |
FS2TransformerEncoder, | |
] | |
logger = get_logger(__file__) | |
def __read_vocab(file): | |
with open(file) as f: | |
lines = f.read().split('\n') | |
lines = [line for line in lines if len(line) > 0] | |
return lines | |
def _get_layer(emb_config: dict): # -> Optional[List[nn.Module],List[float]]: | |
logger.info(f'building embedding with config: {emb_config}') | |
if emb_config['enable']: | |
if emb_config['vocab'] is None: | |
vocab_size = emb_config['vocab_size'] | |
else: | |
vocab = __read_vocab(emb_config['vocab']) | |
vocab_size = len(vocab) | |
layer = nn.Embedding(vocab_size, emb_config['dim'], padding_idx=0) | |
return layer, emb_config['weight'] | |
else: | |
return None, None | |
def _build_embedding_layers(config): | |
layers = nn.ModuleList() | |
weights = [] | |
for c in ( | |
config['pinyin_embedding'], | |
config['hanzi_embedding'], | |
config['speaker_embedding'], | |
): | |
layer, weight = _get_layer(c) | |
if layer is not None: | |
layers.append(layer) | |
weights.append(weight) | |
return layers, weights | |
def get_mask_from_lengths(lengths, max_len=None): | |
batch_size = lengths.shape[0] | |
if max_len is None: | |
max_len = torch.max(lengths).item() | |
ids = torch.arange(0, max_len).unsqueeze(0).expand(batch_size, -1).to(lengths.device) | |
mask = (ids >= lengths.unsqueeze(1).expand(-1, max_len)) | |
return mask | |
class FastSpeech2(nn.Module): | |
""" FastSpeech2 """ | |
def __init__(self, config): | |
super(FastSpeech2, self).__init__() | |
emb_layers, emb_weights = _build_embedding_layers(config) | |
EncoderClass = eval(config['encoder']['encoder_type']) | |
assert EncoderClass in ENCODERS | |
encoder_conf = config['encoder']['conf'] | |
encoder_conf.update({'emb_layers': emb_layers}) | |
encoder_conf.update({'embeding_weights': emb_weights}) | |
self.encoder = EncoderClass(**encoder_conf) | |
dur_config = config['duration_predictor'] | |
self.variance_adaptor = VarianceAdaptor(**dur_config) | |
decoder_config = config['decoder'] | |
self.decoder = Decoder(**decoder_config) | |
n_mels = config['fbank']['n_mels'] | |
mel_linear_input_dim = decoder_config['hidden_dim'] | |
self.mel_linear = nn.Linear(mel_linear_input_dim, n_mels) | |
self.postnet = PostNet(**config['postnet'], ) | |
def forward(self, | |
input_seqs: List[Tensor], | |
seq_len: Tensor, | |
mel_len: Optional[Tensor] = None, | |
d_target: Optional[Tensor] = None, | |
max_src_len=None, | |
max_mel_len=None, | |
d_control=1.0, | |
p_control=1.0, | |
e_control=1.0): | |
src_mask = get_mask_from_lengths(seq_len, max_src_len) | |
if mel_len is not None: | |
mel_mask = get_mask_from_lengths(mel_len, max_mel_len) | |
else: | |
mel_mask = None | |
encoder_output = self.encoder(input_seqs, src_mask) | |
if d_target is not None: | |
variance_adaptor_output, d_prediction, _, _ = self.variance_adaptor(encoder_output, src_mask, mel_mask, | |
d_target, max_mel_len, d_control) | |
else: | |
variance_adaptor_output, d_prediction, mel_len, mel_mask = self.variance_adaptor( | |
encoder_output, src_mask, mel_mask, d_target, max_mel_len, d_control) | |
decoder_output = self.decoder(variance_adaptor_output, mel_mask) | |
mel_pred = self.mel_linear(decoder_output) | |
postnet_input = torch.unsqueeze(mel_pred, 1) | |
postnet_output = self.postnet(postnet_input) + mel_pred | |
return mel_pred, postnet_output, d_prediction, src_mask, mel_mask, mel_len | |
if __name__ == "__main__": | |
# Test | |
import yaml | |
with open('../../examples/aishell3/config.yaml') as f: | |
config = yaml.safe_load(f) | |
model = FastSpeech2(config) | |
print(model) | |
print(sum(param.numel() for param in model.parameters())) | |