File size: 2,219 Bytes
6c63bd9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
import os
from typing import Union, List, Optional, Tuple
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
class SentencePieceJA(PreTrainedTokenizer):
def __init__(self, model_path, **kwargs):
super().__init__(**kwargs)
from tokenizers import Tokenizer
self._tokenizer = Tokenizer.from_file(model_path)
self.__pad_id = self._tokenize("<PAD>")[0]
self.__bos_id = self._tokenize("<BOS>")[0]
self.__eos_id = self._tokenize("<EOS>")[0]
self.__unk_id = self._tokenize("<UNK>")[0]
self.__mask_id = self._tokenize("<MASK>")[0]
def get_vocab(self) -> int:
return self._tokenizer.get_vocab()
def vocab_size(self) -> int:
return self._tokenizer.get_vocab_size()
def _tokenize(self, text, **kwargs):
return self._tokenizer.encode(text).ids
def _convert_token_to_id(self, token):
return token
def _convert_id_to_token(self, index: int) -> str:
return self._tokenizer.decode(index)
# return self._tokenizer.id_to_token(index)
def convert_tokens_to_ids(self, tokens: Union[str, List[str]]) -> Union[int, List[int]]:
return tokens
def convert_ids_to_tokens(
self, ids: Union[int, List[int]], skip_special_tokens: bool = False
) -> Union[str, List[str]]:
decoded = self._tokenizer.decode(ids)
return decoded
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
index = 0
if os.path.isdir(save_directory):
vocab_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + 'vocab.txt'
)
else:
vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory
with open(vocab_file, "w", encoding="utf-8") as writer:
for token, token_index in sorted(self.get_vocab().items(), key=lambda kv: kv[1]):
if index != token_index:
index = token_index
writer.write(token + "\n")
index += 1
return (vocab_file,) |