File size: 3,499 Bytes
50e5a6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104

import torch
import sentencepiece
import jieba
import numpy as np

from transformers.tokenization_utils import PreTrainedTokenizer


class GPTPanguTokenizer(PreTrainedTokenizer):
    # Ref: https://git.openi.org.cn/PCL-Platform.Intelligence/PanGu-Alpha/src/branch/master/tokenization_jieba.py
    vocab_files_names = {
        "model_file": "vocab.model"
    }

    def __init__(
            self,
            model_file,
            **kwargs
    ):
        super().__init__(**kwargs)

        self.sp = sentencepiece.SentencePieceProcessor()
        self.sp.Load(model_file=model_file)
        self.translator = str.maketrans(" \n", "\u2582\u2583")

        # special token ids
        # self.eos_token_id = self.sp.piece_to_id("<eot>")

    def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
        """
        Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
        adding special tokens. A BERT sequence has the following format:

        - single sequence: `[CLS] X [SEP]`
        - pair of sequences: `[CLS] A [SEP] B [SEP]`

        Args:
            token_ids_0 (`List[int]`):
                List of IDs to which the special tokens will be added.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.

        Returns:
            `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
        """
        if self.bos_token_id is not None:
            if token_ids_1 is None:
                return [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
            bos = [self.bos_token_id]
            sep = [self.sep_token_id]
            eos = [self.eos_token_id]
            return bos + token_ids_0 + sep + token_ids_1 + eos
        else:
            if token_ids_1 is None:
                return token_ids_0 + [self.eos_token_id]
            sep = [self.sep_token_id]
            eos = [self.eos_token_id]
            return token_ids_0 + sep + token_ids_1 + eos

    def tokenize(self, text, **kwargs):
        """ Tokenize a string. """
        seg_list = [x.translate(self.translator) for x in jieba.cut(text, cut_all=False)]
        return seg_list

    def convert_tokens_to_ids(self, tokens):
        if tokens is None:
            return None

        if isinstance(tokens, str):
            return self._convert_token_to_id_with_added_voc(tokens)
            
        new_seg = " ".join(tokens)
        return self.sp.encode(new_seg)
        # return tokens

    def _convert_token_to_id(self, token):
        return self.sp.piece_to_id(token)

    def _convert_id_to_token(self, index):
        return self.sp.id_to_piece(index)

    def convert_ids_to_tokens(self, ids):
        return self.decode(ids)

    def decode(self, tokens, **kwargs):
        if isinstance(tokens, torch.Tensor) or isinstance(tokens, np.ndarray):
            tokens = tokens.tolist()

        if kwargs.get('skip_special_tokens', None) is True:
            tokens = [token for token in tokens if token not in self.all_special_ids]
        text = self.sp.decode(tokens)
        if isinstance(text, list):
            text = text[0]
        text = text.replace(' ', '').replace('\u2582', ' ').replace('\u2583', '\n')
        return text

    @property
    def vocab_size(self) -> int:
        """
        `int`: Size of the base vocabulary (without the added tokens).
        """
        return len(self.sp)