import jieba | |
from functools import partial | |
from transformers import BertTokenizer | |
class T5PegasusTokenizer(BertTokenizer): | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.pre_tokenizer = partial(jieba.cut, HMM=False) | |
def _tokenize(self, text, *arg, **kwargs): | |
split_tokens = [] | |
for text in self.pre_tokenizer(text): | |
if text in self.vocab: | |
split_tokens.append(text) | |
else: | |
split_tokens.extend(super()._tokenize(text)) | |
return split_tokens |