Curt-Park commited on
Commit
a6635ad
·
1 Parent(s): 0606cc6

Fix postprocess with autotokenizer

Browse files
model_repository/postprocessing/1/gpt2-merges.txt DELETED
The diff for this file is too large to render. See raw diff
 
model_repository/postprocessing/1/gpt2-vocab.json DELETED
The diff for this file is too large to render. See raw diff
 
model_repository/postprocessing/1/model.py CHANGED
@@ -5,15 +5,7 @@ from typing import Any, Dict, List
5
 
6
  import numpy as np
7
  import triton_python_backend_utils as pb_utils
8
- import utils.gpt_token_encoder as encoder
9
-
10
- # GPT3 Related variables
11
- # Reference:
12
- # https://github.com/NVIDIA/FasterTransformer/blob/main/sample/pytorch/gpt_sample.py
13
- MERGES_FILE = "gpt2-merges.txt"
14
- VOCAB_FILE = "gpt2-vocab.json"
15
-
16
- MAX_BATCH_SIZE = 8
17
 
18
 
19
  class TritonPythonModel:
@@ -24,8 +16,6 @@ class TritonPythonModel:
24
 
25
  Implementing `initialize` function is optional. This function allows
26
  the model to initialize any state associated with this model.
27
- Parameters
28
-
29
  Args:
30
  Both keys and values are strings. The dictionary keys and values are:
31
  * model_config: A JSON string containing the model configuration
@@ -44,6 +34,13 @@ class TritonPythonModel:
44
  # Convert Triton types to numpy types
45
  self.output_dtype = pb_utils.triton_string_to_numpy(output_config["data_type"])
46
 
 
 
 
 
 
 
 
47
  def execute(
48
  self, requests: List["pb_utils.InferenceRequest"]
49
  ) -> List["pb_utils.InferenceResponse"]:
@@ -115,14 +112,8 @@ class TritonPythonModel:
115
 
116
  def _postprocessing(self, tokens_batch: np.ndarray) -> List[bytes]:
117
  """Postprocess."""
118
- cur_folder = Path(__file__).parent
119
- enc = encoder.get_encoder(
120
- str(cur_folder / VOCAB_FILE), str(cur_folder / MERGES_FILE)
121
- )
122
-
123
  outputs = []
124
  for beam_tokens in tokens_batch:
125
  for tokens in beam_tokens:
126
- output = enc.decode(tokens)
127
- outputs.append(output.encode("utf8"))
128
  return outputs
 
5
 
6
  import numpy as np
7
  import triton_python_backend_utils as pb_utils
8
+ from transformers import AutoTokenizer
 
 
 
 
 
 
 
 
9
 
10
 
11
  class TritonPythonModel:
 
16
 
17
  Implementing `initialize` function is optional. This function allows
18
  the model to initialize any state associated with this model.
 
 
19
  Args:
20
  Both keys and values are strings. The dictionary keys and values are:
21
  * model_config: A JSON string containing the model configuration
 
34
  # Convert Triton types to numpy types
35
  self.output_dtype = pb_utils.triton_string_to_numpy(output_config["data_type"])
36
 
37
+ # Init a tokenizer for postprocessing.
38
+ cur_folder = Path(__file__).parent
39
+ cache_dir = cur_folder / ".cache"
40
+ self.tokenizer = AutoTokenizer.from_pretrained(
41
+ "Salesforce/codegen-350M-mono", cache_dir=cache_dir
42
+ )
43
+
44
  def execute(
45
  self, requests: List["pb_utils.InferenceRequest"]
46
  ) -> List["pb_utils.InferenceResponse"]:
 
112
 
113
  def _postprocessing(self, tokens_batch: np.ndarray) -> List[bytes]:
114
  """Postprocess."""
 
 
 
 
 
115
  outputs = []
116
  for beam_tokens in tokens_batch:
117
  for tokens in beam_tokens:
118
+ outputs.append(self.tokenizer.decode(tokens))
 
119
  return outputs
model_repository/postprocessing/1/utils/__init__.py DELETED
@@ -1,13 +0,0 @@
1
- # Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model_repository/postprocessing/1/utils/gpt_token_encoder.py DELETED
@@ -1,172 +0,0 @@
1
- """Byte pair encoding utilities"""
2
-
3
- # Modified MIT License
4
-
5
- # Software Copyright (c) 2019 OpenAI
6
-
7
- # We don’t claim ownership of the content you create with GPT-2, so it is yours to do with as you please.
8
- # We only ask that you use GPT-2 responsibly and clearly indicate your content was created using GPT-2.
9
-
10
- # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
11
- # associated documentation files (the "Software"), to deal in the Software without restriction,
12
- # including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
13
- # and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so,
14
- # subject to the following conditions:
15
-
16
- # The above copyright notice and this permission notice shall be included
17
- # in all copies or substantial portions of the Software.
18
- # The above copyright notice and this permission notice need not be included
19
- # with content created by the Software.
20
-
21
- # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
22
- # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
23
- # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
24
- # BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
25
- # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
26
- # OR OTHER DEALINGS IN THE SOFTWARE.
27
- # Copyright (c) 2021-2022, NVIDIA CORPORATION.  All rights reserved.
28
- #
29
- # Licensed under the Apache License, Version 2.0 (the "License");
30
- # you may not use this file except in compliance with the License.
31
- # You may obtain a copy of the License at
32
- #
33
- #     http://www.apache.org/licenses/LICENSE-2.0
34
- #
35
- # Unless required by applicable law or agreed to in writing, software
36
- # distributed under the License is distributed on an "AS IS" BASIS,
37
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
38
- # See the License for the specific language governing permissions and
39
- # limitations under the License.
40
-
41
- import json
42
- import os
43
- from functools import lru_cache
44
-
45
- import regex as re
46
-
47
-
48
- @lru_cache()
49
- def bytes_to_unicode():
50
- """
51
- Returns list of utf-8 byte and a corresponding list of unicode strings.
52
- The reversible bpe codes work on unicode strings.
53
- This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
54
- When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
55
- This is a signficant percentage of your normal, say, 32K bpe vocab.
56
- To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
57
- And avoids mapping to whitespace/control characters the bpe code barfs on.
58
- """
59
- bs = (
60
- list(range(ord("!"), ord("~") + 1))
61
- + list(range(ord("¡"), ord("¬") + 1))
62
- + list(range(ord("®"), ord("ÿ") + 1))
63
- )
64
- cs = bs[:]
65
- n = 0
66
- for b in range(2**8):
67
- if b not in bs:
68
- bs.append(b)
69
- cs.append(2**8 + n)
70
- n += 1
71
- cs = [chr(n) for n in cs]
72
- return dict(zip(bs, cs))
73
-
74
-
75
- def get_pairs(word):
76
- """Return set of symbol pairs in a word.
77
-
78
- Word is represented as tuple of symbols (symbols being variable-length strings).
79
- """
80
- pairs = set()
81
- prev_char = word[0]
82
- for char in word[1:]:
83
- pairs.add((prev_char, char))
84
- prev_char = char
85
- return pairs
86
-
87
-
88
- class Encoder:
89
- def __init__(self, encoder, bpe_merges, errors="replace"):
90
- self.encoder = encoder
91
- self.decoder = {v: k for k, v in self.encoder.items()}
92
- self.errors = errors # how to handle errors in decoding
93
- self.byte_encoder = bytes_to_unicode()
94
- self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
95
- self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
96
- self.cache = {}
97
-
98
- # Should have added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
99
- self.pat = re.compile(
100
- r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
101
- )
102
-
103
- def bpe(self, token):
104
- if token in self.cache:
105
- return self.cache[token]
106
- word = tuple(token)
107
- pairs = get_pairs(word)
108
-
109
- if not pairs:
110
- return token
111
-
112
- while True:
113
- bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
114
- if bigram not in self.bpe_ranks:
115
- break
116
- first, second = bigram
117
- new_word = []
118
- i = 0
119
- while i < len(word):
120
- try:
121
- j = word.index(first, i)
122
- new_word.extend(word[i:j])
123
- i = j
124
- except:
125
- new_word.extend(word[i:])
126
- break
127
-
128
- if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
129
- new_word.append(first + second)
130
- i += 2
131
- else:
132
- new_word.append(word[i])
133
- i += 1
134
- new_word = tuple(new_word)
135
- word = new_word
136
- if len(word) == 1:
137
- break
138
- else:
139
- pairs = get_pairs(word)
140
- word = " ".join(word)
141
- self.cache[token] = word
142
- return word
143
-
144
- def encode(self, text):
145
- bpe_tokens = []
146
- for token in re.findall(self.pat, text):
147
- token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
148
- bpe_tokens.extend(
149
- self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")
150
- )
151
- return bpe_tokens
152
-
153
- def decode(self, tokens):
154
- text = "".join(
155
- [self.decoder[min(token, 50256)] for token in tokens]
156
- )
157
- text = bytearray([self.byte_decoder[c] for c in text]).decode(
158
- "utf-8", errors=self.errors
159
- )
160
- return text
161
-
162
-
163
- def get_encoder(vocab_file, bpe_file):
164
- with open(vocab_file, "r") as f:
165
- encoder = json.load(f)
166
- with open(bpe_file, "r", encoding="utf-8") as f:
167
- bpe_data = f.read()
168
- bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split("\n")[1:-1]]
169
- return Encoder(
170
- encoder=encoder,
171
- bpe_merges=bpe_merges,
172
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model_repository/preprocessing/1/model.py CHANGED
@@ -9,12 +9,8 @@ import torch
9
  import triton_python_backend_utils as pb_utils
10
  from torch.nn.utils.rnn import pad_sequence
11
  from transformers import AutoTokenizer
12
- from word_list import to_word_list_format
13
 
14
-
15
- START_ID = 50256
16
  END_ID = 50256
17
- MAX_BATCH_SIZE = 8
18
 
19
 
20
  class TritonPythonModel:
@@ -102,8 +98,8 @@ class TritonPythonModel:
102
 
103
  # Preprocessing input data.
104
  input_id, request_input_len = self._create_request(query)
105
- bad_words = to_word_list_format(bad_words_dict)
106
- stop_words = to_word_list_format(stop_words_dict)
107
 
108
  # Create output tensors. You need pb_utils.Tensor
109
  # objects to create pb_utils.InferenceResponse.
@@ -165,7 +161,8 @@ class TritonPythonModel:
165
 
166
  return start_ids, start_lengths
167
 
168
- def _create_word_list(self, word_dict: Dict[str, Any]) -> np.ndarray:
 
169
  flat_ids = []
170
  offsets = []
171
  for word_dict_item in word_dict:
 
9
  import triton_python_backend_utils as pb_utils
10
  from torch.nn.utils.rnn import pad_sequence
11
  from transformers import AutoTokenizer
 
12
 
 
 
13
  END_ID = 50256
 
14
 
15
 
16
  class TritonPythonModel:
 
98
 
99
  # Preprocessing input data.
100
  input_id, request_input_len = self._create_request(query)
101
+ bad_words = self._create_word_list(bad_words_dict)
102
+ stop_words = self._create_word_list(stop_words_dict)
103
 
104
  # Create output tensors. You need pb_utils.Tensor
105
  # objects to create pb_utils.InferenceResponse.
 
161
 
162
  return start_ids, start_lengths
163
 
164
+ def _create_word_list(self, word_dict: np.ndarray) -> np.ndarray:
165
+ """Encode the word list."""
166
  flat_ids = []
167
  offsets = []
168
  for word_dict_item in word_dict:
model_repository/preprocessing/1/utils/__init__.py DELETED
@@ -1,13 +0,0 @@
1
- # Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model_repository/preprocessing/1/utils/gpt_token_encoder.py DELETED
@@ -1,170 +0,0 @@
1
- """Byte pair encoding utilities"""
2
-
3
- # Modified MIT License
4
-
5
- # Software Copyright (c) 2019 OpenAI
6
-
7
- # We don’t claim ownership of the content you create with GPT-2, so it is yours to do with as you please.
8
- # We only ask that you use GPT-2 responsibly and clearly indicate your content was created using GPT-2.
9
-
10
- # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
11
- # associated documentation files (the "Software"), to deal in the Software without restriction,
12
- # including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
13
- # and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so,
14
- # subject to the following conditions:
15
-
16
- # The above copyright notice and this permission notice shall be included
17
- # in all copies or substantial portions of the Software.
18
- # The above copyright notice and this permission notice need not be included
19
- # with content created by the Software.
20
-
21
- # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
22
- # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
23
- # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
24
- # BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
25
- # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
26
- # OR OTHER DEALINGS IN THE SOFTWARE.
27
- # Copyright (c) 2021-2022, NVIDIA CORPORATION.  All rights reserved.
28
- #
29
- # Licensed under the Apache License, Version 2.0 (the "License");
30
- # you may not use this file except in compliance with the License.
31
- # You may obtain a copy of the License at
32
- #
33
- #     http://www.apache.org/licenses/LICENSE-2.0
34
- #
35
- # Unless required by applicable law or agreed to in writing, software
36
- # distributed under the License is distributed on an "AS IS" BASIS,
37
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
38
- # See the License for the specific language governing permissions and
39
- # limitations under the License.
40
-
41
- import json
42
- import os
43
- from functools import lru_cache
44
-
45
- import regex as re
46
-
47
-
48
- @lru_cache()
49
- def bytes_to_unicode():
50
- """
51
- Returns list of utf-8 byte and a corresponding list of unicode strings.
52
- The reversible bpe codes work on unicode strings.
53
- This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
54
- When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
55
- This is a signficant percentage of your normal, say, 32K bpe vocab.
56
- To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
57
- And avoids mapping to whitespace/control characters the bpe code barfs on.
58
- """
59
- bs = (
60
- list(range(ord("!"), ord("~") + 1))
61
- + list(range(ord("¡"), ord("¬") + 1))
62
- + list(range(ord("®"), ord("ÿ") + 1))
63
- )
64
- cs = bs[:]
65
- n = 0
66
- for b in range(2**8):
67
- if b not in bs:
68
- bs.append(b)
69
- cs.append(2**8 + n)
70
- n += 1
71
- cs = [chr(n) for n in cs]
72
- return dict(zip(bs, cs))
73
-
74
-
75
- def get_pairs(word):
76
- """Return set of symbol pairs in a word.
77
-
78
- Word is represented as tuple of symbols (symbols being variable-length strings).
79
- """
80
- pairs = set()
81
- prev_char = word[0]
82
- for char in word[1:]:
83
- pairs.add((prev_char, char))
84
- prev_char = char
85
- return pairs
86
-
87
-
88
- class Encoder:
89
- def __init__(self, encoder, bpe_merges, errors="replace"):
90
- self.encoder = encoder
91
- self.decoder = {v: k for k, v in self.encoder.items()}
92
- self.errors = errors # how to handle errors in decoding
93
- self.byte_encoder = bytes_to_unicode()
94
- self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
95
- self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
96
- self.cache = {}
97
-
98
- # Should have added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
99
- self.pat = re.compile(
100
- r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
101
- )
102
-
103
- def bpe(self, token):
104
- if token in self.cache:
105
- return self.cache[token]
106
- word = tuple(token)
107
- pairs = get_pairs(word)
108
-
109
- if not pairs:
110
- return token
111
-
112
- while True:
113
- bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
114
- if bigram not in self.bpe_ranks:
115
- break
116
- first, second = bigram
117
- new_word = []
118
- i = 0
119
- while i < len(word):
120
- try:
121
- j = word.index(first, i)
122
- new_word.extend(word[i:j])
123
- i = j
124
- except:
125
- new_word.extend(word[i:])
126
- break
127
-
128
- if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
129
- new_word.append(first + second)
130
- i += 2
131
- else:
132
- new_word.append(word[i])
133
- i += 1
134
- new_word = tuple(new_word)
135
- word = new_word
136
- if len(word) == 1:
137
- break
138
- else:
139
- pairs = get_pairs(word)
140
- word = " ".join(word)
141
- self.cache[token] = word
142
- return word
143
-
144
- def encode(self, text):
145
- bpe_tokens = []
146
- for token in re.findall(self.pat, text):
147
- token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
148
- bpe_tokens.extend(
149
- self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")
150
- )
151
- return bpe_tokens
152
-
153
- def decode(self, tokens):
154
- text = "".join([self.decoder[token] for token in tokens])
155
- text = bytearray([self.byte_decoder[c] for c in text]).decode(
156
- "utf-8", errors=self.errors
157
- )
158
- return text
159
-
160
-
161
- def get_encoder(vocab_file, bpe_file):
162
- with open(vocab_file, "r") as f:
163
- encoder = json.load(f)
164
- with open(bpe_file, "r", encoding="utf-8") as f:
165
- bpe_data = f.read()
166
- bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split("\n")[1:-1]]
167
- return Encoder(
168
- encoder=encoder,
169
- bpe_merges=bpe_merges,
170
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model_repository/preprocessing/1/word_list.py DELETED
@@ -1,56 +0,0 @@
1
- # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import csv
16
- from pathlib import Path
17
-
18
- import numpy as np
19
- from transformers import AutoTokenizer
20
-
21
-
22
- def to_word_list_format(word_dict):
23
- cache_dir = Path(__file__).parent / ".cache"
24
- tokenizer = AutoTokenizer.from_pretrained(
25
- "Salesforce/codegen-350M-mono", cache_dir=cache_dir
26
- )
27
-
28
- flat_ids = []
29
- offsets = []
30
- for word_dict_item in word_dict:
31
- item_flat_ids = []
32
- item_offsets = []
33
-
34
- if isinstance(word_dict_item[0], bytes):
35
- word_dict_item = [word_dict_item[0].decode()]
36
-
37
- words = list(csv.reader(word_dict_item))[0]
38
- for word in words:
39
- ids = tokenizer.encode(word)
40
-
41
- if len(ids) == 0:
42
- continue
43
-
44
- item_flat_ids += ids
45
- item_offsets.append(len(ids))
46
-
47
- flat_ids.append(np.array(item_flat_ids))
48
- offsets.append(np.cumsum(np.array(item_offsets)))
49
-
50
- pad_to = max(1, max(len(ids) for ids in flat_ids))
51
-
52
- for i, (ids, offs) in enumerate(zip(flat_ids, offsets)):
53
- flat_ids[i] = np.pad(ids, (0, pad_to - len(ids)), constant_values=0)
54
- offsets[i] = np.pad(offs, (0, pad_to - len(offs)), constant_values=-1)
55
-
56
- return np.array([flat_ids, offsets], dtype="int32").transpose((1, 0, 2))