File size: 8,844 Bytes
dc9eaa3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268

import re
import logging
import torch
import torchaudio
import random
import speechbrain
from speechbrain.inference.interfaces import Pretrained
from speechbrain.inference.text import GraphemeToPhoneme

logger = logging.getLogger(__name__)

class TTSInferencing(Pretrained):
    """
    A ready-to-use wrapper for TTS (text -> mel_spec).
    Arguments
    ---------
    hparams
        Hyperparameters (from HyperPyYAML)
    """

    HPARAMS_NEEDED = ["modules", "input_encoder"]

    MODULES_NEEDED = ["encoder_prenet", "pos_emb_enc",
                      "decoder_prenet", "pos_emb_dec",
                      "Seq2SeqTransformer", "mel_lin",
                      "stop_lin", "decoder_postnet"]


    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        lexicon = self.hparams.lexicon
        lexicon = ["@@"] + lexicon
        self.input_encoder = self.hparams.input_encoder
        self.input_encoder.update_from_iterable(lexicon, sequence_input=False)
        self.input_encoder.add_unk()

        self.modules = self.hparams.modules

        self.g2p = GraphemeToPhoneme.from_hparams("speechbrain/soundchoice-g2p")




    def generate_padded_phonemes(self, texts):
        """Computes mel-spectrogram for a list of texts

        Arguments
        ---------
        texts: List[str]
            texts to be converted to spectrogram

        Returns
        -------
        tensors of output spectrograms
        """

        # Preprocessing required at the inference time for the input text
        # "label" below contains input text
        # "phoneme_labels" contain the phoneme sequences corresponding to input text labels

        phoneme_labels = list()

        for label in texts:

          phoneme_label = list()

          label = self.custom_clean(label).upper()

          words = label.split()
          words = [word.strip() for word in words]
          words_phonemes = self.g2p(words)

          for i in range(len(words_phonemes)):
              words_phonemes_seq = words_phonemes[i]
              for phoneme in words_phonemes_seq:
                  if not phoneme.isspace():
                      phoneme_label.append(phoneme)
          phoneme_labels.append(phoneme_label)


        # encode the phonemes with input text encoder
        encoded_phonemes = list()
        for i in range(len(phoneme_labels)):
            phoneme_label = phoneme_labels[i]
            encoded_phoneme =  torch.LongTensor(self.input_encoder.encode_sequence(phoneme_label)).to(self.device)
            encoded_phonemes.append(encoded_phoneme)


        # Right zero-pad all one-hot text sequences to max input length
        input_lengths, ids_sorted_decreasing = torch.sort(
            torch.LongTensor([len(x) for x in encoded_phonemes]), dim=0, descending=True
        )

        max_input_len = input_lengths[0]

        phoneme_padded = torch.LongTensor(len(encoded_phonemes), max_input_len).to(self.device)
        phoneme_padded.zero_()

        for seq_idx, seq in enumerate(encoded_phonemes):
            phoneme_padded[seq_idx, : len(seq)] = seq


        return phoneme_padded.to(self.device, non_blocking=True).float()


    def encode_batch(self, texts):
        """Computes mel-spectrogram for a list of texts

        Texts must be sorted in decreasing order on their lengths

        Arguments
        ---------
        texts: List[str]
            texts to be encoded into spectrogram

        Returns
        -------
        tensors of output spectrograms
        """

        # generate phonemes and padd the input texts
        encoded_phoneme_padded = self.generate_padded_phonemes(texts)
        phoneme_prenet_emb = self.modules['encoder_prenet'](encoded_phoneme_padded)
        # Positional Embeddings
        phoneme_pos_emb =  self.modules['pos_emb_enc'](encoded_phoneme_padded)
        # Summing up embeddings
        enc_phoneme_emb = phoneme_prenet_emb.permute(0,2,1)  + phoneme_pos_emb
        enc_phoneme_emb = enc_phoneme_emb.to(self.device)


        with torch.no_grad():

          # generate sequential predictions via transformer decoder
          start_token = torch.full((80, 1), fill_value= 0)
          start_token[1] = 2
          decoder_input = start_token.repeat(enc_phoneme_emb.size(0), 1, 1)
          decoder_input = decoder_input.to(self.device, non_blocking=True).float()

          num_itr = 0
          stop_condition = [False] * decoder_input.size(0)
          max_iter = 100

          # while not all(stop_condition) and num_itr < max_iter:
          while num_itr < max_iter:

            # Decoder Prenet
            mel_prenet_emb =  self.modules['decoder_prenet'](decoder_input).to(self.device).permute(0,2,1)

            # Positional Embeddings
            mel_pos_emb =  self.modules['pos_emb_dec'](mel_prenet_emb).to(self.device)
            # Summing up Embeddings
            dec_mel_spec = mel_prenet_emb + mel_pos_emb

            # Getting the target mask to avoid looking ahead
            tgt_mask = self.hparams.lookahead_mask(dec_mel_spec).to(self.device)

            # Getting the source mask
            src_mask = torch.zeros(enc_phoneme_emb.shape[1], enc_phoneme_emb.shape[1]).to(self.device)

            # Padding masks for source and targets
            src_key_padding_mask = self.hparams.padding_mask(enc_phoneme_emb, pad_idx = self.hparams.blank_index).to(self.device)
            tgt_key_padding_mask = self.hparams.padding_mask(dec_mel_spec, pad_idx = self.hparams.blank_index).to(self.device)


            # Running the Seq2Seq Transformer
            decoder_outputs = self.modules['Seq2SeqTransformer'](src = enc_phoneme_emb, tgt = dec_mel_spec, src_mask = src_mask, tgt_mask = tgt_mask,
                                                              src_key_padding_mask = src_key_padding_mask, tgt_key_padding_mask = tgt_key_padding_mask)

            # Mel Linears
            mel_linears =  self.modules['mel_lin'](decoder_outputs).permute(0,2,1)
            mel_postnet = self.modules['decoder_postnet'](mel_linears) # mel tensor output
            mel_pred = mel_linears + mel_postnet # mel tensor output

            stop_token_pred =  self.modules['stop_lin'](decoder_outputs).squeeze(-1)

            stop_condition_list = self.check_stop_condition(stop_token_pred)


            # update the values of main stop conditions
            stop_condition_update = [True if stop_condition_list[i] else stop_condition[i] for i in range(len(stop_condition))]
            stop_condition = stop_condition_update


            # Prepare input for the transformer input for next iteration
            current_output = mel_pred[:, :, -1:]

            decoder_input=torch.cat([decoder_input,current_output],dim=2)
            num_itr = num_itr+1

        mel_outputs =  decoder_input[:, :, 1:]

        return mel_outputs



    def encode_text(self, text):
        """Runs inference for a single text str"""
        return self.encode_batch([text])


    def forward(self, text_list):
        "Encodes the input texts."
        return self.encode_batch(text_list)


    def check_stop_condition(self, stop_token_pred):
        """
        check if stop token / EOS reached or not for mel_specs in the batch
        """

        # Applying sigmoid to perform binary classification
        sigmoid_output = torch.sigmoid(stop_token_pred)
        # Checking if the probability is greater than 0.5
        stop_results = sigmoid_output > 0.8
        stop_output = [all(result) for result in stop_results]

        return stop_output



    def custom_clean(self, text):
        """
        Uses custom criteria to clean text.

        Arguments
        ---------
        text : str
            Input text to be cleaned
        model_name : str
            whether to treat punctuations

        Returns
        -------
        text : str
            Cleaned text
        """

        _abbreviations = [
            (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
            for x in [
                ("mrs", "missus"),
                ("mr", "mister"),
                ("dr", "doctor"),
                ("st", "saint"),
                ("co", "company"),
                ("jr", "junior"),
                ("maj", "major"),
                ("gen", "general"),
                ("drs", "doctors"),
                ("rev", "reverend"),
                ("lt", "lieutenant"),
                ("hon", "honorable"),
                ("sgt", "sergeant"),
                ("capt", "captain"),
                ("esq", "esquire"),
                ("ltd", "limited"),
                ("col", "colonel"),
                ("ft", "fort"),
            ]
        ]

        text = re.sub(" +", " ", text)

        for regex, replacement in _abbreviations:
            text = re.sub(regex, replacement, text)
        return text