ctt_punctuator / cttpunctuator /src /utils /text_post_process.py
lovemefan's picture
Upload 14 files
ec97ce5
# -*- coding:utf-8 -*-
# @FileName :text_post_process.py
# @Time :2023/4/13 15:09
# @Author :lovemefan
# @Email :[email protected]
from pathlib import Path
from typing import Dict, Iterable, List, Union
import numpy as np
import yaml
from typeguard import check_argument_types
class TokenIDConverterError(Exception):
pass
class TokenIDConverter:
def __init__(
self,
token_list: Union[List, str],
):
check_argument_types()
self.token_list = token_list
self.unk_symbol = token_list[-1]
self.token2id = {v: i for i, v in enumerate(self.token_list)}
self.unk_id = self.token2id[self.unk_symbol]
def get_num_vocabulary_size(self) -> int:
return len(self.token_list)
def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
if isinstance(integers, np.ndarray) and integers.ndim != 1:
raise TokenIDConverterError(
f"Must be 1 dim ndarray, but got {integers.ndim}"
)
return [self.token_list[i] for i in integers]
def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
return [self.token2id.get(i, self.unk_id) for i in tokens]
def split_to_mini_sentence(words: list, word_limit: int = 20):
assert word_limit > 1
if len(words) <= word_limit:
return [words]
sentences = []
length = len(words)
sentence_len = length // word_limit
for i in range(sentence_len):
sentences.append(words[i * word_limit : (i + 1) * word_limit])
if length % word_limit > 0:
sentences.append(words[sentence_len * word_limit :])
return sentences
def code_mix_split_words(text: str):
words = []
segs = text.split()
for seg in segs:
# There is no space in seg.
current_word = ""
for c in seg:
if len(c.encode()) == 1:
# This is an ASCII char.
current_word += c
else:
# This is a Chinese char.
if len(current_word) > 0:
words.append(current_word)
current_word = ""
words.append(c)
if len(current_word) > 0:
words.append(current_word)
return words
def read_yaml(yaml_path: Union[str, Path]) -> Dict:
if not Path(yaml_path).exists():
raise FileExistsError(f"The {yaml_path} does not exist.")
with open(str(yaml_path), "rb") as f:
data = yaml.load(f, Loader=yaml.Loader)
return data