File size: 1,825 Bytes
568e264
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from os import PathLike
from typing import Dict, List, Optional, Union
from wenet.paraformer.search import paraformer_beautify_result
from wenet.text.char_tokenizer import CharTokenizer
from wenet.text.tokenize_utils import tokenize_by_seg_dict


def read_seg_dict(path):
    seg_table = {}
    with open(path, 'r', encoding='utf8') as fin:
        for line in fin:
            arr = line.strip().split('\t')
            assert len(arr) == 2
            seg_table[arr[0]] = arr[1]
    return seg_table


class ParaformerTokenizer(CharTokenizer):

    def __init__(self,
                 symbol_table: Union[str, PathLike, Dict],
                 seg_dict: Optional[Union[str, PathLike, Dict]] = None,
                 split_with_space: bool = False,
                 connect_symbol: str = '',
                 unk='<unk>') -> None:
        super().__init__(symbol_table, None, split_with_space, connect_symbol,
                         unk)
        self.seg_dict = seg_dict
        if seg_dict is not None and not isinstance(seg_dict, Dict):
            self.seg_dict = read_seg_dict(seg_dict)

    def text2tokens(self, line: str) -> List[str]:
        assert self.seg_dict is not None

        # TODO(Mddct): duplicated here, refine later
        line = line.strip()
        if self.non_lang_syms_pattern is not None:
            parts = self.non_lang_syms_pattern.split(line)
            parts = [w for w in parts if len(w.strip()) > 0]
        else:
            parts = [line]

        tokens = []
        for part in parts:
            if part in self.non_lang_syms:
                tokens.append(part)
            else:
                tokens.extend(tokenize_by_seg_dict(self.seg_dict, part))
        return tokens

    def tokens2text(self, tokens: List[str]) -> str:
        return paraformer_beautify_result(tokens)