import jieba from functools import partial from transformers import BertTokenizer class T5PegasusTokenizer(BertTokenizer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.pre_tokenizer = partial(jieba.cut, HMM=False) def _tokenize(self, text, *arg, **kwargs): split_tokens = [] for text in self.pre_tokenizer(text): if text in self.vocab: split_tokens.append(text) else: split_tokens.extend(super()._tokenize(text)) return split_tokens