import os from pathlib import Path from loguru import logger # from app import CONFIG_URL, MODEL_URL from app.util import get_hparams_from_file, get_paths, time_it import requests from tqdm.auto import tqdm import re from re import Pattern import onnxruntime as ort import threading MODEL_URL = r"https://api.onedrive.com/v1.0/shares/u!aHR0cHM6Ly8xZHJ2Lm1zL3UvcyFBdG53cTVRejJnLTJmckZWcGdCR0xxLWJmU28/root/content" CONFIG_URL = r"https://api.onedrive.com/v1.0/shares/u!aHR0cHM6Ly8xZHJ2Lm1zL3UvcyFBdG53cTVRejJnLTJhNEJ3enhhUHpqNE5EZWc/root/content" class Config: hps: dict = None pattern: Pattern = None # symbol_to_id:dict = None speaker_choices: list = None ort_sess: ort.InferenceSession = None model_is_ok: bool = False @classmethod def init(cls): # logger.add( # "vits_infer.log", rotation="10 MB", encoding="utf-8", enqueue=True, retention="30 days" # ) brackets = ['(', '[', '『', '「', '【', ")", "】", "]", "』", "」", ")"] cls.pattern = re.compile('|'.join(map(re.escape, brackets))) dir_path = Path(__file__).parent.absolute() / ".model" dir_path.mkdir( parents=True, exist_ok=True ) model_path, config_path = get_paths(dir_path) if not model_path or not config_path: model_path = dir_path / "model.onnx" config_path = dir_path / "config.json" logger.warning( "unable to find model or config, try to download default model and config" ) cfg = requests.get(CONFIG_URL, timeout=5).content with open(str(config_path), 'wb') as f: f.write(cfg) cls.setup_config(str(config_path)) t = threading.Thread(target=cls.pdownload, args=(MODEL_URL, str(model_path))) t.start() # cls.pdownload(MODEL_URL, str(model_path)) else: cls.setup_config(str(config_path)) cls.setup_model(str(model_path)) @classmethod @logger.catch @time_it def setup_model(cls, model_path: str): import numpy as np cls.ort_sess = ort.InferenceSession(model_path) # init the model seq = np.random.randint(low=0, high=len( cls.hps.symbols), size=(1, 10), dtype=np.int64) # seq_len = torch.IntTensor([seq.size(1)]).long() seq_len = np.array([seq.shape[1]], dtype=np.int64) # noise(可用于控制感情等变化程度) lenth(可用于控制整体语速) noisew(控制音素发音长度变化程度) # 参考 https://github.com/gbxh/genshinTTS # scales = torch.FloatTensor([0.667, 1.0, 0.8]) scales = np.array([0.667, 1.0, 0.8], dtype=np.float32) # make triton dynamic shape happy # scales = scales.unsqueeze(0) scales.resize(1, 3) # sid = torch.IntTensor([0]).long() sid = np.array([0], dtype=np.int64) # sid = torch.LongTensor([0]) ort_inputs = { 'input': seq, 'input_lengths': seq_len, 'scales': scales, 'sid': sid } cls.ort_sess.run(None, ort_inputs) cls.model_is_ok = True logger.info( f"model init done with model path {model_path}" ) @classmethod def setup_config(cls, config_path: str): cls.hps = get_hparams_from_file(config_path) cls.speaker_choices = list( map(lambda x: str(x[0])+":"+x[1], enumerate(cls.hps.speakers))) logger.info( f"config init done with config path {config_path}" ) @classmethod def pdownload(cls, url: str, save_path: str, chunk_size: int = 8192): # copy from https://github.com/tqdm/tqdm/blob/master/examples/tqdm_requests.py file_size = int(requests.head(url).headers["Content-Length"]) response = requests.get(url, stream=True) with tqdm(total=file_size, unit='B', unit_scale=True, unit_divisor=1024, miniters=1, desc="model download") as pbar: with open(save_path, 'wb') as f: for chunk in response.iter_content(chunk_size=chunk_size): if chunk: f.write(chunk) pbar.update(chunk_size) cls.setup_model(save_path)