MNIST_Demo_1 / tokenizer.py
xiaohua828's picture
Upload tokenizer.py
362cdb6
import os
from typing import Optional,Union
from transformers import PreTrainedTokenizer
class CustomTokenizer(PreTrainedTokenizer):
def __init__(self, vocab_file, **kwargs):
super().__init__(**kwargs)
self.vocab = self._load_vocab(vocab_file)
self.ids_to_tokens = {v: k for k, v in self.vocab.items()}
self.bos_token,self.eos_token,self.pad_token ,self.unk_token,self.mask_token = "[CLS]","[SEP]","[PAD]","[UNK]","[MASK]"
def _load_vocab(self, vocab_file):
vocab = {}
with open(vocab_file, 'r',encoding="UTF-8") as f:
for line in f:
token = line.strip()
vocab[token] = len(vocab)
return vocab
def tokenize(self, text):
tokens = []
for word in text.split():
if word in self.vocab:
tokens.append(word)
else:
tokens.append("[UNK]")
return tokens
def encode(
self,
text,text_pair = None,add_special_tokens:bool = True,padding: Union[bool, str] = False,truncation: Union[bool, str] = None,max_length: Optional[int] = None,stride: int = 0,return_tensors: Optional[Union[str]] = None,**kwargs,):
tokens = []
for word in text.split():
if word in self.vocab:
tokens.append(word)
else:
tokens.append("[UNK]")
return tokens
def convert_token_to_id(self, token):
if token in self.vocab:
return self.vocab[token]
else:
return self.vocab["[UNK]"]
def convert_id_to_token(self, idx):
if idx in self.ids_to_tokens:
return self.ids_to_tokens[idx]
else:
return "[UNK]"
def save_vocabulary(self, save_directory,filename_prefix = None):
if not os.path.isdir(save_directory):
return
vocab_file = os.path.join(save_directory, "vocab.txt")
with open(vocab_file, "w", encoding="utf-8") as f:
for token, index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
f.write(token + "\n")
return (vocab_file,)
if __name__ == '__main__':
# import login
# tokenizer = CustomTokenizer(vocab_file = "./vocab.txt")
# tokenizer.push_to_hub(login.name_or_path)
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("xiaohua828/MNIST_Demo_1")
token = tokenizer.encode("我要赚钱")
print(token)