import json import numpy as np from abc import ABC from pathlib import Path from abc import abstractmethod from typing import Union, Iterable, List, Dict class AbsTokenizer(ABC): @abstractmethod def text2tokens(self, line: str) -> List[str]: raise NotImplementedError @abstractmethod def tokens2text(self, tokens: Iterable[str]) -> str: raise NotImplementedError class BaseTokenizer(ABC): def __init__( self, token_list: Union[Path, str, Iterable[str]] = None, unk_symbol: str = "", **kwargs, ): if token_list is not None: if isinstance(token_list, (Path, str)) and token_list.endswith(".txt"): token_list = Path(token_list) self.token_list_repr = str(token_list) self.token_list: List[str] = [] with token_list.open("r", encoding="utf-8") as f: for idx, line in enumerate(f): line = line.rstrip() self.token_list.append(line) elif isinstance(token_list, (Path, str)) and token_list.endswith(".json"): token_list = Path(token_list) self.token_list_repr = str(token_list) self.token_list: List[str] = [] with open(token_list, "r", encoding="utf-8") as f: self.token_list = json.load(f) else: self.token_list: List[str] = list(token_list) self.token_list_repr = "" for i, t in enumerate(self.token_list): if i == 3: break self.token_list_repr += f"{t}, " self.token_list_repr += f"... (NVocab={(len(self.token_list))})" self.token2id: Dict[str, int] = {} for i, t in enumerate(self.token_list): if t in self.token2id: raise RuntimeError(f'Symbol "{t}" is duplicated') self.token2id[t] = i self.unk_symbol = unk_symbol if self.unk_symbol not in self.token2id: raise RuntimeError( f"Unknown symbol '{unk_symbol}' doesn't exist in the token_list" ) self.unk_id = self.token2id[self.unk_symbol] def encode(self, text): tokens = self.text2tokens(text) text_ints = self.tokens2ids(tokens) return text_ints def decode(self, text_ints): token = self.ids2tokens(text_ints) text = self.tokens2text(token) return text def get_num_vocabulary_size(self) -> int: return len(self.token_list) def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]: if isinstance(integers, np.ndarray) and integers.ndim != 1: raise ValueError(f"Must be 1 dim ndarray, but got {integers.ndim}") return [self.token_list[i] for i in integers] def tokens2ids(self, tokens: Iterable[str]) -> List[int]: return [self.token2id.get(i, self.unk_id) for i in tokens] @abstractmethod def text2tokens(self, line: str) -> List[str]: raise NotImplementedError @abstractmethod def tokens2text(self, tokens: Iterable[str]) -> str: raise NotImplementedError