TTSDemoApp / TTSInferencing.py
myhanhhyugen's picture
initial commits
dc9eaa3 verified
import re
import logging
import torch
import torchaudio
import random
import speechbrain
from speechbrain.inference.interfaces import Pretrained
from speechbrain.inference.text import GraphemeToPhoneme
logger = logging.getLogger(__name__)
class TTSInferencing(Pretrained):
"""
A ready-to-use wrapper for TTS (text -> mel_spec).
Arguments
---------
hparams
Hyperparameters (from HyperPyYAML)
"""
HPARAMS_NEEDED = ["modules", "input_encoder"]
MODULES_NEEDED = ["encoder_prenet", "pos_emb_enc",
"decoder_prenet", "pos_emb_dec",
"Seq2SeqTransformer", "mel_lin",
"stop_lin", "decoder_postnet"]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
lexicon = self.hparams.lexicon
lexicon = ["@@"] + lexicon
self.input_encoder = self.hparams.input_encoder
self.input_encoder.update_from_iterable(lexicon, sequence_input=False)
self.input_encoder.add_unk()
self.modules = self.hparams.modules
self.g2p = GraphemeToPhoneme.from_hparams("speechbrain/soundchoice-g2p")
def generate_padded_phonemes(self, texts):
"""Computes mel-spectrogram for a list of texts
Arguments
---------
texts: List[str]
texts to be converted to spectrogram
Returns
-------
tensors of output spectrograms
"""
# Preprocessing required at the inference time for the input text
# "label" below contains input text
# "phoneme_labels" contain the phoneme sequences corresponding to input text labels
phoneme_labels = list()
for label in texts:
phoneme_label = list()
label = self.custom_clean(label).upper()
words = label.split()
words = [word.strip() for word in words]
words_phonemes = self.g2p(words)
for i in range(len(words_phonemes)):
words_phonemes_seq = words_phonemes[i]
for phoneme in words_phonemes_seq:
if not phoneme.isspace():
phoneme_label.append(phoneme)
phoneme_labels.append(phoneme_label)
# encode the phonemes with input text encoder
encoded_phonemes = list()
for i in range(len(phoneme_labels)):
phoneme_label = phoneme_labels[i]
encoded_phoneme = torch.LongTensor(self.input_encoder.encode_sequence(phoneme_label)).to(self.device)
encoded_phonemes.append(encoded_phoneme)
# Right zero-pad all one-hot text sequences to max input length
input_lengths, ids_sorted_decreasing = torch.sort(
torch.LongTensor([len(x) for x in encoded_phonemes]), dim=0, descending=True
)
max_input_len = input_lengths[0]
phoneme_padded = torch.LongTensor(len(encoded_phonemes), max_input_len).to(self.device)
phoneme_padded.zero_()
for seq_idx, seq in enumerate(encoded_phonemes):
phoneme_padded[seq_idx, : len(seq)] = seq
return phoneme_padded.to(self.device, non_blocking=True).float()
def encode_batch(self, texts):
"""Computes mel-spectrogram for a list of texts
Texts must be sorted in decreasing order on their lengths
Arguments
---------
texts: List[str]
texts to be encoded into spectrogram
Returns
-------
tensors of output spectrograms
"""
# generate phonemes and padd the input texts
encoded_phoneme_padded = self.generate_padded_phonemes(texts)
phoneme_prenet_emb = self.modules['encoder_prenet'](encoded_phoneme_padded)
# Positional Embeddings
phoneme_pos_emb = self.modules['pos_emb_enc'](encoded_phoneme_padded)
# Summing up embeddings
enc_phoneme_emb = phoneme_prenet_emb.permute(0,2,1) + phoneme_pos_emb
enc_phoneme_emb = enc_phoneme_emb.to(self.device)
with torch.no_grad():
# generate sequential predictions via transformer decoder
start_token = torch.full((80, 1), fill_value= 0)
start_token[1] = 2
decoder_input = start_token.repeat(enc_phoneme_emb.size(0), 1, 1)
decoder_input = decoder_input.to(self.device, non_blocking=True).float()
num_itr = 0
stop_condition = [False] * decoder_input.size(0)
max_iter = 100
# while not all(stop_condition) and num_itr < max_iter:
while num_itr < max_iter:
# Decoder Prenet
mel_prenet_emb = self.modules['decoder_prenet'](decoder_input).to(self.device).permute(0,2,1)
# Positional Embeddings
mel_pos_emb = self.modules['pos_emb_dec'](mel_prenet_emb).to(self.device)
# Summing up Embeddings
dec_mel_spec = mel_prenet_emb + mel_pos_emb
# Getting the target mask to avoid looking ahead
tgt_mask = self.hparams.lookahead_mask(dec_mel_spec).to(self.device)
# Getting the source mask
src_mask = torch.zeros(enc_phoneme_emb.shape[1], enc_phoneme_emb.shape[1]).to(self.device)
# Padding masks for source and targets
src_key_padding_mask = self.hparams.padding_mask(enc_phoneme_emb, pad_idx = self.hparams.blank_index).to(self.device)
tgt_key_padding_mask = self.hparams.padding_mask(dec_mel_spec, pad_idx = self.hparams.blank_index).to(self.device)
# Running the Seq2Seq Transformer
decoder_outputs = self.modules['Seq2SeqTransformer'](src = enc_phoneme_emb, tgt = dec_mel_spec, src_mask = src_mask, tgt_mask = tgt_mask,
src_key_padding_mask = src_key_padding_mask, tgt_key_padding_mask = tgt_key_padding_mask)
# Mel Linears
mel_linears = self.modules['mel_lin'](decoder_outputs).permute(0,2,1)
mel_postnet = self.modules['decoder_postnet'](mel_linears) # mel tensor output
mel_pred = mel_linears + mel_postnet # mel tensor output
stop_token_pred = self.modules['stop_lin'](decoder_outputs).squeeze(-1)
stop_condition_list = self.check_stop_condition(stop_token_pred)
# update the values of main stop conditions
stop_condition_update = [True if stop_condition_list[i] else stop_condition[i] for i in range(len(stop_condition))]
stop_condition = stop_condition_update
# Prepare input for the transformer input for next iteration
current_output = mel_pred[:, :, -1:]
decoder_input=torch.cat([decoder_input,current_output],dim=2)
num_itr = num_itr+1
mel_outputs = decoder_input[:, :, 1:]
return mel_outputs
def encode_text(self, text):
"""Runs inference for a single text str"""
return self.encode_batch([text])
def forward(self, text_list):
"Encodes the input texts."
return self.encode_batch(text_list)
def check_stop_condition(self, stop_token_pred):
"""
check if stop token / EOS reached or not for mel_specs in the batch
"""
# Applying sigmoid to perform binary classification
sigmoid_output = torch.sigmoid(stop_token_pred)
# Checking if the probability is greater than 0.5
stop_results = sigmoid_output > 0.8
stop_output = [all(result) for result in stop_results]
return stop_output
def custom_clean(self, text):
"""
Uses custom criteria to clean text.
Arguments
---------
text : str
Input text to be cleaned
model_name : str
whether to treat punctuations
Returns
-------
text : str
Cleaned text
"""
_abbreviations = [
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
for x in [
("mrs", "missus"),
("mr", "mister"),
("dr", "doctor"),
("st", "saint"),
("co", "company"),
("jr", "junior"),
("maj", "major"),
("gen", "general"),
("drs", "doctors"),
("rev", "reverend"),
("lt", "lieutenant"),
("hon", "honorable"),
("sgt", "sergeant"),
("capt", "captain"),
("esq", "esquire"),
("ltd", "limited"),
("col", "colonel"),
("ft", "fort"),
]
]
text = re.sub(" +", " ", text)
for regex, replacement in _abbreviations:
text = re.sub(regex, replacement, text)
return text