TeraSpace commited on
Commit
9f30535
·
1 Parent(s): b21421b

Create infer_onnx.py

Browse files
Files changed (1) hide show
  1. infer_onnx.py +91 -0
infer_onnx.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import onnxruntime
4
+ import numpy as np
5
+ from huggingface_hub import snapshot_download
6
+ from gruut import sentences
7
+ import time
8
+ import scipy.io.wavfile
9
+
10
+
11
+
12
+ class TTS:
13
+ def __init__(self, model_name: str, save_path: str = "./model", add_time_to_end: float = 0.8) -> None:
14
+ if not os.path.exists(save_path):
15
+ os.mkdir(save_path)
16
+
17
+ model_dir = os.path.join(save_path, model_name)
18
+
19
+ if not os.path.exists(model_dir):
20
+ snapshot_download(repo_id=model_name,
21
+ allow_patterns=["*.txt", "*.onnx"],
22
+ local_dir=model_dir,
23
+ local_dir_use_symlinks=False
24
+ )
25
+
26
+ sess_options = onnxruntime.SessionOptions()
27
+ self.model = onnxruntime.InferenceSession(os.path.join(model_dir, "exported/model.onnx"), sess_options=sess_options)
28
+
29
+ with open(os.path.join(model_dir, "exported/vocab.txt"), "r", encoding="utf-8") as vocab_file:
30
+ self.symbols = vocab_file.read().split("\n")
31
+ self.symbols = list(map(chr, list(map(int, self.symbols))))
32
+
33
+ self.symbol_to_id = {s: i for i, s in enumerate(self.symbols)}
34
+ self.add_time_to_end = add_time_to_end
35
+
36
+
37
+ def _ru_phonems(self, text: str) -> str:
38
+ text = text.lower()
39
+ phonemes = ""
40
+ for sent in sentences(text, lang="ru"):
41
+ for word in sent:
42
+ if word.phonemes:
43
+ phonemes += "".join(word.phonemes)
44
+ phonemes = re.sub(re.compile(r'\s+'), ' ', phonemes).lstrip().rstrip()
45
+ return phonemes
46
+
47
+
48
+ def _text_to_sequence(self, text: str) -> list[int]:
49
+ '''convert text to seq'''
50
+ sequence = []
51
+ clean_text = self._ru_phonems(text)
52
+ for symbol in clean_text:
53
+ symbol_id = self.symbol_to_id[symbol]
54
+ sequence += [symbol_id]
55
+ return sequence
56
+
57
+
58
+ def _intersperse(self, lst, item):
59
+ result = [item] * (len(lst) * 2 + 1)
60
+ result[1::2] = lst
61
+ return result
62
+
63
+
64
+ def _get_text(self, text: str) -> list[int]:
65
+ text_norm = self._text_to_sequence(text)
66
+ text_norm = self._intersperse(text_norm, 0)
67
+ return text_norm
68
+
69
+ def save_wav(self, audio, path:str):
70
+ '''save audio to wav'''
71
+ scipy.io.wavfile.write(path, 22050, audio)
72
+
73
+ def __call__(self, text: str, play = False):
74
+ phoneme_ids = self._get_text(text)
75
+ text = np.expand_dims(np.array(phoneme_ids, dtype=np.int64), 0)
76
+ text_lengths = np.array([text.shape[1]], dtype=np.int64)
77
+ scales = np.array(
78
+ [0.667, 1, 0.8],
79
+ dtype=np.float32,
80
+ )
81
+ audio = self.model.run(
82
+ None,
83
+ {
84
+ "input": text,
85
+ "input_lengths": text_lengths,
86
+ "scales": scales,
87
+ "sid": None,
88
+ },
89
+ )[0][0,0][0]
90
+
91
+ return audio