wuxulong19950206
First model version
14d1720
import argparse
import copy
import os
from typing import List
import jieba
import pypinyin
SPECIAL_NOTES = '。?!?!.;;:,,:'
def read_vocab(file: os.PathLike) -> List[str]:
with open(file) as f:
vocab = f.read().split('\n')
vocab = [v for v in vocab if len(v) > 0 and v != '\n']
return vocab
class TextNormal:
def __init__(self,
gp_vocab_file: os.PathLike,
py_vocab_file: os.PathLike,
add_sp1=False,
fix_er=False,
add_sil=True):
if gp_vocab_file is not None:
self.gp_vocab = read_vocab(gp_vocab_file)
if py_vocab_file is not None:
self.py_vocab = read_vocab(py_vocab_file)
self.in_py_vocab = dict([(p, True) for p in self.py_vocab])
self.add_sp1 = add_sp1
self.add_sil = add_sil
self.fix_er = fix_er
# gp2idx = dict([(c, i) for i, c in enumerate(self.gp_vocab)])
# idx2gp = dict([(i, c) for i, c in enumerate(self.gp_vocab)])
def _split2sent(self, text):
new_sub = [text]
while True:
sub = copy.deepcopy(new_sub)
new_sub = []
for s in sub:
sp = False
for t in SPECIAL_NOTES:
if t in s:
new_sub += s.split(t)
sp = True
break
if not sp and len(s) > 0:
new_sub += [s]
if len(new_sub) == len(sub):
break
tokens = [a for a in text if a in SPECIAL_NOTES]
return new_sub, tokens
def _correct_tone3(self, pys: List[str]) -> List[str]:
"""Fix the continuous tone3 pronunciation problem"""
for i in range(2, len(pys)):
if pys[i][-1] == '3' and pys[i - 1][-1] == '3' and pys[i - 2][-1] == '3':
pys[i - 1] = pys[i - 1][:-1] + '2' # change the middle one
for i in range(1, len(pys)):
if pys[i][-1] == '3':
if pys[i - 1][-1] == '3':
pys[i - 1] = pys[i - 1][:-1] + '2'
return pys
def _correct_tone4(self, pys: List[str]) -> List[str]:
"""Fixed the problem of pronouncing 不 bu2 yao4 / bu4 neng2"""
for i in range(len(pys) - 1):
if pys[i] == 'bu4':
if pys[i + 1][-1] == '4':
pys[i] = 'bu2'
return pys
def _replace_with_sp(self, pys: List[str]) -> List[str]:
for i, p in enumerate(pys):
if p in ',,、':
pys[i] = 'sp1'
return pys
def _correct_tone5(self, pys: List[str]) -> List[str]:
for i in range(len(pys)):
if pys[i][-1] not in '1234':
pys[i] += '5'
return pys
def gp2py(self, gp_text: str) -> List[str]:
gp_sent_list, tokens = self._split2sent(gp_text)
py_sent_list = []
for sent in gp_sent_list:
pys = []
for words in list(jieba.cut(sent)):
py = pypinyin.pinyin(words, pypinyin.TONE3)
py = [p[0] for p in py]
pys += py
if self.add_sp1:
pys = self._replace_with_sp(pys)
pys = self._correct_tone3(pys)
pys = self._correct_tone4(pys)
pys = self._correct_tone5(pys)
if self.add_sil:
py_sent_list += [' '.join(['sil'] + pys + ['sil'])]
else:
py_sent_list += [' '.join(pys)]
if self.add_sil:
gp_sent_list = ['sil ' + ' '.join(list(gp)) + ' sil' for gp in gp_sent_list]
else:
gp_sent_list = [' '.join(list(gp)) for gp in gp_sent_list]
if self.fix_er:
new_py_sent_list = []
for py, gp in zip(py_sent_list, gp_sent_list):
py = self._convert_er2(py, gp)
new_py_sent_list += [py]
py_sent_list = new_py_sent_list
print(new_py_sent_list)
return py_sent_list, gp_sent_list
def _convert_er2(self, py, gp):
py2hz = dict([(p, h) for p, h in zip(py.split(), gp.split())])
py_list = py.split()
for i, p in enumerate(py_list):
if (p == 'er2' and py2hz[p] == '儿' and i > 1 and len(py_list[i - 1]) > 2 and py_list[i - 1][-1] in '1234'):
py_er = py_list[i - 1][:-1] + 'r' + py_list[i - 1][-1]
if self.in_py_vocab.get(py_er, False): # must in vocab
py_list[i - 1] = py_er
py_list[i] = 'r'
py = ' '.join(py_list)
return py
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-t', '--text', type=str)
args = parser.parse_args()
text = args.text
tn = TextNormal('gp.vocab', 'py.vocab', add_sp1=True, fix_er=True)
py_list, gp_list = tn.gp2py(text)
for py, gp in zip(py_list, gp_list):
print(py + '|' + gp)