|
import os |
|
from typing import Optional,Union |
|
from transformers import PreTrainedTokenizer |
|
class CustomTokenizer(PreTrainedTokenizer): |
|
def __init__(self, vocab_file, **kwargs): |
|
super().__init__(**kwargs) |
|
self.vocab = self._load_vocab(vocab_file) |
|
self.ids_to_tokens = {v: k for k, v in self.vocab.items()} |
|
|
|
self.bos_token,self.eos_token,self.pad_token ,self.unk_token,self.mask_token = "[CLS]","[SEP]","[PAD]","[UNK]","[MASK]" |
|
|
|
def _load_vocab(self, vocab_file): |
|
vocab = {} |
|
with open(vocab_file, 'r',encoding="UTF-8") as f: |
|
for line in f: |
|
token = line.strip() |
|
vocab[token] = len(vocab) |
|
return vocab |
|
def tokenize(self, text): |
|
tokens = [] |
|
for word in text.split(): |
|
if word in self.vocab: |
|
tokens.append(word) |
|
else: |
|
tokens.append("[UNK]") |
|
return tokens |
|
|
|
def encode( |
|
self, |
|
text,text_pair = None,add_special_tokens:bool = True,padding: Union[bool, str] = False,truncation: Union[bool, str] = None,max_length: Optional[int] = None,stride: int = 0,return_tensors: Optional[Union[str]] = None,**kwargs,): |
|
tokens = [] |
|
for word in text.split(): |
|
if word in self.vocab: |
|
tokens.append(word) |
|
else: |
|
tokens.append("[UNK]") |
|
return tokens |
|
|
|
|
|
def convert_token_to_id(self, token): |
|
if token in self.vocab: |
|
return self.vocab[token] |
|
else: |
|
return self.vocab["[UNK]"] |
|
|
|
def convert_id_to_token(self, idx): |
|
if idx in self.ids_to_tokens: |
|
return self.ids_to_tokens[idx] |
|
else: |
|
return "[UNK]" |
|
|
|
def save_vocabulary(self, save_directory,filename_prefix = None): |
|
if not os.path.isdir(save_directory): |
|
return |
|
vocab_file = os.path.join(save_directory, "vocab.txt") |
|
with open(vocab_file, "w", encoding="utf-8") as f: |
|
for token, index in sorted(self.vocab.items(), key=lambda kv: kv[1]): |
|
f.write(token + "\n") |
|
return (vocab_file,) |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
|
|
|
|
from transformers import AutoTokenizer |
|
tokenizer = AutoTokenizer.from_pretrained("xiaohua828/MNIST_Demo_1") |
|
token = tokenizer.encode("我要赚钱") |
|
print(token) |
|
|
|
|
|
|
|
|