Spaces:
Running
Running
Update text/g2pw/onnx_api.py
Browse files- text/g2pw/onnx_api.py +12 -16
text/g2pw/onnx_api.py
CHANGED
@@ -81,26 +81,22 @@ class G2PWOnnxConverter:
|
|
81 |
model_source: str=None,
|
82 |
enable_non_tradional_chinese: bool=False):
|
83 |
uncompress_path = download_and_decompress(model_dir)
|
84 |
-
|
85 |
sess_options = onnxruntime.SessionOptions()
|
86 |
-
print(":::2")
|
87 |
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
88 |
-
print(":::3")
|
89 |
sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL
|
90 |
-
print(":::4")
|
91 |
sess_options.intra_op_num_threads = 2
|
92 |
-
|
93 |
-
|
94 |
-
print(":::6")
|
95 |
self.config = load_config(
|
96 |
config_path=os.path.join(uncompress_path, 'config.py'),
|
97 |
use_default=True)
|
98 |
|
99 |
self.model_source = model_source if model_source else self.config.model_source
|
100 |
self.enable_opencc = enable_non_tradional_chinese
|
101 |
-
|
102 |
self.tokenizer = AutoTokenizer.from_pretrained(self.model_source)
|
103 |
-
|
104 |
polyphonic_chars_path = os.path.join(uncompress_path,
|
105 |
'POLYPHONIC_CHARS.txt')
|
106 |
monophonic_chars_path = os.path.join(uncompress_path,
|
@@ -124,14 +120,14 @@ class G2PWOnnxConverter:
|
|
124 |
polyphonic_chars=self.polyphonic_chars
|
125 |
) if self.config.use_char_phoneme else get_phoneme_labels(
|
126 |
polyphonic_chars=self.polyphonic_chars)
|
127 |
-
|
128 |
self.chars = sorted(list(self.char2phonemes.keys()))
|
129 |
|
130 |
self.polyphonic_chars_new = set(self.chars)
|
131 |
for char in self.non_polyphonic:
|
132 |
if char in self.polyphonic_chars_new:
|
133 |
self.polyphonic_chars_new.remove(char)
|
134 |
-
|
135 |
self.monophonic_chars_dict = {
|
136 |
char: phoneme
|
137 |
for char, phoneme in self.monophonic_chars
|
@@ -139,11 +135,11 @@ class G2PWOnnxConverter:
|
|
139 |
for char in self.non_monophonic:
|
140 |
if char in self.monophonic_chars_dict:
|
141 |
self.monophonic_chars_dict.pop(char)
|
142 |
-
|
143 |
self.pos_tags = [
|
144 |
'UNK', 'A', 'C', 'D', 'I', 'N', 'P', 'T', 'V', 'DE', 'SHI'
|
145 |
]
|
146 |
-
|
147 |
with open(
|
148 |
os.path.join(uncompress_path,
|
149 |
'bopomofo_to_pinyin_wo_tune_dict.json'),
|
@@ -154,16 +150,16 @@ class G2PWOnnxConverter:
|
|
154 |
'bopomofo': lambda x: x,
|
155 |
'pinyin': self._convert_bopomofo_to_pinyin,
|
156 |
}[style]
|
157 |
-
|
158 |
with open(
|
159 |
os.path.join(uncompress_path, 'char_bopomofo_dict.json'),
|
160 |
'r',
|
161 |
encoding='utf-8') as fr:
|
162 |
self.char_bopomofo_dict = json.load(fr)
|
163 |
-
|
164 |
if self.enable_opencc:
|
165 |
self.cc = OpenCC('s2tw')
|
166 |
-
|
167 |
def _convert_bopomofo_to_pinyin(self, bopomofo: str) -> str:
|
168 |
tone = bopomofo[-1]
|
169 |
assert tone in '12345'
|
|
|
81 |
model_source: str=None,
|
82 |
enable_non_tradional_chinese: bool=False):
|
83 |
uncompress_path = download_and_decompress(model_dir)
|
84 |
+
|
85 |
sess_options = onnxruntime.SessionOptions()
|
|
|
86 |
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
|
87 |
sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL
|
|
|
88 |
sess_options.intra_op_num_threads = 2
|
89 |
+
self.session_g2pW = onnxruntime.InferenceSession(os.path.join(uncompress_path, 'g2pW.onnx'), sess_options=sess_options, providers=['CPUExecutionProvider'])
|
90 |
+
|
|
|
91 |
self.config = load_config(
|
92 |
config_path=os.path.join(uncompress_path, 'config.py'),
|
93 |
use_default=True)
|
94 |
|
95 |
self.model_source = model_source if model_source else self.config.model_source
|
96 |
self.enable_opencc = enable_non_tradional_chinese
|
97 |
+
|
98 |
self.tokenizer = AutoTokenizer.from_pretrained(self.model_source)
|
99 |
+
|
100 |
polyphonic_chars_path = os.path.join(uncompress_path,
|
101 |
'POLYPHONIC_CHARS.txt')
|
102 |
monophonic_chars_path = os.path.join(uncompress_path,
|
|
|
120 |
polyphonic_chars=self.polyphonic_chars
|
121 |
) if self.config.use_char_phoneme else get_phoneme_labels(
|
122 |
polyphonic_chars=self.polyphonic_chars)
|
123 |
+
|
124 |
self.chars = sorted(list(self.char2phonemes.keys()))
|
125 |
|
126 |
self.polyphonic_chars_new = set(self.chars)
|
127 |
for char in self.non_polyphonic:
|
128 |
if char in self.polyphonic_chars_new:
|
129 |
self.polyphonic_chars_new.remove(char)
|
130 |
+
|
131 |
self.monophonic_chars_dict = {
|
132 |
char: phoneme
|
133 |
for char, phoneme in self.monophonic_chars
|
|
|
135 |
for char in self.non_monophonic:
|
136 |
if char in self.monophonic_chars_dict:
|
137 |
self.monophonic_chars_dict.pop(char)
|
138 |
+
|
139 |
self.pos_tags = [
|
140 |
'UNK', 'A', 'C', 'D', 'I', 'N', 'P', 'T', 'V', 'DE', 'SHI'
|
141 |
]
|
142 |
+
|
143 |
with open(
|
144 |
os.path.join(uncompress_path,
|
145 |
'bopomofo_to_pinyin_wo_tune_dict.json'),
|
|
|
150 |
'bopomofo': lambda x: x,
|
151 |
'pinyin': self._convert_bopomofo_to_pinyin,
|
152 |
}[style]
|
153 |
+
|
154 |
with open(
|
155 |
os.path.join(uncompress_path, 'char_bopomofo_dict.json'),
|
156 |
'r',
|
157 |
encoding='utf-8') as fr:
|
158 |
self.char_bopomofo_dict = json.load(fr)
|
159 |
+
|
160 |
if self.enable_opencc:
|
161 |
self.cc = OpenCC('s2tw')
|
162 |
+
|
163 |
def _convert_bopomofo_to_pinyin(self, bopomofo: str) -> str:
|
164 |
tone = bopomofo[-1]
|
165 |
assert tone in '12345'
|