File size: 4,422 Bytes
14d1720
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
from typing import List, Optional

import torch
import torch.nn as nn
from torch import Tensor

from mtts.models.decoder import Decoder
from mtts.models.encoder import FS2TransformerEncoder
from mtts.models.fs2_variance import VarianceAdaptor
from mtts.models.postnet import PostNet
from mtts.utils.logging import get_logger

ENCODERS = [
    FS2TransformerEncoder,
]

logger = get_logger(__file__)


def __read_vocab(file):
    with open(file) as f:
        lines = f.read().split('\n')
    lines = [line for line in lines if len(line) > 0]
    return lines


def _get_layer(emb_config: dict):  # -> Optional[List[nn.Module],List[float]]:
    logger.info(f'building embedding with config: {emb_config}')
    if emb_config['enable']:
        if emb_config['vocab'] is None:
            vocab_size = emb_config['vocab_size']
        else:
            vocab = __read_vocab(emb_config['vocab'])
            vocab_size = len(vocab)
        layer = nn.Embedding(vocab_size, emb_config['dim'], padding_idx=0)
        return layer, emb_config['weight']
    else:
        return None, None


def _build_embedding_layers(config):
    layers = nn.ModuleList()
    weights = []
    for c in (
            config['pinyin_embedding'],
            config['hanzi_embedding'],
            config['speaker_embedding'],
    ):

        layer, weight = _get_layer(c)
        if layer is not None:
            layers.append(layer)
            weights.append(weight)

    return layers, weights


def get_mask_from_lengths(lengths, max_len=None):
    batch_size = lengths.shape[0]
    if max_len is None:
        max_len = torch.max(lengths).item()
    ids = torch.arange(0, max_len).unsqueeze(0).expand(batch_size, -1).to(lengths.device)
    mask = (ids >= lengths.unsqueeze(1).expand(-1, max_len))

    return mask


class FastSpeech2(nn.Module):
    """ FastSpeech2 """
    def __init__(self, config):
        super(FastSpeech2, self).__init__()

        emb_layers, emb_weights = _build_embedding_layers(config)

        EncoderClass = eval(config['encoder']['encoder_type'])
        assert EncoderClass in ENCODERS
        encoder_conf = config['encoder']['conf']
        encoder_conf.update({'emb_layers': emb_layers})
        encoder_conf.update({'embeding_weights': emb_weights})
        self.encoder = EncoderClass(**encoder_conf)

        dur_config = config['duration_predictor']
        self.variance_adaptor = VarianceAdaptor(**dur_config)

        decoder_config = config['decoder']
        self.decoder = Decoder(**decoder_config)

        n_mels = config['fbank']['n_mels']
        mel_linear_input_dim = decoder_config['hidden_dim']
        self.mel_linear = nn.Linear(mel_linear_input_dim, n_mels)
        self.postnet = PostNet(**config['postnet'], )

    def forward(self,
                input_seqs: List[Tensor],
                seq_len: Tensor,
                mel_len: Optional[Tensor] = None,
                d_target: Optional[Tensor] = None,
                max_src_len=None,
                max_mel_len=None,
                d_control=1.0,
                p_control=1.0,
                e_control=1.0):
        src_mask = get_mask_from_lengths(seq_len, max_src_len)

        if mel_len is not None:
            mel_mask = get_mask_from_lengths(mel_len, max_mel_len)
        else:
            mel_mask = None

        encoder_output = self.encoder(input_seqs, src_mask)

        if d_target is not None:
            variance_adaptor_output, d_prediction, _, _ = self.variance_adaptor(encoder_output, src_mask, mel_mask,
                                                                                d_target, max_mel_len, d_control)
        else:
            variance_adaptor_output, d_prediction, mel_len, mel_mask = self.variance_adaptor(
                encoder_output, src_mask, mel_mask, d_target, max_mel_len, d_control)

        decoder_output = self.decoder(variance_adaptor_output, mel_mask)

        mel_pred = self.mel_linear(decoder_output)
        postnet_input = torch.unsqueeze(mel_pred, 1)
        postnet_output = self.postnet(postnet_input) + mel_pred

        return mel_pred, postnet_output, d_prediction, src_mask, mel_mask, mel_len


if __name__ == "__main__":
    # Test
    import yaml
    with open('../../examples/aishell3/config.yaml') as f:
        config = yaml.safe_load(f)
    model = FastSpeech2(config)
    print(model)
    print(sum(param.numel() for param in model.parameters()))