BarcodeBERT / tokenizer.py
nioushasadjadi
Changing the call function.
4e98ce2
raw
history blame
6.63 kB
from transformers import PreTrainedTokenizer
from huggingface_hub import hf_hub_download
import torch
import json
import os
from itertools import product
class KmerTokenizer(PreTrainedTokenizer):
def __init__(self, vocab_dict=None, k=4, stride=4, max_len=660, **kwargs):
self.k = k
self.stride = stride
self.max_len = max_len
self.special_tokens = ["[MASK]", "[UNK]"]
if vocab_dict is None:
kmers = ["".join(kmer) for kmer in product('ACGT', repeat=self.k)]
self.vocab = self.special_tokens + kmers
self.vocab_dict = {word: idx for idx, word in enumerate(self.vocab)}
else:
self.vocab = list(vocab_dict.keys())
self.vocab_dict = vocab_dict
super().__init__(**kwargs)
self.mask_token = "[MASK]"
self.unk_token = "[UNK]"
# self.pad_token = "[PAD]"
def tokenize(self, text, **kwargs):
if len(text) > self.max_len:
text = text[:self.max_len]
if kwargs.get('padding'):
if len(text) < self.max_len:
text = text + 'N' * (self.max_len - len(text))
splits = [text[i:i + self.k] for i in range(0, len(text) - self.k + 1, self.stride)]
return splits
def encode(self, text, **kwargs):
tokens = self.tokenize(text, **kwargs)
token_ids = self.convert_tokens_to_ids(tokens)
if kwargs.get('return_tensors') == 'pt':
return torch.tensor(token_ids)
return token_ids
def convert_tokens_to_ids(self, tokens):
unk_id = self.vocab_dict.get(self.unk_token)
return [self.vocab_dict[token] if token in self.vocab_dict else unk_id for token in tokens]
def convert_ids_to_tokens(self, ids, **kwargs):
id_to_token = {idx: token for token, idx in self.vocab_dict.items()}
return [id_to_token.get(id_, self.unk_token) for id_ in ids]
# def build_inputs_with_special_tokens(self, token_ids):
# return [self.vocab_dict.get(self.cls_token)] + token_ids + [self.vocab_dict.get(self.sep_token)]
def get_vocab(self):
return self.vocab_dict
def save_vocabulary(self, save_directory, **kwargs):
vocab_file = os.path.join(save_directory, "tokenizer.json")
with open(vocab_file, "w", encoding="utf-8") as f:
# Format
vocab_content = {
"version": "1.0",
"added_tokens": [
{"id": self.vocab_dict[self.mask_token], "content": self.mask_token, "special": True},
{"id": self.vocab_dict[self.unk_token], "content": self.unk_token, "special": True}
],
"pre_tokenizer": {
"type": "KmerSplitter",
"k": self.k,
"stride": self.stride,
"max_length": self.max_len
},
"model": {
"type": "KmerTokenizer",
"unk_token": self.unk_token,
"vocab": self.vocab_dict
},
}
json.dump(vocab_content, f, ensure_ascii=False, indent=2)
# vocab_file = os.path.join(save_directory, "tokenizer.json")
# with open(vocab_file, "w", encoding="utf-8") as f:
# json.dump(self.vocab_dict, f, ensure_ascii=False, indent=2)
tokenizer_config = {
"added_tokens_decoder": {
"0": {"content": "[MASK]", "lstrip": False, "normalized": False, "rstrip": False, "single_word": False,
"special": True},
"1": {"content": "[UNK]", "lstrip": False, "normalized": False, "rstrip": False, "single_word": False,
"special": True}
},
"auto_map": {
"AutoTokenizer": [
"tokenizer.KmerTokenizer",
None
]
},
"clean_up_tokenization_spaces": True,
"mask_token": "[MASK]",
"model_max_length": 1e12, # Set a high number, or adjust as needed
"tokenizer_class": "KmerTokenizer", # Set your tokenizer class name
"unk_token": "[UNK]"
}
tokenizer_config_file = os.path.join(save_directory, "tokenizer_config.json")
with open(tokenizer_config_file, "w", encoding="utf-8") as f:
json.dump(tokenizer_config, f, ensure_ascii=False, indent=2)
return vocab_file, tokenizer_config_file
@classmethod
def from_pretrained(cls, pretrained_dir, **kwargs):
# Load vocabulary
# vocab_file = os.path.join(pretrained_dir, "tokenizer.json")
vocab_file = hf_hub_download(repo_id=pretrained_dir, filename="tokenizer.json")
if os.path.exists(vocab_file):
with open(vocab_file, "r", encoding="utf-8") as f:
vocab_content = json.load(f)
vocab = vocab_content["model"]["vocab"]
k = vocab_content["pre_tokenizer"]["k"]
stride = vocab_content["pre_tokenizer"]["stride"]
max_len = vocab_content["pre_tokenizer"]["max_length"]
else:
raise ValueError(f"Vocabulary file not found at {vocab_file}")
# Check for the existence of tokenizer_config.json
# tokenizer_config_file = os.path.join(pretrained_dir, "tokenizer_config.json")
tokenizer_config_file = hf_hub_download(repo_id=pretrained_dir, filename="tokenizer_config.json")
if os.path.exists(tokenizer_config_file):
with open(tokenizer_config_file, "r", encoding="utf-8") as f:
tokenizer_config = json.load(f)
else:
raise ValueError(f"Tokenizer config file not found at {tokenizer_config_file}")
# Instantiate the tokenizer with loaded values
return cls(vocab=vocab, k=k, stride=stride, max_len=max_len, **kwargs)
def __call__(self, text, padding=False, **kwargs):
token_ids = self.encode(text, padding=padding, **kwargs)
unk_token_id = self.vocab_dict.get("[UNK]")
attention_mask = [1 if id_ != unk_token_id else 0 for id_ in token_ids]
token_type_ids = [0] * len(token_ids)
# Convert to the specified tensor format
if kwargs.get('return_tensors') == 'pt':
attention_mask = torch.tensor(attention_mask)
token_type_ids = torch.tensor(token_type_ids)
# Return the output dictionary
return {
"input_ids": token_ids,
"token_type_ids": token_type_ids,
"attention_mask": attention_mask
}