update tokenizer code for compatibility with latest transformers

#12
by katuni4ka - opened
Files changed (1) hide show
  1. tokenization_codegen25.py +6 -3
tokenization_codegen25.py CHANGED
@@ -133,15 +133,14 @@ class CodeGen25Tokenizer(PreTrainedTokenizer):
133
  ):
134
  pad_token_added = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
135
  eos_token_added = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
 
136
  super().__init__(
137
  pad_token=pad_token_added,
138
  eos_token=eos_token_added,
139
  add_eos_token=add_eos_token,
140
- add_special_tokens=add_special_tokens,
141
  **kwargs,
142
  )
143
  self.add_eos_token = add_eos_token
144
- self.encoder = tiktoken_tokenizer(base="gpt2", pad_token=pad_token, add_special=add_special_tokens)
145
 
146
  @property
147
  def vocab_size(self):
@@ -166,7 +165,11 @@ class CodeGen25Tokenizer(PreTrainedTokenizer):
166
 
167
  def _convert_id_to_token(self, index):
168
  """Converts an index (integer) in a token (str) using the vocab."""
169
- return self.encoder.decode_single_token_bytes(index).decode("utf-8")
 
 
 
 
170
 
171
  def _decode(self, token_ids: List[int], skip_special_tokens: bool = False, **kwargs):
172
  if skip_special_tokens:
 
133
  ):
134
  pad_token_added = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
135
  eos_token_added = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
136
+ self.encoder = tiktoken_tokenizer(base="gpt2", pad_token=pad_token, add_special=add_special_tokens)
137
  super().__init__(
138
  pad_token=pad_token_added,
139
  eos_token=eos_token_added,
140
  add_eos_token=add_eos_token,
 
141
  **kwargs,
142
  )
143
  self.add_eos_token = add_eos_token
 
144
 
145
  @property
146
  def vocab_size(self):
 
165
 
166
  def _convert_id_to_token(self, index):
167
  """Converts an index (integer) in a token (str) using the vocab."""
168
+ try:
169
+ token = self.encoder.decode_single_token_bytes(index).decode("utf-8")
170
+ except Exception:
171
+ token = ""
172
+ return token
173
 
174
  def _decode(self, token_ids: List[int], skip_special_tokens: bool = False, **kwargs):
175
  if skip_special_tokens: