|
import os
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
from loguru import logger
|
|
from hyperpyyaml import load_hyperpyyaml
|
|
|
|
from VietTTS.model import TTSModel
|
|
from VietTTS.frontend import TTSFrontEnd
|
|
from VietTTS.utils.file_utils import download_model, save_wav
|
|
|
|
|
|
class TTS:
|
|
def __init__(self, model_dir, load_jit=False, load_onnx=False):
|
|
if not os.path.exists(model_dir):
|
|
logger.info(f"Downloading model from huggingface [dangvansam/viet-tts]")
|
|
download_model(model_dir)
|
|
|
|
with open(f"{model_dir}/config.yaml", "r") as f:
|
|
configs = load_hyperpyyaml(f)
|
|
self.frontend = TTSFrontEnd(
|
|
speech_embedding_model=f"{model_dir}/speech_embedding.onnx", speech_tokenizer_model=f"{model_dir}/speech_tokenizer.onnx"
|
|
)
|
|
self.model = TTSModel(llm=configs["llm"], flow=configs["flow"], hift=configs["hift"])
|
|
self.model.load(llm_model=f"{model_dir}/llm.pt", flow_model=f"{model_dir}/flow.pt", hift_model=f"{model_dir}/hift.pt")
|
|
if load_jit:
|
|
self.model.load_jit(
|
|
"{}/llm.text_encoder.fp16.zip".format(model_dir),
|
|
"{}/llm.llm.fp16.zip".format(model_dir),
|
|
"{}/flow.encoder.fp32.zip".format(model_dir),
|
|
)
|
|
logger.success("Loaded jit model from {}".format(model_dir))
|
|
|
|
if load_onnx:
|
|
self.model.load_onnx("{}/flow.decoder.estimator.fp32.onnx".format(model_dir))
|
|
logger.success("Loaded onnx model from {}".format(model_dir))
|
|
|
|
logger.success("Loaded model from {}".format(model_dir))
|
|
self.model_dir = model_dir
|
|
|
|
def list_avaliable_spks(self):
|
|
spks = list(self.frontend.spk2info.keys())
|
|
return spks
|
|
|
|
def inference_tts(self, tts_text, prompt_speech_16k, stream=False, speed=1.0):
|
|
for i in tqdm(self.frontend.preprocess_text(tts_text, split=True)):
|
|
model_input = self.frontend.frontend_tts(i, prompt_speech_16k)
|
|
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
|
yield model_output
|
|
|
|
def inference_vc(self, source_speech_16k, prompt_speech_16k, stream=False, speed=1.0):
|
|
model_input = self.frontend.frontend_vc(source_speech_16k, prompt_speech_16k)
|
|
for model_output in self.model.vc(**model_input, stream=stream, speed=speed):
|
|
yield model_output
|
|
|
|
def tts_to_wav(self, text, prompt_speech_16k, speed=1.0):
|
|
wavs = []
|
|
for output in self.inference_tts(text, prompt_speech_16k, stream=False, speed=speed):
|
|
wavs.append(output["tts_speech"].squeeze(0).numpy())
|
|
return np.concatenate(wavs, axis=0)
|
|
|
|
def tts_to_file(self, text, prompt_speech_16k, speed, output_path):
|
|
wav = self.tts_to_wav(text, prompt_speech_16k, speed)
|
|
save_wav(wav, 22050, output_path)
|
|
|