wetdog commited on
Commit
35cd28c
·
1 Parent(s): daa90f5

fix config path

Browse files
Files changed (1) hide show
  1. infer_onnx.py +3 -3
infer_onnx.py CHANGED
@@ -30,14 +30,14 @@ def process_text(i: int, text: str, device: torch.device):
30
  MODEL_PATH_MATCHA_MEL="matcha_multispeaker_cat_opset_15.onnx"
31
  MODEL_PATH_MATCHA="matcha_hifigan_multispeaker_cat.onnx"
32
  MODEL_PATH_VOCOS="mel_spec_22khz.onnx"
33
- CONFIG_PATH="/home/jgiraldo/projects/tts-onnx-comparison/config_22khz.yaml"
34
 
35
  sess_options = onnxruntime.SessionOptions()
36
  model_matcha_mel= onnxruntime.InferenceSession(str(MODEL_PATH_MATCHA_MEL), sess_options=sess_options, providers=["CPUExecutionProvider"])
37
  model_vocos = onnxruntime.InferenceSession(str(MODEL_PATH_VOCOS), sess_options=sess_options, providers=["CPUExecutionProvider"])
38
  model_matcha = onnxruntime.InferenceSession(str(MODEL_PATH_MATCHA), sess_options=sess_options, providers=["CPUExecutionProvider"])
39
 
40
- def vocos_inference(mel: torch.Tensor, config):
41
 
42
  with open(CONFIG_PATH, "r") as f:
43
  config = yaml.safe_load(f)
@@ -102,7 +102,7 @@ def tts(text:str, spk_id:int):
102
 
103
  mel, mel_lengths = model_matcha_mel.run(None, inputs)
104
  # vocos inference
105
- wavs_vocos = vocos_inference(mel, CONFIG_PATH)
106
 
107
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp_matcha_vocos:
108
  sf.write(fp_matcha_vocos.name, wavs_vocos.squeeze(0), 22050, "PCM_24")
 
30
  MODEL_PATH_MATCHA_MEL="matcha_multispeaker_cat_opset_15.onnx"
31
  MODEL_PATH_MATCHA="matcha_hifigan_multispeaker_cat.onnx"
32
  MODEL_PATH_VOCOS="mel_spec_22khz.onnx"
33
+ CONFIG_PATH="config_22khz.yaml"
34
 
35
  sess_options = onnxruntime.SessionOptions()
36
  model_matcha_mel= onnxruntime.InferenceSession(str(MODEL_PATH_MATCHA_MEL), sess_options=sess_options, providers=["CPUExecutionProvider"])
37
  model_vocos = onnxruntime.InferenceSession(str(MODEL_PATH_VOCOS), sess_options=sess_options, providers=["CPUExecutionProvider"])
38
  model_matcha = onnxruntime.InferenceSession(str(MODEL_PATH_MATCHA), sess_options=sess_options, providers=["CPUExecutionProvider"])
39
 
40
+ def vocos_inference(mel: torch.Tensor):
41
 
42
  with open(CONFIG_PATH, "r") as f:
43
  config = yaml.safe_load(f)
 
102
 
103
  mel, mel_lengths = model_matcha_mel.run(None, inputs)
104
  # vocos inference
105
+ wavs_vocos = vocos_inference(mel)
106
 
107
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp_matcha_vocos:
108
  sf.write(fp_matcha_vocos.name, wavs_vocos.squeeze(0), 22050, "PCM_24")