import os import torch import torchaudio import whisper import onnxruntime import numpy as np import torchaudio.compliance.kaldi as kaldi from typing import Callable, List, Union from functools import partial from loguru import logger from VietTTS.utils.frontend_utils import split_text, normalize_text, mel_spectrogram from VietTTS.tokenizer.tokenizer import get_tokenizer class TTSFrontEnd: def __init__( self, speech_embedding_model: str, speech_tokenizer_model: str, ): self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.tokenizer = get_tokenizer() option = onnxruntime.SessionOptions() option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL option.intra_op_num_threads = 1 self.speech_embedding_session = onnxruntime.InferenceSession( speech_embedding_model, sess_options=option, providers=["CPUExecutionProvider"] ) self.speech_tokenizer_session = onnxruntime.InferenceSession( speech_tokenizer_model, sess_options=option, providers=["CUDAExecutionProvider" if torch.cuda.is_available() else "CPUExecutionProvider"] ) self.spk2info = {} def _extract_text_token(self, text: str): text_token = self.tokenizer.encode(text, allowed_special='all') text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device) text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device) return text_token, text_token_len def _extract_speech_token(self, speech: torch.Tensor): if speech.shape[1] / 16000 > 30: speech = speech[:, :int(16000 * 30)] feat = whisper.log_mel_spectrogram(speech, n_mels=128) speech_token = self.speech_tokenizer_session.run( None, {self.speech_tokenizer_session.get_inputs()[0].name: feat.detach().cpu().numpy(), self.speech_tokenizer_session.get_inputs()[1].name: np.array([feat.shape[2]], dtype=np.int32)} )[0].flatten().tolist() speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device) speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device) return speech_token, speech_token_len def _extract_spk_embedding(self, speech: torch.Tensor): feat = kaldi.fbank( waveform=speech, num_mel_bins=80, dither=0, sample_frequency=16000 ) feat = feat - feat.mean(dim=0, keepdim=True) embedding = self.speech_embedding_session.run( None, {self.speech_embedding_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()} )[0].flatten().tolist() embedding = torch.tensor([embedding]).to(self.device) return embedding def _extract_speech_feat(self, speech: torch.Tensor): speech_feat = mel_spectrogram( y=speech, n_fft=1024, num_mels=80, sampling_rate=22050, hop_size=256, win_size=1024, fmin=0, fmax=8000, center=False ).squeeze(dim=0).transpose(0, 1).to(self.device) speech_feat = speech_feat.unsqueeze(dim=0) speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to(self.device) return speech_feat, speech_feat_len def preprocess_text(self, text, split=True) -> Union[str, List[str]]: text = normalize_text(text) if split: text = list(split_text( text=text, tokenize=partial(self.tokenizer.encode, allowed_special='all'), token_max_n=30, token_min_n=10, merge_len=5, comma_split=False )) return text def frontend_tts( self, text: str, prompt_speech_16k: Union[np.ndarray, torch.Tensor] ) -> dict: if isinstance(prompt_speech_16k, np.ndarray): prompt_speech_16k = torch.from_numpy(prompt_speech_16k) text_token, text_token_len = self._extract_text_token(text) speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k) prompt_speech_22050 = torchaudio.transforms.Resample(orig_freq=16000, new_freq=22050)(prompt_speech_16k) speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_22050) embedding = self._extract_spk_embedding(prompt_speech_16k) model_input = { 'text': text_token, 'text_len': text_token_len, 'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len, 'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len, 'llm_embedding': embedding, 'flow_embedding': embedding } return model_input def frontend_vc( self, source_speech_16k: Union[np.ndarray, torch.Tensor], prompt_speech_16k: Union[np.ndarray, torch.Tensor] ) -> dict: if isinstance(source_speech_16k, np.ndarray): source_speech_16k = torch.from_numpy(source_speech_16k) if isinstance(prompt_speech_16k, np.ndarray): prompt_speech_16k = torch.from_numpy(prompt_speech_16k) prompt_speech_token, prompt_speech_token_len = self._extract_speech_token(prompt_speech_16k) prompt_speech_22050 = torchaudio.transforms.Resample(orig_freq=16000, new_freq=22050)(prompt_speech_16k) prompt_speech_feat, prompt_speech_feat_len = self._extract_speech_feat(prompt_speech_22050) embedding = self._extract_spk_embedding(prompt_speech_16k) source_speech_token, source_speech_token_len = self._extract_speech_token(source_speech_16k) model_input = { 'source_speech_token': source_speech_token, 'source_speech_token_len': source_speech_token_len, 'flow_prompt_speech_token': prompt_speech_token, 'flow_prompt_speech_token_len': prompt_speech_token_len, 'prompt_speech_feat': prompt_speech_feat, 'prompt_speech_feat_len': prompt_speech_feat_len, 'flow_embedding': embedding } return model_input