Spaces:
Paused
Paused
import re | |
import logging | |
import torch | |
import torchaudio | |
import random | |
import speechbrain | |
from speechbrain.inference.interfaces import Pretrained | |
from speechbrain.inference.text import GraphemeToPhoneme | |
logger = logging.getLogger(__name__) | |
class TTSInferencing(Pretrained): | |
""" | |
A ready-to-use wrapper for TTS (text -> mel_spec). | |
Arguments | |
--------- | |
hparams | |
Hyperparameters (from HyperPyYAML) | |
""" | |
HPARAMS_NEEDED = ["modules", "input_encoder"] | |
MODULES_NEEDED = ["encoder_prenet", "pos_emb_enc", | |
"decoder_prenet", "pos_emb_dec", | |
"Seq2SeqTransformer", "mel_lin", | |
"stop_lin", "decoder_postnet"] | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
lexicon = self.hparams.lexicon | |
lexicon = ["@@"] + lexicon | |
self.input_encoder = self.hparams.input_encoder | |
self.input_encoder.update_from_iterable(lexicon, sequence_input=False) | |
self.input_encoder.add_unk() | |
self.modules = self.hparams.modules | |
self.g2p = GraphemeToPhoneme.from_hparams("speechbrain/soundchoice-g2p") | |
def generate_padded_phonemes(self, texts): | |
"""Computes mel-spectrogram for a list of texts | |
Arguments | |
--------- | |
texts: List[str] | |
texts to be converted to spectrogram | |
Returns | |
------- | |
tensors of output spectrograms | |
""" | |
# Preprocessing required at the inference time for the input text | |
# "label" below contains input text | |
# "phoneme_labels" contain the phoneme sequences corresponding to input text labels | |
phoneme_labels = list() | |
for label in texts: | |
phoneme_label = list() | |
label = self.custom_clean(label).upper() | |
words = label.split() | |
words = [word.strip() for word in words] | |
words_phonemes = self.g2p(words) | |
for i in range(len(words_phonemes)): | |
words_phonemes_seq = words_phonemes[i] | |
for phoneme in words_phonemes_seq: | |
if not phoneme.isspace(): | |
phoneme_label.append(phoneme) | |
phoneme_labels.append(phoneme_label) | |
# encode the phonemes with input text encoder | |
encoded_phonemes = list() | |
for i in range(len(phoneme_labels)): | |
phoneme_label = phoneme_labels[i] | |
encoded_phoneme = torch.LongTensor(self.input_encoder.encode_sequence(phoneme_label)).to(self.device) | |
encoded_phonemes.append(encoded_phoneme) | |
# Right zero-pad all one-hot text sequences to max input length | |
input_lengths, ids_sorted_decreasing = torch.sort( | |
torch.LongTensor([len(x) for x in encoded_phonemes]), dim=0, descending=True | |
) | |
max_input_len = input_lengths[0] | |
phoneme_padded = torch.LongTensor(len(encoded_phonemes), max_input_len).to(self.device) | |
phoneme_padded.zero_() | |
for seq_idx, seq in enumerate(encoded_phonemes): | |
phoneme_padded[seq_idx, : len(seq)] = seq | |
return phoneme_padded.to(self.device, non_blocking=True).float() | |
def encode_batch(self, texts): | |
"""Computes mel-spectrogram for a list of texts | |
Texts must be sorted in decreasing order on their lengths | |
Arguments | |
--------- | |
texts: List[str] | |
texts to be encoded into spectrogram | |
Returns | |
------- | |
tensors of output spectrograms | |
""" | |
# generate phonemes and padd the input texts | |
encoded_phoneme_padded = self.generate_padded_phonemes(texts) | |
phoneme_prenet_emb = self.modules['encoder_prenet'](encoded_phoneme_padded) | |
# Positional Embeddings | |
phoneme_pos_emb = self.modules['pos_emb_enc'](encoded_phoneme_padded) | |
# Summing up embeddings | |
enc_phoneme_emb = phoneme_prenet_emb.permute(0,2,1) + phoneme_pos_emb | |
enc_phoneme_emb = enc_phoneme_emb.to(self.device) | |
with torch.no_grad(): | |
# generate sequential predictions via transformer decoder | |
start_token = torch.full((80, 1), fill_value= 0) | |
start_token[1] = 2 | |
decoder_input = start_token.repeat(enc_phoneme_emb.size(0), 1, 1) | |
decoder_input = decoder_input.to(self.device, non_blocking=True).float() | |
num_itr = 0 | |
stop_condition = [False] * decoder_input.size(0) | |
max_iter = 100 | |
# while not all(stop_condition) and num_itr < max_iter: | |
while num_itr < max_iter: | |
# Decoder Prenet | |
mel_prenet_emb = self.modules['decoder_prenet'](decoder_input).to(self.device).permute(0,2,1) | |
# Positional Embeddings | |
mel_pos_emb = self.modules['pos_emb_dec'](mel_prenet_emb).to(self.device) | |
# Summing up Embeddings | |
dec_mel_spec = mel_prenet_emb + mel_pos_emb | |
# Getting the target mask to avoid looking ahead | |
tgt_mask = self.hparams.lookahead_mask(dec_mel_spec).to(self.device) | |
# Getting the source mask | |
src_mask = torch.zeros(enc_phoneme_emb.shape[1], enc_phoneme_emb.shape[1]).to(self.device) | |
# Padding masks for source and targets | |
src_key_padding_mask = self.hparams.padding_mask(enc_phoneme_emb, pad_idx = self.hparams.blank_index).to(self.device) | |
tgt_key_padding_mask = self.hparams.padding_mask(dec_mel_spec, pad_idx = self.hparams.blank_index).to(self.device) | |
# Running the Seq2Seq Transformer | |
decoder_outputs = self.modules['Seq2SeqTransformer'](src = enc_phoneme_emb, tgt = dec_mel_spec, src_mask = src_mask, tgt_mask = tgt_mask, | |
src_key_padding_mask = src_key_padding_mask, tgt_key_padding_mask = tgt_key_padding_mask) | |
# Mel Linears | |
mel_linears = self.modules['mel_lin'](decoder_outputs).permute(0,2,1) | |
mel_postnet = self.modules['decoder_postnet'](mel_linears) # mel tensor output | |
mel_pred = mel_linears + mel_postnet # mel tensor output | |
stop_token_pred = self.modules['stop_lin'](decoder_outputs).squeeze(-1) | |
stop_condition_list = self.check_stop_condition(stop_token_pred) | |
# update the values of main stop conditions | |
stop_condition_update = [True if stop_condition_list[i] else stop_condition[i] for i in range(len(stop_condition))] | |
stop_condition = stop_condition_update | |
# Prepare input for the transformer input for next iteration | |
current_output = mel_pred[:, :, -1:] | |
decoder_input=torch.cat([decoder_input,current_output],dim=2) | |
num_itr = num_itr+1 | |
mel_outputs = decoder_input[:, :, 1:] | |
return mel_outputs | |
def encode_text(self, text): | |
"""Runs inference for a single text str""" | |
return self.encode_batch([text]) | |
def forward(self, text_list): | |
"Encodes the input texts." | |
return self.encode_batch(text_list) | |
def check_stop_condition(self, stop_token_pred): | |
""" | |
check if stop token / EOS reached or not for mel_specs in the batch | |
""" | |
# Applying sigmoid to perform binary classification | |
sigmoid_output = torch.sigmoid(stop_token_pred) | |
# Checking if the probability is greater than 0.5 | |
stop_results = sigmoid_output > 0.8 | |
stop_output = [all(result) for result in stop_results] | |
return stop_output | |
def custom_clean(self, text): | |
""" | |
Uses custom criteria to clean text. | |
Arguments | |
--------- | |
text : str | |
Input text to be cleaned | |
model_name : str | |
whether to treat punctuations | |
Returns | |
------- | |
text : str | |
Cleaned text | |
""" | |
_abbreviations = [ | |
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) | |
for x in [ | |
("mrs", "missus"), | |
("mr", "mister"), | |
("dr", "doctor"), | |
("st", "saint"), | |
("co", "company"), | |
("jr", "junior"), | |
("maj", "major"), | |
("gen", "general"), | |
("drs", "doctors"), | |
("rev", "reverend"), | |
("lt", "lieutenant"), | |
("hon", "honorable"), | |
("sgt", "sergeant"), | |
("capt", "captain"), | |
("esq", "esquire"), | |
("ltd", "limited"), | |
("col", "colonel"), | |
("ft", "fort"), | |
] | |
] | |
text = re.sub(" +", " ", text) | |
for regex, replacement in _abbreviations: | |
text = re.sub(regex, replacement, text) | |
return text | |