lj1995 commited on
Commit
34eca10
·
verified ·
1 Parent(s): 167b457

Update text/g2pw/onnx_api.py

Browse files
Files changed (1) hide show
  1. text/g2pw/onnx_api.py +12 -16
text/g2pw/onnx_api.py CHANGED
@@ -81,26 +81,22 @@ class G2PWOnnxConverter:
81
  model_source: str=None,
82
  enable_non_tradional_chinese: bool=False):
83
  uncompress_path = download_and_decompress(model_dir)
84
- print(":::1")
85
  sess_options = onnxruntime.SessionOptions()
86
- print(":::2")
87
  sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
88
- print(":::3")
89
  sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL
90
- print(":::4")
91
  sess_options.intra_op_num_threads = 2
92
- print(":::5")
93
- self.session_g2pW = onnxruntime.InferenceSession(os.path.join(uncompress_path, 'g2pW.onnx'),sess_options=sess_options, providers=['CPUExecutionProvider'])
94
- print(":::6")
95
  self.config = load_config(
96
  config_path=os.path.join(uncompress_path, 'config.py'),
97
  use_default=True)
98
 
99
  self.model_source = model_source if model_source else self.config.model_source
100
  self.enable_opencc = enable_non_tradional_chinese
101
- print(":::7")
102
  self.tokenizer = AutoTokenizer.from_pretrained(self.model_source)
103
- print(":::8")
104
  polyphonic_chars_path = os.path.join(uncompress_path,
105
  'POLYPHONIC_CHARS.txt')
106
  monophonic_chars_path = os.path.join(uncompress_path,
@@ -124,14 +120,14 @@ class G2PWOnnxConverter:
124
  polyphonic_chars=self.polyphonic_chars
125
  ) if self.config.use_char_phoneme else get_phoneme_labels(
126
  polyphonic_chars=self.polyphonic_chars)
127
- print(":::9")
128
  self.chars = sorted(list(self.char2phonemes.keys()))
129
 
130
  self.polyphonic_chars_new = set(self.chars)
131
  for char in self.non_polyphonic:
132
  if char in self.polyphonic_chars_new:
133
  self.polyphonic_chars_new.remove(char)
134
- print(":::10")
135
  self.monophonic_chars_dict = {
136
  char: phoneme
137
  for char, phoneme in self.monophonic_chars
@@ -139,11 +135,11 @@ class G2PWOnnxConverter:
139
  for char in self.non_monophonic:
140
  if char in self.monophonic_chars_dict:
141
  self.monophonic_chars_dict.pop(char)
142
- print(":::11")
143
  self.pos_tags = [
144
  'UNK', 'A', 'C', 'D', 'I', 'N', 'P', 'T', 'V', 'DE', 'SHI'
145
  ]
146
- print(":::12")
147
  with open(
148
  os.path.join(uncompress_path,
149
  'bopomofo_to_pinyin_wo_tune_dict.json'),
@@ -154,16 +150,16 @@ class G2PWOnnxConverter:
154
  'bopomofo': lambda x: x,
155
  'pinyin': self._convert_bopomofo_to_pinyin,
156
  }[style]
157
- print(":::13")
158
  with open(
159
  os.path.join(uncompress_path, 'char_bopomofo_dict.json'),
160
  'r',
161
  encoding='utf-8') as fr:
162
  self.char_bopomofo_dict = json.load(fr)
163
- print(":::14")
164
  if self.enable_opencc:
165
  self.cc = OpenCC('s2tw')
166
- print(":::15")
167
  def _convert_bopomofo_to_pinyin(self, bopomofo: str) -> str:
168
  tone = bopomofo[-1]
169
  assert tone in '12345'
 
81
  model_source: str=None,
82
  enable_non_tradional_chinese: bool=False):
83
  uncompress_path = download_and_decompress(model_dir)
84
+
85
  sess_options = onnxruntime.SessionOptions()
 
86
  sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
 
87
  sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL
 
88
  sess_options.intra_op_num_threads = 2
89
+ self.session_g2pW = onnxruntime.InferenceSession(os.path.join(uncompress_path, 'g2pW.onnx'), sess_options=sess_options, providers=['CPUExecutionProvider'])
90
+
 
91
  self.config = load_config(
92
  config_path=os.path.join(uncompress_path, 'config.py'),
93
  use_default=True)
94
 
95
  self.model_source = model_source if model_source else self.config.model_source
96
  self.enable_opencc = enable_non_tradional_chinese
97
+
98
  self.tokenizer = AutoTokenizer.from_pretrained(self.model_source)
99
+
100
  polyphonic_chars_path = os.path.join(uncompress_path,
101
  'POLYPHONIC_CHARS.txt')
102
  monophonic_chars_path = os.path.join(uncompress_path,
 
120
  polyphonic_chars=self.polyphonic_chars
121
  ) if self.config.use_char_phoneme else get_phoneme_labels(
122
  polyphonic_chars=self.polyphonic_chars)
123
+
124
  self.chars = sorted(list(self.char2phonemes.keys()))
125
 
126
  self.polyphonic_chars_new = set(self.chars)
127
  for char in self.non_polyphonic:
128
  if char in self.polyphonic_chars_new:
129
  self.polyphonic_chars_new.remove(char)
130
+
131
  self.monophonic_chars_dict = {
132
  char: phoneme
133
  for char, phoneme in self.monophonic_chars
 
135
  for char in self.non_monophonic:
136
  if char in self.monophonic_chars_dict:
137
  self.monophonic_chars_dict.pop(char)
138
+
139
  self.pos_tags = [
140
  'UNK', 'A', 'C', 'D', 'I', 'N', 'P', 'T', 'V', 'DE', 'SHI'
141
  ]
142
+
143
  with open(
144
  os.path.join(uncompress_path,
145
  'bopomofo_to_pinyin_wo_tune_dict.json'),
 
150
  'bopomofo': lambda x: x,
151
  'pinyin': self._convert_bopomofo_to_pinyin,
152
  }[style]
153
+
154
  with open(
155
  os.path.join(uncompress_path, 'char_bopomofo_dict.json'),
156
  'r',
157
  encoding='utf-8') as fr:
158
  self.char_bopomofo_dict = json.load(fr)
159
+
160
  if self.enable_opencc:
161
  self.cc = OpenCC('s2tw')
162
+
163
  def _convert_bopomofo_to_pinyin(self, bopomofo: str) -> str:
164
  tone = bopomofo[-1]
165
  assert tone in '12345'