|
import json |
|
import logging |
|
import re |
|
from typing import Dict, Tuple, List, Literal, Callable, Optional |
|
import sys |
|
|
|
from numba import jit |
|
import numpy as np |
|
|
|
from .utils import del_all |
|
|
|
|
|
@jit |
|
def _find_index(table: np.ndarray, val: np.uint16): |
|
for i in range(table.size): |
|
if table[i] == val: |
|
return i |
|
return -1 |
|
|
|
|
|
@jit |
|
def _fast_replace( |
|
table: np.ndarray, text: bytes |
|
) -> Tuple[np.ndarray, List[Tuple[str, str]]]: |
|
result = np.frombuffer(text, dtype=np.uint16).copy() |
|
replaced_words = [] |
|
for i in range(result.size): |
|
ch = result[i] |
|
p = _find_index(table[0], ch) |
|
if p >= 0: |
|
repl_char = table[1][p] |
|
result[i] = repl_char |
|
replaced_words.append((chr(ch), chr(repl_char))) |
|
return result, replaced_words |
|
|
|
|
|
@jit |
|
def _split_tags(text: str) -> Tuple[List[str], List[str]]: |
|
texts: List[str] = [] |
|
tags: List[str] = [] |
|
current_text = "" |
|
current_tag = "" |
|
for c in text: |
|
if c == "[": |
|
texts.append(current_text) |
|
current_text = "" |
|
current_tag = c |
|
elif current_tag != "": |
|
current_tag += c |
|
else: |
|
current_text += c |
|
if c == "]": |
|
tags.append(current_tag) |
|
current_tag = "" |
|
if current_text != "": |
|
texts.append(current_text) |
|
return texts, tags |
|
|
|
|
|
@jit |
|
def _combine_tags(texts: List[str], tags: List[str]) -> str: |
|
text = "" |
|
for t in texts: |
|
tg = "" |
|
if len(tags) > 0: |
|
tg = tags.pop(0) |
|
text += t + tg |
|
return text |
|
|
|
|
|
class Normalizer: |
|
def __init__(self, map_file_path: str, logger=logging.getLogger(__name__)): |
|
self.logger = logger |
|
self.normalizers: Dict[str, Callable[[str], str]] = {} |
|
self.homophones_map = self._load_homophones_map(map_file_path) |
|
""" |
|
homophones_map |
|
|
|
Replace the mispronounced characters with correctly pronounced ones. |
|
|
|
Creation process of homophones_map.json: |
|
|
|
1. Establish a word corpus using the [Tencent AI Lab Embedding Corpora v0.2.0 large] with 12 million entries. After cleaning, approximately 1.8 million entries remain. Use ChatTTS to infer the text. |
|
2. Record discrepancies between the inferred and input text, identifying about 180,000 misread words. |
|
3. Create a pinyin to common characters mapping using correctly read characters by ChatTTS. |
|
4. For each discrepancy, extract the correct pinyin using [python-pinyin] and find homophones with the correct pronunciation from the mapping. |
|
|
|
Thanks to: |
|
[Tencent AI Lab Embedding Corpora for Chinese and English Words and Phrases](https://ai.tencent.com/ailab/nlp/en/embedding.html) |
|
[python-pinyin](https://github.com/mozillazg/python-pinyin) |
|
|
|
""" |
|
self.coding = "utf-16-le" if sys.byteorder == "little" else "utf-16-be" |
|
self.reject_pattern = re.compile(r"[^\u4e00-\u9fffA-Za-z,。、,\. ]") |
|
self.sub_pattern = re.compile(r"\[[\w_]+\]") |
|
self.chinese_char_pattern = re.compile(r"[\u4e00-\u9fff]") |
|
self.english_word_pattern = re.compile(r"\b[A-Za-z]+\b") |
|
self.character_simplifier = str.maketrans( |
|
{ |
|
":": ",", |
|
";": ",", |
|
"!": "。", |
|
"(": ",", |
|
")": ",", |
|
"【": ",", |
|
"】": ",", |
|
"『": ",", |
|
"』": ",", |
|
"「": ",", |
|
"」": ",", |
|
"《": ",", |
|
"》": ",", |
|
"-": ",", |
|
":": ",", |
|
";": ",", |
|
"!": ".", |
|
"(": ",", |
|
")": ",", |
|
|
|
|
|
">": ",", |
|
"<": ",", |
|
"-": ",", |
|
} |
|
) |
|
self.halfwidth_2_fullwidth = str.maketrans( |
|
{ |
|
"!": "!", |
|
'"': "“", |
|
"'": "‘", |
|
"#": "#", |
|
"$": "$", |
|
"%": "%", |
|
"&": "&", |
|
"(": "(", |
|
")": ")", |
|
",": ",", |
|
"-": "-", |
|
"*": "*", |
|
"+": "+", |
|
".": "。", |
|
"/": "/", |
|
":": ":", |
|
";": ";", |
|
"<": "<", |
|
"=": "=", |
|
">": ">", |
|
"?": "?", |
|
"@": "@", |
|
|
|
"\\": "\", |
|
|
|
"^": "^", |
|
|
|
"`": "`", |
|
"{": "{", |
|
"|": "|", |
|
"}": "}", |
|
"~": "~", |
|
} |
|
) |
|
|
|
def __call__( |
|
self, |
|
text: str, |
|
do_text_normalization=True, |
|
do_homophone_replacement=True, |
|
lang: Optional[Literal["zh", "en"]] = None, |
|
) -> str: |
|
if do_text_normalization: |
|
_lang = self._detect_language(text) if lang is None else lang |
|
if _lang in self.normalizers: |
|
texts, tags = _split_tags(text) |
|
self.logger.debug("split texts %s, tags %s", str(texts), str(tags)) |
|
texts = [self.normalizers[_lang](t) for t in texts] |
|
self.logger.debug("normed texts %s", str(texts)) |
|
text = _combine_tags(texts, tags) if len(tags) > 0 else texts[0] |
|
self.logger.debug("combined text %s", text) |
|
if _lang == "zh": |
|
text = self._apply_half2full_map(text) |
|
invalid_characters = self._count_invalid_characters(text) |
|
if len(invalid_characters): |
|
self.logger.warning(f"found invalid characters: {invalid_characters}") |
|
text = self._apply_character_map(text) |
|
if do_homophone_replacement: |
|
arr, replaced_words = _fast_replace( |
|
self.homophones_map, |
|
text.encode(self.coding), |
|
) |
|
if replaced_words: |
|
text = arr.tobytes().decode(self.coding) |
|
repl_res = ", ".join([f"{_[0]}->{_[1]}" for _ in replaced_words]) |
|
self.logger.info(f"replace homophones: {repl_res}") |
|
if len(invalid_characters): |
|
texts, tags = _split_tags(text) |
|
self.logger.debug("split texts %s, tags %s", str(texts), str(tags)) |
|
texts = [self.reject_pattern.sub("", t) for t in texts] |
|
self.logger.debug("normed texts %s", str(texts)) |
|
text = _combine_tags(texts, tags) if len(tags) > 0 else texts[0] |
|
self.logger.debug("combined text %s", text) |
|
return text |
|
|
|
def register(self, name: str, normalizer: Callable[[str], str]) -> bool: |
|
if name in self.normalizers: |
|
self.logger.warning(f"name {name} has been registered") |
|
return False |
|
try: |
|
val = normalizer("test string 测试字符串") |
|
if not isinstance(val, str): |
|
self.logger.warning("normalizer must have caller type (str) -> str") |
|
return False |
|
except Exception as e: |
|
self.logger.warning(e) |
|
return False |
|
self.normalizers[name] = normalizer |
|
return True |
|
|
|
def unregister(self, name: str): |
|
if name in self.normalizers: |
|
del self.normalizers[name] |
|
|
|
def destroy(self): |
|
del_all(self.normalizers) |
|
del self.homophones_map |
|
|
|
def _load_homophones_map(self, map_file_path: str) -> np.ndarray: |
|
with open(map_file_path, "r", encoding="utf-8") as f: |
|
homophones_map: Dict[str, str] = json.load(f) |
|
map = np.empty((2, len(homophones_map)), dtype=np.uint32) |
|
for i, k in enumerate(homophones_map.keys()): |
|
map[:, i] = (ord(k), ord(homophones_map[k])) |
|
del homophones_map |
|
return map |
|
|
|
def _count_invalid_characters(self, s: str): |
|
s = self.sub_pattern.sub("", s) |
|
non_alphabetic_chinese_chars = self.reject_pattern.findall(s) |
|
return set(non_alphabetic_chinese_chars) |
|
|
|
def _apply_half2full_map(self, text: str) -> str: |
|
return text.translate(self.halfwidth_2_fullwidth) |
|
|
|
def _apply_character_map(self, text: str) -> str: |
|
return text.translate(self.character_simplifier) |
|
|
|
def _detect_language(self, sentence: str) -> Literal["zh", "en"]: |
|
chinese_chars = self.chinese_char_pattern.findall(sentence) |
|
english_words = self.english_word_pattern.findall(sentence) |
|
|
|
if len(chinese_chars) > len(english_words): |
|
return "zh" |
|
else: |
|
return "en" |
|
|