File size: 3,813 Bytes
67c46fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from pathlib import Path
from typing import Iterable
from typing import List
from typing import Union
import warnings
import re

from funasr_detach.tokenizer.abs_tokenizer import BaseTokenizer
from funasr_detach.register import tables


@tables.register("tokenizer_classes", "CharTokenizer")
class CharTokenizer(BaseTokenizer):
    def __init__(
        self,
        non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
        space_symbol: str = "<space>",
        remove_non_linguistic_symbols: bool = False,
        split_with_space: bool = False,
        seg_dict: str = None,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.space_symbol = space_symbol
        if non_linguistic_symbols is None:
            self.non_linguistic_symbols = set()
        elif isinstance(non_linguistic_symbols, (Path, str)):
            non_linguistic_symbols = Path(non_linguistic_symbols)
            try:
                with non_linguistic_symbols.open("r", encoding="utf-8") as f:
                    self.non_linguistic_symbols = set(line.rstrip() for line in f)
            except FileNotFoundError:
                warnings.warn(f"{non_linguistic_symbols} doesn't exist.")
                self.non_linguistic_symbols = set()
        else:
            self.non_linguistic_symbols = set(non_linguistic_symbols)
        self.remove_non_linguistic_symbols = remove_non_linguistic_symbols
        self.split_with_space = split_with_space
        self.seg_dict = None
        if seg_dict is not None:
            self.seg_dict = load_seg_dict(seg_dict)

    def __repr__(self):
        return (
            f"{self.__class__.__name__}("
            f'space_symbol="{self.space_symbol}"'
            f'non_linguistic_symbols="{self.non_linguistic_symbols}"'
            f")"
        )

    def text2tokens(self, line: Union[str, list]) -> List[str]:

        # if self.split_with_space:

        if self.seg_dict is not None:
            tokens = line.strip().split(" ")
            tokens = seg_tokenize(tokens, self.seg_dict)
        else:
            tokens = []
            while len(line) != 0:
                for w in self.non_linguistic_symbols:
                    if line.startswith(w):
                        if not self.remove_non_linguistic_symbols:
                            tokens.append(line[: len(w)])
                        line = line[len(w) :]
                        break
                else:
                    t = line[0]
                    if t == " ":
                        # t = "<space>"
                        line = line[1:]
                        continue
                    tokens.append(t)
                    line = line[1:]
        return tokens

    def tokens2text(self, tokens: Iterable[str]) -> str:
        tokens = [t if t != self.space_symbol else " " for t in tokens]
        return "".join(tokens)


def load_seg_dict(seg_dict_file):
    seg_dict = {}
    assert isinstance(seg_dict_file, str)
    with open(seg_dict_file, "r", encoding="utf8") as f:
        lines = f.readlines()
        for line in lines:
            s = line.strip().split()
            key = s[0]
            value = s[1:]
            seg_dict[key] = " ".join(value)
    return seg_dict


def seg_tokenize(txt, seg_dict):
    pattern = re.compile(r"^[\u4E00-\u9FA50-9]+$")
    out_txt = ""
    for word in txt:
        word = word.lower()
        if word in seg_dict:
            out_txt += seg_dict[word] + " "
        else:
            if pattern.match(word):
                for char in word:
                    if char in seg_dict:
                        out_txt += seg_dict[char] + " "
                    else:
                        out_txt += "<unk>" + " "
            else:
                out_txt += "<unk>" + " "
    return out_txt.strip().split()