|
from transformers import PreTrainedTokenizer |
|
from huggingface_hub import hf_hub_download |
|
import json |
|
import os |
|
from itertools import product |
|
|
|
|
|
class KmerTokenizer(PreTrainedTokenizer): |
|
def __init__(self, vocab_dict=None, k=4, stride=4, **kwargs): |
|
self.k = k |
|
self.stride = stride |
|
self.special_tokens = ["[MASK]", "[UNK]"] |
|
|
|
if vocab_dict is None: |
|
kmers = ["".join(kmer) for kmer in product('ACGT', repeat=self.k)] |
|
self.vocab = self.special_tokens + kmers |
|
self.vocab_dict = {word: idx for idx, word in enumerate(self.vocab)} |
|
else: |
|
self.vocab = list(vocab_dict.keys()) |
|
self.vocab_dict = vocab_dict |
|
|
|
super().__init__(**kwargs) |
|
|
|
self.mask_token = "[MASK]" |
|
self.unk_token = "[UNK]" |
|
|
|
|
|
def _tokenize(self, text): |
|
splits = [text[i:i + self.k] for i in range(0, len(text) - self.k + 1, self.stride)] |
|
return self.convert_tokens_to_ids(splits) |
|
|
|
def convert_tokens_to_ids(self, tokens): |
|
unk_id = self.vocab_dict.get(self.unk_token) |
|
return [self.vocab_dict[token] if token in self.vocab_dict else unk_id for token in tokens] |
|
|
|
def convert_ids_to_tokens(self, ids): |
|
id_to_token = {idx: token for token, idx in self.vocab_dict.items()} |
|
return [id_to_token.get(id_, self.unk_token) for id_ in ids] |
|
|
|
|
|
|
|
|
|
def get_vocab(self): |
|
return self.vocab_dict |
|
|
|
def save_vocabulary(self, save_directory, **kwargs): |
|
vocab_file = os.path.join(save_directory, "tokenizer.json") |
|
with open(vocab_file, "w", encoding="utf-8") as f: |
|
|
|
vocab_content = { |
|
"version": "1.0", |
|
"added_tokens": [ |
|
{"id": self.vocab_dict[self.mask_token], "content": self.mask_token, "special": True}, |
|
{"id": self.vocab_dict[self.unk_token], "content": self.unk_token, "special": True} |
|
], |
|
"pre_tokenizer": { |
|
"type": "KmerSplitter", |
|
"k": self.k, |
|
"stride": self.stride |
|
}, |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"model": { |
|
"type": "k-mer", |
|
"k": self.k, |
|
"stride": self.stride, |
|
"unk_token": self.unk_token, |
|
"vocab": self.vocab_dict |
|
}, |
|
} |
|
json.dump(vocab_content, f, ensure_ascii=False, indent=2) |
|
|
|
|
|
|
|
|
|
tokenizer_config = { |
|
"added_tokens_decoder": { |
|
"0": {"content": "[MASK]", "lstrip": False, "normalized": False, "rstrip": False, "single_word": False, |
|
"special": True}, |
|
"1": {"content": "[UNK]", "lstrip": False, "normalized": False, "rstrip": False, "single_word": False, |
|
"special": True} |
|
}, |
|
"auto_map": { |
|
"AutoTokenizer": [ |
|
"tokenizer.KmerTokenizer", |
|
None |
|
] |
|
}, |
|
"clean_up_tokenization_spaces": True, |
|
"mask_token": "[MASK]", |
|
"model_max_length": 1e12, |
|
"tokenizer_class": "KmerTokenizer", |
|
"unk_token": "[UNK]", |
|
"k": self.k, |
|
"stride": self.stride |
|
} |
|
tokenizer_config_file = os.path.join(save_directory, "tokenizer_config.json") |
|
with open(tokenizer_config_file, "w", encoding="utf-8") as f: |
|
json.dump(tokenizer_config, f, ensure_ascii=False, indent=2) |
|
|
|
return vocab_file, tokenizer_config_file |
|
|
|
@classmethod |
|
def from_pretrained(cls, pretrained_dir, **kwargs): |
|
|
|
vocab_file = hf_hub_download(repo_id=pretrained_dir, filename="tokenizer.json") |
|
|
|
with open(vocab_file, "r", encoding="utf-8") as f: |
|
vocab_content = json.load(f) |
|
vocab = vocab_content["model"]["vocab"] |
|
|
|
|
|
|
|
|
|
|
|
tokenizer_config_file = hf_hub_download(repo_id=pretrained_dir, filename="tokenizer_config.json") |
|
if os.path.exists(tokenizer_config_file): |
|
with open(tokenizer_config_file, "r", encoding="utf-8") as f: |
|
tokenizer_config = json.load(f) |
|
k = tokenizer_config.get("k", 4) |
|
stride = tokenizer_config.get("stride", k) |
|
else: |
|
raise ValueError(f"Tokenizer config file not found at {tokenizer_config_file}") |
|
|
|
|
|
return cls(vocab=vocab, k=k, stride=stride, **kwargs) |
|
|