Update README.md
Browse files
README.md
CHANGED
@@ -79,19 +79,20 @@ model=MT5ForConditionalGeneration.from_pretrained(pretrain_path)
|
|
79 |
sample={"context":"在柏林,胡格诺派教徒创建了两个新的社区:多罗西恩斯塔特和弗里德里希斯塔特。到1700年,这个城市五分之一的人口讲法语。柏林胡格诺派在他们的教堂服务中保留了将近一个世纪的法语。他们最终决定改用德语,以抗议1806-1807年拿破仑占领普鲁士。他们的许多后代都有显赫的地位。成立了几个教会,如弗雷德里夏(丹麦)、柏林、斯德哥尔摩、汉堡、法兰克福、赫尔辛基和埃姆登的教会。","question":"除了多罗西恩斯塔特,柏林还有哪个新的社区?","idx":1}
|
80 |
plain_text='question:'+sample['question']+'knowledge:'+sample['context'][:self.max_knowledge_length]
|
81 |
|
82 |
-
res_prefix=tokenizer.encode('answer',
|
83 |
res_prefix.append(tokenizer.convert_tokens_to_ids('<extra_id_0>'))
|
84 |
-
res_prefix.append(
|
85 |
l_rp=len(res_prefix)
|
86 |
|
87 |
-
tokenized=tokenizer.encode(plain_text,add_special_tokens=False,truncation=True,max_length=
|
88 |
tokenized+=res_prefix
|
89 |
batch=[tokenized]*2
|
90 |
input_ids=torch.tensor(np.array(batch),dtype=torch.long)
|
91 |
|
92 |
# Generate answer
|
93 |
-
|
94 |
-
|
|
|
95 |
res=pred_tokens.replace('<extra_id_0>','').replace('有答案:','')
|
96 |
```
|
97 |
|
|
|
79 |
sample={"context":"在柏林,胡格诺派教徒创建了两个新的社区:多罗西恩斯塔特和弗里德里希斯塔特。到1700年,这个城市五分之一的人口讲法语。柏林胡格诺派在他们的教堂服务中保留了将近一个世纪的法语。他们最终决定改用德语,以抗议1806-1807年拿破仑占领普鲁士。他们的许多后代都有显赫的地位。成立了几个教会,如弗雷德里夏(丹麦)、柏林、斯德哥尔摩、汉堡、法兰克福、赫尔辛基和埃姆登的教会。","question":"除了多罗西恩斯塔特,柏林还有哪个新的社区?","idx":1}
|
80 |
plain_text='question:'+sample['question']+'knowledge:'+sample['context'][:self.max_knowledge_length]
|
81 |
|
82 |
+
res_prefix=tokenizer.encode('answer',add_special_tokens=False)
|
83 |
res_prefix.append(tokenizer.convert_tokens_to_ids('<extra_id_0>'))
|
84 |
+
res_prefix.append(tokenizer.eos_token_id)
|
85 |
l_rp=len(res_prefix)
|
86 |
|
87 |
+
tokenized=tokenizer.encode(plain_text,add_special_tokens=False,truncation=True,max_length=1024-2-l_rp)
|
88 |
tokenized+=res_prefix
|
89 |
batch=[tokenized]*2
|
90 |
input_ids=torch.tensor(np.array(batch),dtype=torch.long)
|
91 |
|
92 |
# Generate answer
|
93 |
+
max_target_length=128
|
94 |
+
pred_ids = model.generate(input_ids=input_ids,max_new_tokens=max_target_length,do_sample=True,top_p=0.9)
|
95 |
+
pred_tokens=tokenizer.batch_decode(pred_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
96 |
res=pred_tokens.replace('<extra_id_0>','').replace('有答案:','')
|
97 |
```
|
98 |
|