KaleiNeely commited on
Commit
2d52c88
·
1 Parent(s): 9fe2c59

Upload 2 files

Browse files
rwkv_vocab_v20230424.txt ADDED
The diff for this file is too large to render. See raw diff
 
tokenization_rwkv_world.py CHANGED
@@ -52,186 +52,52 @@ if TYPE_CHECKING:
52
  logger = logging.get_logger(__name__)
53
 
54
  VOCAB_FILES_NAMES = {
55
- "vocab_file": "rwkv_vocab_v20230424.json",
56
  }
57
 
58
-
59
- class DATrie:
60
- class Node:
61
- def __init__(self, is_leaf=False, leaf_data=None, tail=""):
62
- self._is_leaf = is_leaf
63
- self._leaf_data = leaf_data
64
- self._tail = tail
65
- self._next_map = {}
66
-
67
- def is_leaf(self):
68
- return self._is_leaf
69
-
70
- def set_leaf(self):
71
- self._is_leaf = True
72
-
73
- def has_next(self, w):
74
- if w in self._next_map:
75
- return True
76
- return False
77
-
78
- def add_node(self, w, node):
79
- self._next_map[w] = node
80
-
81
- def get_node(self, w):
82
- if w in self._next_map:
83
- return self._next_map[w]
84
- return None
85
-
86
- def get_tail(self):
87
- return self._tail
88
-
89
- def get_data(self):
90
- return self._leaf_data
91
-
92
- def set_data(self, data):
93
- self._leaf_data = data
94
-
95
- def __init__(self, special_ids):
96
- self.root = self.Node()
97
- self.data = {}
98
- self.r_data = {}
99
- self.special_ids = special_ids
100
-
101
- def insert(self, word, data):
102
- self.data[word] = data
103
- self.r_data[data] = word
104
- idx = 0
105
- node = self.root
106
- while idx < len(word):
107
- w = word[idx]
108
- is_leaf = (idx == (len(word) - 1))
109
- leaf_data = (data if is_leaf else None)
110
- # 不存在则插入
111
- if not node.has_next(w):
112
- node.add_node(w, self.Node(is_leaf=is_leaf, leaf_data=leaf_data))
113
- # last word
114
- node = node.get_node(w)
115
- idx += 1
116
- if not node.is_leaf():
117
- node.set_leaf()
118
- node.set_data(data)
119
-
120
- def findStrict(self, word):
121
- idx = 0
122
- node = self.root
123
- while node is not None and idx < len(word):
124
- w = word[idx]
125
- if not node.has_next(w):
126
- return None
127
- # last word
128
- node = node.get_node(w)
129
- idx += 1
130
- if node.is_leaf():
131
- return node.get_data()
132
- return None
133
-
134
- def prefix(self, word):
135
- idx = 0
136
- node = self.root
137
- result = []
138
- while node is not None and idx < len(word):
139
- w = word[idx]
140
- if not node.has_next(w):
141
- return result
142
- # last word
143
- node = node.get_node(w)
144
- if node.is_leaf():
145
- result.append([word[:idx + 1], node.get_data()])
146
- idx += 1
147
- return result
148
-
149
- def max_prefix(self, content, start_idx):
150
- idx = start_idx
151
- node = self.root
152
- l = len(content)
153
- result = [["", ], ]
154
- while node is not None and idx < l:
155
- w = content[idx]
156
- if not node.has_next(w):
157
- return result[-1]
158
- # last word
159
- node = node.get_node(w)
160
- if node.is_leaf():
161
- result.append([content[start_idx:idx + 1], node.get_data()])
162
  idx += 1
163
- return result[-1]
164
-
165
- def max_score(self, content, start_idx):
166
- idx = start_idx
167
- node = self.root
168
- l = len(content)
169
- result = [["", (3, 0)], ]
170
- while node is not None and idx < l:
171
- w = content[idx]
172
- if not node.has_next(w):
173
- break
174
- # last word
175
- node = node.get_node(w)
176
- if node.is_leaf():
177
- result.append([content[start_idx:idx + 1], node.get_data()])
178
- idx += 1
179
- if len(result) > 1:
180
- result = sorted(result, key=lambda x: x[1][1])
181
- return result[-1]
182
-
183
- def match(self, content, add_unk=True, unk_id=-1, **kwargs):
184
- # length
185
- l = len(content)
186
- i = 0
187
- result_list = []
188
- while i < l:
189
- match_word = self.max_prefix(content=content, start_idx=i)
190
- # print(match_word)
191
- w = match_word[0]
192
- if len(w) > 0:
193
- result_list.append(match_word[1])
194
- i += len(w)
195
- else:
196
- if add_unk:
197
- result_list.append(unk_id)
198
- i += 1
199
- return result_list
200
-
201
- def id2str(self, ids, escape_special_ids=True, end_ids=[], **kwargs):
202
- res_str = ""
203
- for rid in ids:
204
- if rid in self.r_data:
205
- if rid in end_ids:
206
- break
207
- if escape_special_ids and rid in self.special_ids:
208
- continue
209
- rstr = self.r_data[rid]
210
- res_str += rstr
211
- elif rid == 0:
212
  break
213
- else:
214
- print("ERROR unknown id %d" % rid)
215
- res_str += "UNK"
216
- return res_str
217
-
218
- def id2str_v2(self, ids, escape_special_ids=True, end_ids=[], **kwargs):
219
- res_str = ""
220
- for rid in ids:
221
- if rid in self.r_data:
222
- if rid in end_ids:
223
- break
224
- rstr = self.r_data[rid]
225
- if escape_special_ids and rid in self.special_ids:
226
- continue
227
- res_str += rstr
228
- elif rid == 0:
229
- break
230
- else:
231
- print("ERROR unknown id %d" % rid)
232
- res_str += "UNK"
233
- return res_str
234
-
235
 
236
  class RWKVWorldTokenizer(PreTrainedTokenizer):
237
  vocab_files_names = VOCAB_FILES_NAMES
@@ -244,17 +110,30 @@ class RWKVWorldTokenizer(PreTrainedTokenizer):
244
  **kwargs
245
  ):
246
  self.add_bos_token = False
 
 
 
 
 
 
 
 
 
 
 
 
247
 
248
- with open(vocab_file, encoding="utf-8") as vocab_handle:
249
- self.encoder = json.load(vocab_handle)
250
  super().__init__(
251
  errors=errors,
252
  **kwargs,
253
  )
254
- self.decoder = {v: k for k, v in self.encoder.items()}
255
- self.trie = DATrie(self.all_special_ids)
256
- for k, v in self.encoder.items():
257
- self.trie.insert(k, v)
 
 
 
258
  self.errors = errors # how to handle errors in decoding
259
  self.cache = {}
260
 
@@ -311,9 +190,23 @@ class RWKVWorldTokenizer(PreTrainedTokenizer):
311
  return [1] + ([0] * len(token_ids_0))
312
  return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1))
313
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
  def _tokenize(self, text, **kwargs):
315
  """Tokenize a string."""
316
- return self.trie.match(text, unk_id=self.unk_token_id, **kwargs)
317
 
318
  def _decode(self,
319
  token_ids: Union[int, List[int], "np.ndarray", "torch.Tensor", "tf.Tensor"],
@@ -326,13 +219,9 @@ class RWKVWorldTokenizer(PreTrainedTokenizer):
326
  if isinstance(token_ids, int):
327
  if token_ids in self.all_special_ids and skip_special_tokens:
328
  return ""
329
- return self.decoder.get(token_ids, self.unk_token)
330
  elif isinstance(token_ids, list):
331
- return self.trie.id2str(
332
- token_ids,
333
- escape_special_ids=skip_special_tokens,
334
- **kwargs
335
- )
336
  else:
337
  return token_ids
338
 
@@ -383,10 +272,10 @@ class RWKVWorldTokenizer(PreTrainedTokenizer):
383
  ) -> BatchEncoding:
384
  def get_input_ids(text):
385
  if isinstance(text, str):
386
- text_id = self.trie.match(text, unk_id=self.unk_token_id)
387
  return text_id
388
  elif isinstance(text, list) and len(text) > 0 and isinstance(text[0], str):
389
- return [self.trie.match(t, unk_id=self.unk_token_id) for t in text]
390
  elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
391
  return text
392
  else:
@@ -448,10 +337,10 @@ class RWKVWorldTokenizer(PreTrainedTokenizer):
448
  ) -> BatchEncoding:
449
  def get_input_ids(text):
450
  if isinstance(text, str):
451
- text_id = self.trie.match(text, unk_id=self.unk_token_id)
452
  return text_id
453
  elif isinstance(text, list) and len(text) > 0 and isinstance(text[0], str):
454
- return [self.trie.match(t, unk_id=self.unk_token_id) for t in text]
455
  elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
456
  return text
457
  else:
 
52
  logger = logging.get_logger(__name__)
53
 
54
  VOCAB_FILES_NAMES = {
55
+ "vocab_file": "rwkv_vocab_v20230424.txt",
56
  }
57
 
58
+ class TRIE:
59
+ __slots__ = tuple("ch,to,values,front".split(","))
60
+ to:list
61
+ values:set
62
+ def __init__(self, front=None, ch=None):
63
+ self.ch = ch
64
+ self.to = [None for ch in range(256)]
65
+ self.values = set()
66
+ self.front = front
67
+
68
+ def __repr__(self):
69
+ fr = self
70
+ ret = []
71
+ while(fr!=None):
72
+ if(fr.ch!=None):
73
+ ret.append(fr.ch)
74
+ fr = fr.front
75
+ return "<TRIE %s %s>"%(ret[::-1], self.values)
76
+
77
+ def add(self, key:bytes, idx:int=0, val=None):
78
+ if(idx == len(key)):
79
+ if(val is None):
80
+ val = key
81
+ self.values.add(val)
82
+ return self
83
+ ch = key[idx]
84
+ if(self.to[ch] is None):
85
+ self.to[ch] = TRIE(front=self, ch=ch)
86
+ return self.to[ch].add(key, idx=idx+1, val=val)
87
+
88
+ def find_longest(self, key:bytes, idx:int=0):
89
+ u:TRIE = self
90
+ ch:int = key[idx]
91
+
92
+ while(u.to[ch] is not None):
93
+ u = u.to[ch]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  idx += 1
95
+ if(u.values):
96
+ ret = idx, u, u.values
97
+ if(idx==len(key)):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  break
99
+ ch = key[idx]
100
+ return ret
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
  class RWKVWorldTokenizer(PreTrainedTokenizer):
103
  vocab_files_names = VOCAB_FILES_NAMES
 
110
  **kwargs
111
  ):
112
  self.add_bos_token = False
113
+ self.encoder = {}
114
+ sorted = [] # must be already sorted
115
+ with open(vocab_file, "r", encoding="utf-8") as f:
116
+ lines = f.readlines()
117
+ for l in lines:
118
+ idx = int(l[:l.index(' ')])
119
+ x = eval(l[l.index(' '):l.rindex(' ')])
120
+ x = x.encode("utf-8") if isinstance(x, str) else x
121
+ assert isinstance(x, bytes)
122
+ assert len(x) == int(l[l.rindex(' '):])
123
+ sorted += [x]
124
+ self.encoder[idx] = x
125
 
 
 
126
  super().__init__(
127
  errors=errors,
128
  **kwargs,
129
  )
130
+ self.decoder = {}
131
+ for k,v in self.encoder.items():
132
+ self.decoder[v] = int(k)
133
+
134
+ self.trie = TRIE()
135
+ for t, i in self.decoder.items():
136
+ _ = self.trie.add(t, val=(t, i))
137
  self.errors = errors # how to handle errors in decoding
138
  self.cache = {}
139
 
 
190
  return [1] + ([0] * len(token_ids_0))
191
  return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1))
192
 
193
+ def encodeBytes(self, src:bytes):
194
+ idx:int = 0
195
+ tokens = []
196
+ while (idx < len(src)):
197
+ _idx:int = idx
198
+ idx, _, values = self.trie.find_longest(src, idx)
199
+ assert(idx != _idx)
200
+ _, token = next(iter(values))
201
+ tokens.append(token)
202
+ return tokens
203
+
204
+ def decodeBytes(self, tokens):
205
+ return b''.join(map(lambda i: self.encoder[i], tokens))
206
+
207
  def _tokenize(self, text, **kwargs):
208
  """Tokenize a string."""
209
+ return self.encodeBytes(text.encode("utf-8"))
210
 
211
  def _decode(self,
212
  token_ids: Union[int, List[int], "np.ndarray", "torch.Tensor", "tf.Tensor"],
 
219
  if isinstance(token_ids, int):
220
  if token_ids in self.all_special_ids and skip_special_tokens:
221
  return ""
222
+ return self.encoder.get(token_ids, self.unk_token)
223
  elif isinstance(token_ids, list):
224
+ return self.decodeBytes(token_ids).decode('utf-8')
 
 
 
 
225
  else:
226
  return token_ids
227
 
 
272
  ) -> BatchEncoding:
273
  def get_input_ids(text):
274
  if isinstance(text, str):
275
+ text_id = self._tokenize(text)
276
  return text_id
277
  elif isinstance(text, list) and len(text) > 0 and isinstance(text[0], str):
278
+ return [self._tokenize(t) for t in text]
279
  elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
280
  return text
281
  else:
 
337
  ) -> BatchEncoding:
338
  def get_input_ids(text):
339
  if isinstance(text, str):
340
+ text_id = self._tokenize(text)
341
  return text_id
342
  elif isinstance(text, list) and len(text) > 0 and isinstance(text[0], str):
343
+ return [self._tokenize(t) for t in text]
344
  elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
345
  return text
346
  else: