zhenyundeng commited on
Commit
cb71ea1
·
1 Parent(s): 55ca411

update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -5
app.py CHANGED
@@ -648,7 +648,8 @@ def averitec_search(claim, generate_question, speaker="they", check_date="2024-0
648
  store_folder = "averitec/data/store/retrieved_docs"
649
  #
650
  index = 0
651
- questions = [q["question"] for q in generate_question]
 
652
 
653
  # check the date of the claim
654
  current_date = datetime.now().strftime("%Y-%m-%d")
@@ -783,8 +784,8 @@ def decorate_with_questions(claim, retrieve_evidence, top_k=3): # top_k=5, 10,
783
 
784
  inputs = qg_tokenizer(sentences, padding=True, return_tensors="pt").to(qg_model.device)
785
  # inputs = qg_tokenizer(sentences, padding=True, return_tensors="pt").to(device)
786
- outputs = qg_model.generate(inputs["input_ids"], max_length=2000, num_beams=2, no_repeat_ngram_size=2, early_stopping=True)
787
- # max_length=5000
788
  tgt_text = qg_tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:], skip_special_tokens=True)[0]
789
  # We are not allowed to generate more than 250 characters:
790
  tgt_text = tgt_text[:250]
@@ -1423,8 +1424,8 @@ def main():
1423
  dropdown_samples.change(change_sample_questions, dropdown_samples, samples)
1424
  demo.queue()
1425
 
1426
- # demo.launch()
1427
- demo.launch(share=True)
1428
 
1429
 
1430
  if __name__ == "__main__":
 
648
  store_folder = "averitec/data/store/retrieved_docs"
649
  #
650
  index = 0
651
+ questions = [q["question"] for q in generate_question][:5]
652
+ # questions = [q["question"] for q in generate_question] # ori
653
 
654
  # check the date of the claim
655
  current_date = datetime.now().strftime("%Y-%m-%d")
 
784
 
785
  inputs = qg_tokenizer(sentences, padding=True, return_tensors="pt").to(qg_model.device)
786
  # inputs = qg_tokenizer(sentences, padding=True, return_tensors="pt").to(device)
787
+ outputs = qg_model.generate(inputs["input_ids"], max_length=5000, num_beams=2, no_repeat_ngram_size=2, early_stopping=True)
788
+
789
  tgt_text = qg_tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:], skip_special_tokens=True)[0]
790
  # We are not allowed to generate more than 250 characters:
791
  tgt_text = tgt_text[:250]
 
1424
  dropdown_samples.change(change_sample_questions, dropdown_samples, samples)
1425
  demo.queue()
1426
 
1427
+ demo.launch()
1428
+ # demo.launch(share=True)
1429
 
1430
 
1431
  if __name__ == "__main__":