dongxiaoqun commited on
Commit
2794e9d
·
1 Parent(s): ee0ea74

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +33 -1
README.md CHANGED
@@ -25,7 +25,39 @@ Task: Summarization
25
  import jieba_fast as jieba
26
  jieba.initialize()
27
  from transformers import PegasusForConditionalGeneration,BertTokenizer
28
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  model = PegasusForConditionalGeneration.from_pretrained('dongxq/test_model')
30
  tokenizer = BertTokenizer.from_pretrained('dongxq/test_model')
31
 
 
25
  import jieba_fast as jieba
26
  jieba.initialize()
27
  from transformers import PegasusForConditionalGeneration,BertTokenizer
28
+ class PegasusTokenizer(BertTokenizer):
29
+ model_input_names = ["input_ids", "attention_mask"]
30
+ def __init__(self, **kwargs):
31
+ super().__init__(pre_tokenizer=lambda x: jieba.cut(x, HMM=False), **kwargs)
32
+ self.add_special_tokens({'additional_special_tokens':["<mask_1>"]})
33
+
34
+ def build_inputs_with_special_tokens(
35
+ self,
36
+ token_ids_0: List[int],
37
+ token_ids_1: Optional[List[int]] = None) -> List[int]:
38
+
39
+ if token_ids_1 is None:
40
+ return token_ids_0 + [self.eos_token_id]
41
+ return token_ids_0 + token_ids_1 + [self.eos_token_id]
42
+
43
+ def _special_token_mask(self, seq):
44
+ all_special_ids = set(
45
+ self.all_special_ids) # call it once instead of inside list comp
46
+ # all_special_ids.remove(self.unk_token_id) # <unk> is only sometimes special
47
+ return [1 if x in all_special_ids else 0 for x in seq]
48
+
49
+ def get_special_tokens_mask(
50
+ self,
51
+ token_ids_0: List[int],
52
+ token_ids_1: Optional[List[int]] = None,
53
+ already_has_special_tokens: bool = False) -> List[int]:
54
+ if already_has_special_tokens:
55
+ return self._special_token_mask(token_ids_0)
56
+ elif token_ids_1 is None:
57
+ return self._special_token_mask(token_ids_0) + [self.eos_token_id]
58
+ else:
59
+ return self._special_token_mask(token_ids_0 +
60
+ token_ids_1) + [self.eos_token_id]
61
  model = PegasusForConditionalGeneration.from_pretrained('dongxq/test_model')
62
  tokenizer = BertTokenizer.from_pretrained('dongxq/test_model')
63