Justcode commited on
Commit
bedb821
1 Parent(s): c94ce49

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +6 -5
README.md CHANGED
@@ -79,19 +79,20 @@ model=MT5ForConditionalGeneration.from_pretrained(pretrain_path)
79
  sample={"context":"在柏林,胡格诺派教徒创建了两个新的社区:多罗西恩斯塔特和弗里德里希斯塔特。到1700年,这个城市五分之一的人口讲法语。柏林胡格诺派在他们的教堂服务中保留了将近一个世纪的法语。他们最终决定改用德语,以抗议1806-1807年拿破仑占领普鲁士。他们的许多后代都有显赫的地位。成立了几个教会,如弗雷德里夏(丹麦)、柏林、斯德哥尔摩、汉堡、法兰克福、赫尔辛基和埃姆登的教会。","question":"除了多罗西恩斯塔特,柏林还有哪个新的社区?","idx":1}
80
  plain_text='question:'+sample['question']+'knowledge:'+sample['context'][:self.max_knowledge_length]
81
 
82
- res_prefix=tokenizer.encode('answer',add_special_token=False)
83
  res_prefix.append(tokenizer.convert_tokens_to_ids('<extra_id_0>'))
84
- res_prefix.append(EOS_TOKEN_ID)
85
  l_rp=len(res_prefix)
86
 
87
- tokenized=tokenizer.encode(plain_text,add_special_tokens=False,truncation=True,max_length=self.max_seq_length-2-l_rp)
88
  tokenized+=res_prefix
89
  batch=[tokenized]*2
90
  input_ids=torch.tensor(np.array(batch),dtype=torch.long)
91
 
92
  # Generate answer
93
- pred_ids = model.generate(input_ids=input_ids,max_new_token=self.max_target_length,do_sample=True,top_p=0.9)
94
- pred_tokens=tokenizer.batch_decode(decode_ids=pred_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
 
95
  res=pred_tokens.replace('<extra_id_0>','').replace('有答案:','')
96
  ```
97
 
 
79
  sample={"context":"在柏林,胡格诺派教徒创建了两个新的社区:多罗西恩斯塔特和弗里德里希斯塔特。到1700年,这个城市五分之一的人口讲法语。柏林胡格诺派在他们的教堂服务中保留了将近一个世纪的法语。他们最终决定改用德语,以抗议1806-1807年拿破仑占领普鲁士。他们的许多后代都有显赫的地位。成立了几个教会,如弗雷德里夏(丹麦)、柏林、斯德哥尔摩、汉堡、法兰克福、赫尔辛基和埃姆登的教会。","question":"除了多罗西恩斯塔特,柏林还有哪个新的社区?","idx":1}
80
  plain_text='question:'+sample['question']+'knowledge:'+sample['context'][:self.max_knowledge_length]
81
 
82
+ res_prefix=tokenizer.encode('answer',add_special_tokens=False)
83
  res_prefix.append(tokenizer.convert_tokens_to_ids('<extra_id_0>'))
84
+ res_prefix.append(tokenizer.eos_token_id)
85
  l_rp=len(res_prefix)
86
 
87
+ tokenized=tokenizer.encode(plain_text,add_special_tokens=False,truncation=True,max_length=1024-2-l_rp)
88
  tokenized+=res_prefix
89
  batch=[tokenized]*2
90
  input_ids=torch.tensor(np.array(batch),dtype=torch.long)
91
 
92
  # Generate answer
93
+ max_target_length=128
94
+ pred_ids = model.generate(input_ids=input_ids,max_new_tokens=max_target_length,do_sample=True,top_p=0.9)
95
+ pred_tokens=tokenizer.batch_decode(pred_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
96
  res=pred_tokens.replace('<extra_id_0>','').replace('有答案:','')
97
  ```
98