|
import os |
|
from pathlib import Path |
|
|
|
from loguru import logger |
|
|
|
from app.util import get_hparams_from_file, get_paths, time_it |
|
import requests |
|
from tqdm.auto import tqdm |
|
import re |
|
from re import Pattern |
|
import onnxruntime as ort |
|
import threading |
|
|
|
|
|
MODEL_URL = r"https://api.onedrive.com/v1.0/shares/u!aHR0cHM6Ly8xZHJ2Lm1zL3UvcyFBdG53cTVRejJnLTJmckZWcGdCR0xxLWJmU28/root/content" |
|
CONFIG_URL = r"https://api.onedrive.com/v1.0/shares/u!aHR0cHM6Ly8xZHJ2Lm1zL3UvcyFBdG53cTVRejJnLTJhNEJ3enhhUHpqNE5EZWc/root/content" |
|
|
|
|
|
|
|
class Config: |
|
hps: dict = None |
|
pattern: Pattern = None |
|
|
|
speaker_choices: list = None |
|
ort_sess: ort.InferenceSession = None |
|
model_is_ok: bool = False |
|
|
|
@classmethod |
|
def init(cls): |
|
|
|
|
|
|
|
|
|
|
|
brackets = ['(', '[', '『', '「', '【', ")", "】", "]", "』", "」", ")"] |
|
cls.pattern = re.compile('|'.join(map(re.escape, brackets))) |
|
|
|
dir_path = Path(__file__).parent.absolute() / ".model" |
|
dir_path.mkdir( |
|
parents=True, exist_ok=True |
|
) |
|
model_path, config_path = get_paths(dir_path) |
|
|
|
if not model_path or not config_path: |
|
model_path = dir_path / "model.onnx" |
|
config_path = dir_path / "config.json" |
|
logger.warning( |
|
"unable to find model or config, try to download default model and config" |
|
) |
|
cfg = requests.get(CONFIG_URL, timeout=5).content |
|
with open(str(config_path), 'wb') as f: |
|
f.write(cfg) |
|
cls.setup_config(str(config_path)) |
|
t = threading.Thread(target=cls.pdownload, |
|
args=(MODEL_URL, str(model_path))) |
|
t.start() |
|
|
|
|
|
else: |
|
cls.setup_config(str(config_path)) |
|
cls.setup_model(str(model_path)) |
|
|
|
@classmethod |
|
@logger.catch |
|
@time_it |
|
def setup_model(cls, model_path: str): |
|
import numpy as np |
|
cls.ort_sess = ort.InferenceSession(model_path) |
|
|
|
seq = np.random.randint(low=0, high=len( |
|
cls.hps.symbols), size=(1, 10), dtype=np.int64) |
|
|
|
|
|
seq_len = np.array([seq.shape[1]], dtype=np.int64) |
|
|
|
|
|
|
|
|
|
scales = np.array([0.667, 1.0, 0.8], dtype=np.float32) |
|
|
|
|
|
scales.resize(1, 3) |
|
|
|
sid = np.array([0], dtype=np.int64) |
|
|
|
ort_inputs = { |
|
'input': seq, |
|
'input_lengths': seq_len, |
|
'scales': scales, |
|
'sid': sid |
|
} |
|
cls.ort_sess.run(None, ort_inputs) |
|
|
|
cls.model_is_ok = True |
|
|
|
logger.info( |
|
f"model init done with model path {model_path}" |
|
) |
|
|
|
@classmethod |
|
def setup_config(cls, config_path: str): |
|
cls.hps = get_hparams_from_file(config_path) |
|
cls.speaker_choices = list( |
|
map(lambda x: str(x[0])+":"+x[1], enumerate(cls.hps.speakers))) |
|
|
|
logger.info( |
|
f"config init done with config path {config_path}" |
|
) |
|
|
|
@classmethod |
|
def pdownload(cls, url: str, save_path: str, chunk_size: int = 8192): |
|
|
|
file_size = int(requests.head(url).headers["Content-Length"]) |
|
response = requests.get(url, stream=True) |
|
with tqdm(total=file_size, unit='B', unit_scale=True, unit_divisor=1024, miniters=1, |
|
desc="model download") as pbar: |
|
|
|
with open(save_path, 'wb') as f: |
|
for chunk in response.iter_content(chunk_size=chunk_size): |
|
if chunk: |
|
f.write(chunk) |
|
pbar.update(chunk_size) |
|
cls.setup_model(save_path) |
|
|