Update tokenizer_script.py
Browse files- tokenizer_script.py +5 -5
tokenizer_script.py
CHANGED
@@ -76,19 +76,19 @@ class CharacterTokenizer(PreTrainedTokenizer):
|
|
76 |
|
77 |
return (vocab_file,)
|
78 |
|
79 |
-
def batch_encode(
|
80 |
-
encoded_texts = [
|
81 |
# Handle max_length (truncation)
|
82 |
if max_length is not None:
|
83 |
encoded_texts = [ids[:max_length] for ids in encoded_texts]
|
84 |
if add_special_tokens:
|
85 |
-
bos_token_id =
|
86 |
-
eos_token_id =
|
87 |
encoded_texts = [[bos_token_id] + ids + [eos_token_id] for ids in encoded_texts]
|
88 |
# Handle padding
|
89 |
if padding:
|
90 |
# properly handle padding side
|
91 |
-
pad_id =
|
92 |
max_len = max(len(ids) for ids in encoded_texts) if max_length is None else max_length
|
93 |
if tokenizer.padding_side == "right":
|
94 |
encoded_texts = [ids + [pad_id] * (max_len - len(ids)) for ids in encoded_texts]
|
|
|
76 |
|
77 |
return (vocab_file,)
|
78 |
|
79 |
+
def batch_encode(self, texts, add_special_tokens=False, padding=False, truncation=True, max_length=None):
|
80 |
+
encoded_texts = [self.encode(text) for text in texts]
|
81 |
# Handle max_length (truncation)
|
82 |
if max_length is not None:
|
83 |
encoded_texts = [ids[:max_length] for ids in encoded_texts]
|
84 |
if add_special_tokens:
|
85 |
+
bos_token_id = self.convert_tokens_to_ids(tokenizer.bos_token)
|
86 |
+
eos_token_id = self.convert_tokens_to_ids(tokenizer.eos_token)
|
87 |
encoded_texts = [[bos_token_id] + ids + [eos_token_id] for ids in encoded_texts]
|
88 |
# Handle padding
|
89 |
if padding:
|
90 |
# properly handle padding side
|
91 |
+
pad_id = self.vocab.get(tokenizer.pad_token, 0)
|
92 |
max_len = max(len(ids) for ids in encoded_texts) if max_length is None else max_length
|
93 |
if tokenizer.padding_side == "right":
|
94 |
encoded_texts = [ids + [pad_id] * (max_len - len(ids)) for ids in encoded_texts]
|