Update tokenization_chatglm.py
Browse files- tokenization_chatglm.py +30 -5
tokenization_chatglm.py
CHANGED
@@ -1,11 +1,13 @@
|
|
|
|
1 |
import os
|
2 |
-
import
|
3 |
from typing import List, Optional, Union, Dict
|
4 |
from sentencepiece import SentencePieceProcessor
|
5 |
from transformers import PreTrainedTokenizer
|
6 |
from transformers.utils import logging, PaddingStrategy
|
7 |
from transformers.tokenization_utils_base import EncodedInput, BatchEncoding
|
8 |
|
|
|
9 |
class SPTokenizer:
|
10 |
def __init__(self, model_path: str):
|
11 |
# reload tokenizer
|
@@ -30,6 +32,7 @@ class SPTokenizer:
|
|
30 |
def tokenize(self, s: str):
|
31 |
return self.sp_model.EncodeAsPieces(s)
|
32 |
|
|
|
33 |
def encode(self, s: str, bos: bool = False, eos: bool = False) -> List[int]:
|
34 |
assert type(s) is str
|
35 |
t = self.sp_model.encode(s)
|
@@ -40,7 +43,18 @@ class SPTokenizer:
|
|
40 |
return t
|
41 |
|
42 |
def decode(self, t: List[int]) -> str:
|
43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
def decode_tokens(self, tokens: List[str]) -> str:
|
46 |
text = self.sp_model.DecodePieces(tokens)
|
@@ -54,7 +68,9 @@ class SPTokenizer:
|
|
54 |
|
55 |
def convert_id_to_token(self, index):
|
56 |
"""Converts an index (integer) in a token (str) using the vocab."""
|
57 |
-
if index in self.index_special_tokens
|
|
|
|
|
58 |
return ""
|
59 |
return self.sp_model.IdToPiece(index)
|
60 |
|
@@ -64,8 +80,8 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
|
64 |
|
65 |
model_input_names = ["input_ids", "attention_mask", "position_ids"]
|
66 |
|
67 |
-
def __init__(self, vocab_file, padding_side="left",
|
68 |
-
|
69 |
self.name = "GLMTokenizer"
|
70 |
|
71 |
self.vocab_file = vocab_file
|
@@ -75,6 +91,10 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
|
75 |
"<eos>": self.tokenizer.eos_id,
|
76 |
"<pad>": self.tokenizer.pad_id
|
77 |
}
|
|
|
|
|
|
|
|
|
78 |
|
79 |
def get_command(self, token):
|
80 |
if token in self.special_tokens:
|
@@ -82,6 +102,10 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
|
82 |
assert token in self.tokenizer.special_tokens, f"{token} is not a special token for {self.name}"
|
83 |
return self.tokenizer.special_tokens[token]
|
84 |
|
|
|
|
|
|
|
|
|
85 |
@property
|
86 |
def pad_token(self) -> str:
|
87 |
return "<unk>"
|
@@ -163,6 +187,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
|
163 |
prompt += "[Round {}]\n\n问:{}\n\n答:".format(len(history) + 1, query)
|
164 |
return prompt
|
165 |
|
|
|
166 |
def build_inputs_with_special_tokens(
|
167 |
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
168 |
) -> List[int]:
|
|
|
1 |
+
import json
|
2 |
import os
|
3 |
+
import re
|
4 |
from typing import List, Optional, Union, Dict
|
5 |
from sentencepiece import SentencePieceProcessor
|
6 |
from transformers import PreTrainedTokenizer
|
7 |
from transformers.utils import logging, PaddingStrategy
|
8 |
from transformers.tokenization_utils_base import EncodedInput, BatchEncoding
|
9 |
|
10 |
+
|
11 |
class SPTokenizer:
|
12 |
def __init__(self, model_path: str):
|
13 |
# reload tokenizer
|
|
|
32 |
def tokenize(self, s: str):
|
33 |
return self.sp_model.EncodeAsPieces(s)
|
34 |
|
35 |
+
|
36 |
def encode(self, s: str, bos: bool = False, eos: bool = False) -> List[int]:
|
37 |
assert type(s) is str
|
38 |
t = self.sp_model.encode(s)
|
|
|
43 |
return t
|
44 |
|
45 |
def decode(self, t: List[int]) -> str:
|
46 |
+
text, buffer = "", []
|
47 |
+
for token in t:
|
48 |
+
if token in self.index_special_tokens:
|
49 |
+
if buffer:
|
50 |
+
text += self.sp_model.decode(buffer)
|
51 |
+
buffer = []
|
52 |
+
text += self.index_special_tokens[token]
|
53 |
+
else:
|
54 |
+
buffer.append(token)
|
55 |
+
if buffer:
|
56 |
+
text += self.sp_model.decode(buffer)
|
57 |
+
return text
|
58 |
|
59 |
def decode_tokens(self, tokens: List[str]) -> str:
|
60 |
text = self.sp_model.DecodePieces(tokens)
|
|
|
68 |
|
69 |
def convert_id_to_token(self, index):
|
70 |
"""Converts an index (integer) in a token (str) using the vocab."""
|
71 |
+
if index in self.index_special_tokens:
|
72 |
+
return self.index_special_tokens[index]
|
73 |
+
if index in [self.eos_id, self.bos_id, self.pad_id] or index < 0 or index > self.sp_model.vocab_size():
|
74 |
return ""
|
75 |
return self.sp_model.IdToPiece(index)
|
76 |
|
|
|
80 |
|
81 |
model_input_names = ["input_ids", "attention_mask", "position_ids"]
|
82 |
|
83 |
+
def __init__(self, vocab_file, padding_side="left", clean_up_tokenization_spaces=False, encode_special_tokens=False,
|
84 |
+
**kwargs):
|
85 |
self.name = "GLMTokenizer"
|
86 |
|
87 |
self.vocab_file = vocab_file
|
|
|
91 |
"<eos>": self.tokenizer.eos_id,
|
92 |
"<pad>": self.tokenizer.pad_id
|
93 |
}
|
94 |
+
self.encode_special_tokens = encode_special_tokens
|
95 |
+
super().__init__(padding_side=padding_side, clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
96 |
+
encode_special_tokens=encode_special_tokens,
|
97 |
+
**kwargs)
|
98 |
|
99 |
def get_command(self, token):
|
100 |
if token in self.special_tokens:
|
|
|
102 |
assert token in self.tokenizer.special_tokens, f"{token} is not a special token for {self.name}"
|
103 |
return self.tokenizer.special_tokens[token]
|
104 |
|
105 |
+
@property
|
106 |
+
def unk_token(self) -> str:
|
107 |
+
return "<unk>"
|
108 |
+
|
109 |
@property
|
110 |
def pad_token(self) -> str:
|
111 |
return "<unk>"
|
|
|
187 |
prompt += "[Round {}]\n\n问:{}\n\n答:".format(len(history) + 1, query)
|
188 |
return prompt
|
189 |
|
190 |
+
|
191 |
def build_inputs_with_special_tokens(
|
192 |
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
193 |
) -> List[int]:
|