|
|
|
|
|
|
|
|
|
|
|
from pathlib import Path |
|
from typing import Dict, Iterable, List, Union |
|
|
|
import numpy as np |
|
import yaml |
|
from typeguard import check_argument_types |
|
|
|
|
|
class TokenIDConverterError(Exception): |
|
pass |
|
|
|
|
|
class TokenIDConverter: |
|
def __init__( |
|
self, |
|
token_list: Union[List, str], |
|
): |
|
check_argument_types() |
|
|
|
self.token_list = token_list |
|
self.unk_symbol = token_list[-1] |
|
self.token2id = {v: i for i, v in enumerate(self.token_list)} |
|
self.unk_id = self.token2id[self.unk_symbol] |
|
|
|
def get_num_vocabulary_size(self) -> int: |
|
return len(self.token_list) |
|
|
|
def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]: |
|
if isinstance(integers, np.ndarray) and integers.ndim != 1: |
|
raise TokenIDConverterError( |
|
f"Must be 1 dim ndarray, but got {integers.ndim}" |
|
) |
|
return [self.token_list[i] for i in integers] |
|
|
|
def tokens2ids(self, tokens: Iterable[str]) -> List[int]: |
|
|
|
return [self.token2id.get(i, self.unk_id) for i in tokens] |
|
|
|
|
|
def split_to_mini_sentence(words: list, word_limit: int = 20): |
|
assert word_limit > 1 |
|
if len(words) <= word_limit: |
|
return [words] |
|
sentences = [] |
|
length = len(words) |
|
sentence_len = length // word_limit |
|
for i in range(sentence_len): |
|
sentences.append(words[i * word_limit : (i + 1) * word_limit]) |
|
if length % word_limit > 0: |
|
sentences.append(words[sentence_len * word_limit :]) |
|
return sentences |
|
|
|
|
|
def code_mix_split_words(text: str): |
|
words = [] |
|
segs = text.split() |
|
for seg in segs: |
|
|
|
current_word = "" |
|
for c in seg: |
|
if len(c.encode()) == 1: |
|
|
|
current_word += c |
|
else: |
|
|
|
if len(current_word) > 0: |
|
words.append(current_word) |
|
current_word = "" |
|
words.append(c) |
|
if len(current_word) > 0: |
|
words.append(current_word) |
|
return words |
|
|
|
|
|
def read_yaml(yaml_path: Union[str, Path]) -> Dict: |
|
if not Path(yaml_path).exists(): |
|
raise FileExistsError(f"The {yaml_path} does not exist.") |
|
|
|
with open(str(yaml_path), "rb") as f: |
|
data = yaml.load(f, Loader=yaml.Loader) |
|
return data |
|
|