Make _decode compatible with PreTrainedTokenizerBase (#8)
Browse files- Make _decode compatible with PreTrainedTokenizerBase (1d07fb39189bd0bb45500bca7dfce4107eda74fb)
Co-authored-by: Vincent Brouwers <[email protected]>
tokenization_codegen25.py
CHANGED
@@ -4,7 +4,7 @@
|
|
4 |
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/Apache-2.0
|
5 |
"""Tokenization classes for CodeGen2.5."""
|
6 |
|
7 |
-
from typing import List, Optional
|
8 |
|
9 |
from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
|
10 |
from transformers.utils import logging
|
@@ -168,7 +168,9 @@ class CodeGen25Tokenizer(PreTrainedTokenizer):
|
|
168 |
"""Converts an index (integer) in a token (str) using the vocab."""
|
169 |
return self.encoder.decode_single_token_bytes(index).decode("utf-8")
|
170 |
|
171 |
-
def _decode(self, token_ids: List[int], skip_special_tokens: bool = False, **kwargs):
|
|
|
|
|
172 |
if skip_special_tokens:
|
173 |
token_ids = [t for t in token_ids if t not in self.all_special_ids]
|
174 |
return self.encoder.decode(token_ids)
|
|
|
4 |
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/Apache-2.0
|
5 |
"""Tokenization classes for CodeGen2.5."""
|
6 |
|
7 |
+
from typing import List, Optional, Union
|
8 |
|
9 |
from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
|
10 |
from transformers.utils import logging
|
|
|
168 |
"""Converts an index (integer) in a token (str) using the vocab."""
|
169 |
return self.encoder.decode_single_token_bytes(index).decode("utf-8")
|
170 |
|
171 |
+
def _decode(self, token_ids: Union[int, List[int]], skip_special_tokens: bool = False, **kwargs):
|
172 |
+
if isinstance(token_ids, int):
|
173 |
+
token_ids = [token_ids]
|
174 |
if skip_special_tokens:
|
175 |
token_ids = [t for t in token_ids if t not in self.all_special_ids]
|
176 |
return self.encoder.decode(token_ids)
|