|
from pathlib import Path |
|
from typing import Iterable |
|
from typing import List |
|
from typing import Union |
|
|
|
from typeguard import check_argument_types |
|
|
|
from espnet2.text.abs_tokenizer import AbsTokenizer |
|
|
|
|
|
class CharTokenizer(AbsTokenizer): |
|
def __init__( |
|
self, |
|
non_linguistic_symbols: Union[Path, str, Iterable[str]] = None, |
|
space_symbol: str = "<space>", |
|
remove_non_linguistic_symbols: bool = False, |
|
): |
|
assert check_argument_types() |
|
self.space_symbol = space_symbol |
|
if non_linguistic_symbols is None: |
|
self.non_linguistic_symbols = set() |
|
elif isinstance(non_linguistic_symbols, (Path, str)): |
|
non_linguistic_symbols = Path(non_linguistic_symbols) |
|
with non_linguistic_symbols.open("r", encoding="utf-8") as f: |
|
self.non_linguistic_symbols = set(line.rstrip() for line in f) |
|
else: |
|
self.non_linguistic_symbols = set(non_linguistic_symbols) |
|
self.remove_non_linguistic_symbols = remove_non_linguistic_symbols |
|
|
|
def __repr__(self): |
|
return ( |
|
f"{self.__class__.__name__}(" |
|
f'space_symbol="{self.space_symbol}"' |
|
f'non_linguistic_symbols="{self.non_linguistic_symbols}"' |
|
f")" |
|
) |
|
|
|
def text2tokens(self, line: str) -> List[str]: |
|
tokens = [] |
|
while len(line) != 0: |
|
for w in self.non_linguistic_symbols: |
|
if line.startswith(w): |
|
if not self.remove_non_linguistic_symbols: |
|
tokens.append(line[: len(w)]) |
|
line = line[len(w) :] |
|
break |
|
else: |
|
t = line[0] |
|
if t == " ": |
|
t = "<space>" |
|
tokens.append(t) |
|
line = line[1:] |
|
return tokens |
|
|
|
def tokens2text(self, tokens: Iterable[str]) -> str: |
|
tokens = [t if t != self.space_symbol else " " for t in tokens] |
|
return "".join(tokens) |
|
|