duyv's picture
Upload 86 files
a257816 verified
import os
import numpy as np
from tqdm import tqdm
from loguru import logger
from hyperpyyaml import load_hyperpyyaml
from VietTTS.model import TTSModel
from VietTTS.frontend import TTSFrontEnd
from VietTTS.utils.file_utils import download_model, save_wav
class TTS:
def __init__(self, model_dir, load_jit=False, load_onnx=False):
if not os.path.exists(model_dir):
logger.info(f"Downloading model from huggingface [dangvansam/viet-tts]")
download_model(model_dir)
with open(f"{model_dir}/config.yaml", "r") as f:
configs = load_hyperpyyaml(f)
self.frontend = TTSFrontEnd(
speech_embedding_model=f"{model_dir}/speech_embedding.onnx", speech_tokenizer_model=f"{model_dir}/speech_tokenizer.onnx"
)
self.model = TTSModel(llm=configs["llm"], flow=configs["flow"], hift=configs["hift"])
self.model.load(llm_model=f"{model_dir}/llm.pt", flow_model=f"{model_dir}/flow.pt", hift_model=f"{model_dir}/hift.pt")
if load_jit:
self.model.load_jit(
"{}/llm.text_encoder.fp16.zip".format(model_dir),
"{}/llm.llm.fp16.zip".format(model_dir),
"{}/flow.encoder.fp32.zip".format(model_dir),
)
logger.success("Loaded jit model from {}".format(model_dir))
if load_onnx:
self.model.load_onnx("{}/flow.decoder.estimator.fp32.onnx".format(model_dir))
logger.success("Loaded onnx model from {}".format(model_dir))
logger.success("Loaded model from {}".format(model_dir))
self.model_dir = model_dir
def list_avaliable_spks(self):
spks = list(self.frontend.spk2info.keys())
return spks
def inference_tts(self, tts_text, prompt_speech_16k, stream=False, speed=1.0):
for i in tqdm(self.frontend.preprocess_text(tts_text, split=True)):
model_input = self.frontend.frontend_tts(i, prompt_speech_16k)
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
yield model_output
def inference_vc(self, source_speech_16k, prompt_speech_16k, stream=False, speed=1.0):
model_input = self.frontend.frontend_vc(source_speech_16k, prompt_speech_16k)
for model_output in self.model.vc(**model_input, stream=stream, speed=speed):
yield model_output
def tts_to_wav(self, text, prompt_speech_16k, speed=1.0):
wavs = []
for output in self.inference_tts(text, prompt_speech_16k, stream=False, speed=speed):
wavs.append(output["tts_speech"].squeeze(0).numpy())
return np.concatenate(wavs, axis=0)
def tts_to_file(self, text, prompt_speech_16k, speed, output_path):
wav = self.tts_to_wav(text, prompt_speech_16k, speed)
save_wav(wav, 22050, output_path)