|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse
|
|
from text import text_to_sequence
|
|
import numpy as np
|
|
from scipy.io import wavfile
|
|
import torch
|
|
import json
|
|
import commons
|
|
import utils
|
|
import sys
|
|
import pathlib
|
|
|
|
try:
|
|
import onnxruntime as ort
|
|
except ImportError:
|
|
print('Please install onnxruntime!')
|
|
sys.exit(1)
|
|
|
|
|
|
def to_numpy(tensor: torch.Tensor):
|
|
return tensor.detach().cpu().numpy() if tensor.requires_grad \
|
|
else tensor.detach().numpy()
|
|
|
|
|
|
def get_args():
|
|
parser = argparse.ArgumentParser(description='inference')
|
|
parser.add_argument('--onnx_model', required=True, help='onnx model')
|
|
parser.add_argument('--cfg', required=True, help='config file')
|
|
parser.add_argument('--outdir', default="onnx_output",
|
|
help='ouput directory')
|
|
|
|
|
|
|
|
|
|
parser.add_argument('--test_file', required=True, help='test file')
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
def get_symbols_from_json(path):
|
|
import os
|
|
assert os.path.isfile(path)
|
|
with open(path, 'r') as f:
|
|
data = json.load(f)
|
|
return data['symbols']
|
|
|
|
|
|
def main():
|
|
args = get_args()
|
|
print(args)
|
|
if not pathlib.Path(args.outdir).exists():
|
|
pathlib.Path(args.outdir).mkdir(exist_ok=True, parents=True)
|
|
|
|
symbols = get_symbols_from_json(args.cfg)
|
|
phone_dict = {
|
|
symbol: i for i, symbol in enumerate(symbols)
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hps = utils.get_hparams_from_file(args.cfg)
|
|
|
|
ort_sess = ort.InferenceSession(args.onnx_model)
|
|
|
|
with open(args.test_file) as fin:
|
|
for line in fin:
|
|
arr = line.strip().split("|")
|
|
audio_path = arr[0]
|
|
|
|
|
|
sid = 8
|
|
text = arr[1]
|
|
|
|
|
|
|
|
seq = text_to_sequence(text, symbols=hps.symbols, cleaner_names=["japanese_cleaners2"]
|
|
)
|
|
if hps.data.add_blank:
|
|
seq = commons.intersperse(seq, 0)
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x = np.array([seq], dtype=np.int64)
|
|
x_len = np.array([x.shape[1]], dtype=np.int64)
|
|
sid = np.array([sid], dtype=np.int64)
|
|
|
|
|
|
scales = np.array([0.667, 0.8, 1], dtype=np.float32)
|
|
|
|
|
|
scales.resize(1, 3)
|
|
|
|
ort_inputs = {
|
|
'input': x,
|
|
'input_lengths': x_len,
|
|
'scales': scales,
|
|
'sid': sid
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import time
|
|
|
|
start_time = time.perf_counter()
|
|
audio = np.squeeze(ort_sess.run(None, ort_inputs))
|
|
audio *= 32767.0 / max(0.01, np.max(np.abs(audio))) * 0.6
|
|
audio = np.clip(audio, -32767.0, 32767.0)
|
|
end_time = time.perf_counter()
|
|
|
|
print("infer time cost: ", end_time - start_time, "s")
|
|
|
|
wavfile.write(args.outdir + "/" + audio_path.split("/")[-1],
|
|
hps.data.sampling_rate, audio.astype(np.int16))
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|
|
|