File size: 2,905 Bytes
a257816
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import os
import numpy as np
from tqdm import tqdm
from loguru import logger
from hyperpyyaml import load_hyperpyyaml

from VietTTS.model import TTSModel
from VietTTS.frontend import TTSFrontEnd
from VietTTS.utils.file_utils import download_model, save_wav


class TTS:
    def __init__(self, model_dir, load_jit=False, load_onnx=False):
        if not os.path.exists(model_dir):
            logger.info(f"Downloading model from huggingface [dangvansam/viet-tts]")
            download_model(model_dir)

        with open(f"{model_dir}/config.yaml", "r") as f:
            configs = load_hyperpyyaml(f)
        self.frontend = TTSFrontEnd(
            speech_embedding_model=f"{model_dir}/speech_embedding.onnx", speech_tokenizer_model=f"{model_dir}/speech_tokenizer.onnx"
        )
        self.model = TTSModel(llm=configs["llm"], flow=configs["flow"], hift=configs["hift"])
        self.model.load(llm_model=f"{model_dir}/llm.pt", flow_model=f"{model_dir}/flow.pt", hift_model=f"{model_dir}/hift.pt")
        if load_jit:
            self.model.load_jit(
                "{}/llm.text_encoder.fp16.zip".format(model_dir),
                "{}/llm.llm.fp16.zip".format(model_dir),
                "{}/flow.encoder.fp32.zip".format(model_dir),
            )
            logger.success("Loaded jit model from {}".format(model_dir))

        if load_onnx:
            self.model.load_onnx("{}/flow.decoder.estimator.fp32.onnx".format(model_dir))
            logger.success("Loaded onnx model from {}".format(model_dir))

        logger.success("Loaded model from {}".format(model_dir))
        self.model_dir = model_dir

    def list_avaliable_spks(self):
        spks = list(self.frontend.spk2info.keys())
        return spks

    def inference_tts(self, tts_text, prompt_speech_16k, stream=False, speed=1.0):
        for i in tqdm(self.frontend.preprocess_text(tts_text, split=True)):
            model_input = self.frontend.frontend_tts(i, prompt_speech_16k)
            for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
                yield model_output

    def inference_vc(self, source_speech_16k, prompt_speech_16k, stream=False, speed=1.0):
        model_input = self.frontend.frontend_vc(source_speech_16k, prompt_speech_16k)
        for model_output in self.model.vc(**model_input, stream=stream, speed=speed):
            yield model_output

    def tts_to_wav(self, text, prompt_speech_16k, speed=1.0):
        wavs = []
        for output in self.inference_tts(text, prompt_speech_16k, stream=False, speed=speed):
            wavs.append(output["tts_speech"].squeeze(0).numpy())
        return np.concatenate(wavs, axis=0)

    def tts_to_file(self, text, prompt_speech_16k, speed, output_path):
        wav = self.tts_to_wav(text, prompt_speech_16k, speed)
        save_wav(wav, 22050, output_path)