File size: 4,484 Bytes
7436a15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from typing import List, Union, Optional

import pypinyin
import torch
from torch import NoneType

from transformers import BertTokenizerFast


class Pinyin2(object):
    def __init__(self):
        super(Pinyin2, self).__init__()
        pho_vocab = ['P']
        pho_vocab += [chr(x) for x in range(ord('1'), ord('5') + 1)]
        pho_vocab += [chr(x) for x in range(ord('a'), ord('z') + 1)]
        pho_vocab += ['U']
        assert len(pho_vocab) == 33
        self.pho_vocab_size = len(pho_vocab)
        self.pho_vocab = {c: idx for idx, c in enumerate(pho_vocab)}

    def get_pho_size(self):
        return self.pho_vocab_size

    @staticmethod
    def get_pinyin(c):
        if len(c) > 1:
            return 'U'
        s = pypinyin.pinyin(
            c,
            style=pypinyin.Style.TONE3,
            neutral_tone_with_five=True,
            errors=lambda x: ['U' for _ in x],
        )[0][0]
        if s == 'U':
            return s
        assert isinstance(s, str)
        assert s[-1] in '12345'
        s = s[-1] + s[:-1]
        return s

    def convert(self, chars):
        pinyins = list(map(self.get_pinyin, chars))
        pinyin_ids = [list(map(self.pho_vocab.get, pinyin)) for pinyin in pinyins]
        pinyin_lens = [len(pinyin) for pinyin in pinyins]
        pinyin_ids = torch.nn.utils.rnn.pad_sequence(
            [torch.tensor(x) for x in pinyin_ids],
            batch_first=True,
            padding_value=0,
        )
        return pinyin_ids, pinyin_lens


class ReaLiSeTokenizer(BertTokenizerFast):

    def __init__(self, **kwargs):
        super(ReaLiSeTokenizer, self).__init__(**kwargs)

        self.pho2_convertor = Pinyin2()

    def __call__(self,
                 text: Union[str, List[str], List[List[str]]] = None,
                 text_pair: Union[str, List[str], List[List[str]], NoneType] = None,
                 text_target: Union[str, List[str], List[List[str]]] = None,
                 text_pair_target: Union[str, List[str], List[List[str]], NoneType] = None,
                 add_special_tokens: bool = True,
                 padding=False,
                 truncation=None,
                 max_length: Optional[int] = None,
                 stride: int = 0,
                 is_split_into_words: bool = False,
                 pad_to_multiple_of: Optional[int] = None,
                 return_tensors=None,
                 return_token_type_ids: Optional[bool] = None,
                 return_attention_mask: Optional[bool] = None,
                 return_overflowing_tokens: bool = False, return_special_tokens_mask: bool = False,
                 return_offsets_mapping: bool = False,
                 return_length: bool = False,
                 verbose: bool = True, **kwargs):
        encoding = super(ReaLiSeTokenizer, self).__call__(
            text=text,
            text_pair=text_pair,
            text_target=text_target,
            text_pair_target=text_pair_target,
            add_special_tokens=add_special_tokens,
            padding=padding,
            truncation=truncation,
            max_length=max_length,
            stride=stride,
            is_split_into_words=is_split_into_words,
            pad_to_multiple_of=pad_to_multiple_of,
            return_tensors=return_tensors,
            return_token_type_ids=return_token_type_ids,
            return_attention_mask=return_attention_mask,
            return_overflowing_tokens=return_overflowing_tokens,
            return_offsets_mapping=return_offsets_mapping,
            return_length=return_length,
            verbose=verbose,
        )

        input_ids = encoding['input_ids']
        if type(text) == str and return_tensors is None:
            input_ids = [input_ids]

        pho_idx_list = []
        pho_lens_list = []
        for ids in input_ids:
            chars = self.convert_ids_to_tokens(ids)
            pho_idx, pho_lens = self.pho2_convertor.convert(chars)
            if return_tensors is None:
                pho_idx = pho_idx.tolist()
            pho_idx_list.append(pho_idx)
            pho_lens_list += pho_lens

        pho_idx = pho_idx_list
        pho_lens = pho_lens_list
        if return_tensors == 'pt':
            pho_idx = torch.vstack(pho_idx)
            pho_lens = torch.LongTensor(pho_lens)

        if type(text) == str and return_tensors is None:
            pho_idx = pho_idx[0]

        encoding['pho_idx'] = pho_idx
        encoding['pho_lens'] = pho_lens

        return encoding