|
import os |
|
import json |
|
|
|
import torch |
|
from transformers import ( |
|
AutoTokenizer, |
|
PreTrainedTokenizer |
|
) |
|
|
|
from .configuration_keeper import KeeperConfig |
|
|
|
from typing import Optional, List, Union |
|
|
|
|
|
class KeeperTokenizer(PreTrainedTokenizer): |
|
|
|
config_class = KeeperConfig |
|
|
|
def __init__(self, cfg=None): |
|
|
|
|
|
self.tokenizer_retriever = None |
|
self.tokenizer_model = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod |
|
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): |
|
|
|
instance = cls() |
|
|
|
print("Loading tokenizer_retriever from", pretrained_model_name_or_path) |
|
instance.tokenizer_retriever = AutoTokenizer.from_pretrained( |
|
pretrained_model_name_or_path, subfolder='tokenizer-retriever' |
|
) |
|
|
|
print("Loading tokenizer_model from", pretrained_model_name_or_path) |
|
instance.tokenizer_model = AutoTokenizer.from_pretrained( |
|
pretrained_model_name_or_path, subfolder='tokenizer-model' |
|
) |
|
|
|
return instance |
|
|
|
@property |
|
def vocab_size(self): |
|
|
|
vocab_retriever = self.tokenizer_retriever.get_vocab() |
|
vocab_model = self.tokenizer_model.get_vocab() |
|
|
|
|
|
combined_vocab = {**vocab_retriever, **vocab_model} |
|
|
|
|
|
return len(combined_vocab) |
|
|
|
|
|
def get_vocab(self): |
|
|
|
vocab_retriever = self.tokenizer_retriever.get_vocab() |
|
vocab_model = self.tokenizer_model.get_vocab() |
|
|
|
|
|
separated_vocabularies = { |
|
'vocab_retriever': vocab_retriever, |
|
'vocab_model': vocab_model |
|
} |
|
|
|
return separated_vocabularies |
|
|
|
def _tokenize(self, text, **kwargs): |
|
|
|
pass |
|
|
|
def encode(self, text, **kwargs): |
|
tokens_retriever = self.tokenizer_retriever(text, return_tensors='pt', **kwargs) |
|
tokens_model = self.tokenizer_model(text, return_tensors='pt', **kwargs) |
|
|
|
return { |
|
'tokens_retriever': tokens_retriever, |
|
'tokens_model': tokens_model |
|
} |
|
|
|
def decode( |
|
self, |
|
token_ids: Union[int, List[int], "torch.Tensor"], |
|
skip_special_tokens: bool = False, |
|
**kwargs, |
|
) -> str: |
|
return self.tokenizer_model.decode(token_ids, skip_special_tokens, **kwargs) |
|
|
|
def save_vocabulary(self, save_directory, filename_prefix=None): |
|
|
|
os.makedirs(save_directory, exist_ok=True) |
|
|
|
|
|
retriever_save_directory = os.path.join(save_directory, "tokenizer-retriever") |
|
os.makedirs(retriever_save_directory, exist_ok=True) |
|
self.tokenizer_retriever.save_pretrained(retriever_save_directory) |
|
|
|
|
|
model_save_directory = os.path.join(save_directory, "tokenizer-model") |
|
os.makedirs(model_save_directory, exist_ok=True) |
|
self.tokenizer_model.save_pretrained(model_save_directory) |
|
|
|
|
|
saved_files = [ |
|
"tokenizer-retriver/tokenizer_config.json", |
|
"tokenizer-retriver/special_tokens_map.json", |
|
"tokenizer-retriver/vocab.json", |
|
"tokenizer-retriver/added_tokens.json", |
|
"tokenizer-model/tokenizer_config.json", |
|
"tokenizer-model/special_tokens_map.json", |
|
"tokenizer-model/vocab.json", |
|
"tokenizer-model/added_tokens.json" |
|
] |
|
return tuple(os.path.join(save_directory, file) for file in saved_files) |
|
|