zhenyundeng commited on
Commit
200e5b6
·
1 Parent(s): 8a7fa89

update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -71
app.py CHANGED
@@ -7,6 +7,7 @@ import tqdm
7
  import torch
8
  import numpy as np
9
  from time import sleep
 
10
  import threading
11
  import gc
12
  import os
@@ -34,7 +35,6 @@ from averitec.data.sample_claims import CLAIMS_Type
34
  from utils import create_user_id
35
  user_id = create_user_id()
36
 
37
- from datetime import datetime
38
  from azure.storage.fileshare import ShareServiceClient
39
  try:
40
  from dotenv import load_dotenv
@@ -86,8 +86,21 @@ LABEL = [
86
  "Not Enough Evidence",
87
  "Conflicting Evidence/Cherrypicking",
88
  ]
89
- # Veracity
90
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  veracity_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
92
  bert_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=4, problem_type="single_label_classification")
93
  veracity_model = SequenceClassificationModule.load_from_checkpoint("averitec/pretrained_models/bert_veracity.ckpt",
@@ -99,7 +112,6 @@ best_checkpoint = 'averitec/pretrained_models/bart_justifications_verdict-epoch=
99
  justification_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=justification_tokenizer, model=bart_model).to(device)
100
  # ---------------------------------------------------------------------------
101
 
102
-
103
  # Set up Gradio Theme
104
  theme = gr.themes.Base(
105
  primary_hue="blue",
@@ -182,7 +194,7 @@ class SequenceClassificationDataLoader(pl.LightningDataModule):
182
  )
183
 
184
 
185
- def averitec_veracity_prediction(claim, qa_evidence):
186
  bert_model_name = "bert-base-uncased"
187
  tokenizer = BertTokenizer.from_pretrained(bert_model_name)
188
  bert_model = BertForSequenceClassification.from_pretrained(bert_model_name, num_labels=4,
@@ -340,7 +352,7 @@ def extract_claim_str(claim, qa_evidence, verdict_label):
340
  return claim_str
341
 
342
 
343
- def averitec_justification_generation(claim, qa_evidence, verdict_label):
344
  #
345
  claim_str = extract_claim_str(claim, qa_evidence, verdict_label)
346
  claim_str.strip()
@@ -383,21 +395,19 @@ def QAprediction(claim, evidence, sources):
383
  for i, evi in enumerate(evidence, 1):
384
  part = f"""<span>Doc {i}</span>"""
385
  subpart = f"""<a href="#doc{i}" class="a-doc-ref" target="_self"><span class='doc-ref'><sup>{i}</sup></span></a>"""
386
- # subpart = f"""<span class='doc-ref'>{i}</sup></span>"""
387
  subparts = "".join([part, subpart])
388
  parts.append(subparts)
389
 
390
  evidence_part = ", ".join(parts)
391
 
392
  prediction_title = f"""<h5>Prediction:</h5>"""
393
- # if 'Google' in sources or 'AVeriTeC' in sources:
394
- # verdict_label = averitec_veracity_prediction(claim, evidence)
395
- # justification_label = averitec_justification_generation(claim, evidence, verdict_label)
396
- # # justification_label = "See retrieved docs."
397
  # justification_part = f"""<span>Justification: {justification_label}</span>"""
398
  # if 'WikiPedia' in sources:
399
- # # verdict_label = fever_veracity_prediction(claim, evidence)
400
- # justification_label = averitec_justification_generation(claim, evidence, verdict_label)
401
  # # justification_label = "See retrieved docs."
402
  # justification_part = f"""<span>Justification: {justification_label}</span>"""
403
 
@@ -406,11 +416,8 @@ def QAprediction(claim, evidence, sources):
406
  # justification_label = "See retrieved docs."
407
  justification_part = f"""<span>Justification: {justification_label}</span>"""
408
 
409
-
410
  verdict_part = f"""Verdict: <span>{verdict_label}.</span><br>"""
411
-
412
  content_parts = "".join([evidence_title, evidence_part, prediction_title, verdict_part, justification_part])
413
- # content_parts = "".join([evidence_title, evidence_part, verdict_title, verdict_part, justification_title, justification_part])
414
 
415
  return content_parts, [verdict_label, justification_label]
416
 
@@ -418,8 +425,8 @@ def QAprediction(claim, evidence, sources):
418
  # ----------GoogleAPIretriever---------
419
  def generate_reference_corpus(reference_file):
420
  with open(reference_file) as f:
421
- j = json.load(f)
422
- train_examples = j
423
 
424
  all_data_corpus = []
425
  tokenized_corpus = []
@@ -456,16 +463,16 @@ def docs2prompt(top_docs):
456
 
457
  def prompt_question_generation(test_claim, speaker="they", topk=10):
458
  #
459
- reference_file = "averitec_code/data/train.json"
460
  tokenized_corpus, all_data_corpus = generate_reference_corpus(reference_file)
461
  bm25 = BM25Okapi(tokenized_corpus)
462
 
463
  # Define the bloom model:
464
  accelerator = Accelerator()
465
- accel_device = accelerator.device
466
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
467
- tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-7b1")
468
- model = BloomForCausalLM.from_pretrained("bigscience/bloom-7b1", torch_dtype=torch.bfloat16).to(device)
469
 
470
  # --------------------------------------------------
471
  # test claim
@@ -478,11 +485,11 @@ def prompt_question_generation(test_claim, speaker="they", topk=10):
478
  "\". Criticism includes questions like: "
479
  sentences = [prompt]
480
 
481
- inputs = tokenizer(sentences, padding=True, return_tensors="pt").to(device)
482
- outputs = model.generate(inputs["input_ids"], max_length=2000, num_beams=2, no_repeat_ngram_size=2,
483
  early_stopping=True)
484
 
485
- tgt_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
486
  in_len = len(sentences[0])
487
  questions_str = tgt_text[in_len:].split("\n")[0]
488
 
@@ -592,7 +599,7 @@ def get_google_search_results(api_key, search_engine_id, google_search, sort_dat
592
  return search_results
593
 
594
 
595
- def averitec_search(claim, generate_question, speaker="they", check_date="2024-01-01", n_pages=1): # n_pages=3
596
  # default config
597
  api_key = os.environ["GOOGLE_API_KEY"]
598
  search_engine_id = os.environ["GOOGLE_SEARCH_ENGINE_ID"]
@@ -612,13 +619,14 @@ def averitec_search(claim, generate_question, speaker="they", check_date="2024-0
612
  ]
613
 
614
  # save to folder
615
- store_folder = "averitec_code/store/retrieved_docs"
616
  #
617
  index = 0
618
  questions = [q["question"] for q in generate_question]
619
 
620
  # check the date of the claim
621
- sort_date = check_claim_date(check_date) # check_date="2022-01-01"
 
622
 
623
  #
624
  search_strings = []
@@ -643,6 +651,7 @@ def averitec_search(claim, generate_question, speaker="they", check_date="2024-0
643
  for page_num in range(n_pages):
644
  search_results = get_google_search_results(api_key, search_engine_id, google_search, sort_date,
645
  this_search_string, page=page_num)
 
646
 
647
  for result in search_results:
648
  link = str(result["link"])
@@ -733,21 +742,21 @@ def generate_step2_reference_corpus(reference_file):
733
 
734
  def decorate_with_questions(claim, retrieve_evidence, top_k=10): # top_k=100
735
  #
736
- reference_file = "averitec_code/data/train.json"
737
  tokenized_corpus, prompt_corpus = generate_step2_reference_corpus(reference_file)
738
  prompt_bm25 = BM25Okapi(tokenized_corpus)
739
 
740
  # Define the bloom model:
741
  accelerator = Accelerator()
742
  accel_device = accelerator.device
743
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
744
- tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-7b1")
745
- model = BloomForCausalLM.from_pretrained(
746
- "bigscience/bloom-7b1",
747
- device_map="auto",
748
- torch_dtype=torch.bfloat16,
749
- offload_folder="./offload"
750
- )
751
 
752
  #
753
  tokenized_corpus = []
@@ -795,11 +804,11 @@ def decorate_with_questions(claim, retrieve_evidence, top_k=10): # top_k=100
795
  prompt = "\n\n".join(prompt_docs + [claim_prompt])
796
  sentences = [prompt]
797
 
798
- inputs = tokenizer(sentences, padding=True, return_tensors="pt").to(device)
799
- outputs = model.generate(inputs["input_ids"], max_length=5000, num_beams=2, no_repeat_ngram_size=2,
800
  early_stopping=True)
801
 
802
- tgt_text = tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:], skip_special_tokens=True)[0]
803
  # We are not allowed to generate more than 250 characters:
804
  tgt_text = tgt_text[:250]
805
 
@@ -815,13 +824,13 @@ def triple_to_string(x):
815
 
816
  def rerank_questions(claim, bm25_qas, topk=3):
817
  #
818
- tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
819
- bert_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2,
820
- problem_type="single_label_classification") # Must specify single_label for some reason
821
- best_checkpoint = "averitec_code/pretrained_models/bert_dual_encoder.ckpt"
822
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
823
- trained_model = DualEncoderModule.load_from_checkpoint(best_checkpoint, tokenizer=tokenizer, model=bert_model).to(
824
- device)
825
 
826
  #
827
  strs_to_score = []
@@ -834,13 +843,13 @@ def rerank_questions(claim, bm25_qas, topk=3):
834
  values.append([question, answer, source])
835
 
836
  if len(bm25_qas) > 0:
837
- encoded_dict = tokenizer(strs_to_score, max_length=512, padding="longest", truncation=True,
838
  return_tensors="pt").to(device)
839
 
840
  input_ids = encoded_dict['input_ids']
841
  attention_masks = encoded_dict['attention_mask']
842
 
843
- scores = torch.softmax(trained_model(input_ids, attention_mask=attention_masks).logits, axis=-1)[:, 1]
844
 
845
  top_n = torch.argsort(scores, descending=True)[:topk]
846
  pass_through = [{"question": values[i][0], "answers": values[i][1], "source_url": values[i][2]} for i in top_n]
@@ -852,20 +861,16 @@ def rerank_questions(claim, bm25_qas, topk=3):
852
  return top3_qa_pairs
853
 
854
 
855
- def GoogleAPIretriever(query):
856
  # ----- Generate QA pairs using AVeriTeC
857
- top3_qa_pairs_path = "averitec_code/top3_qa_pairs1.json"
858
- if not os.path.exists(top3_qa_pairs_path):
859
- # step 1: generate questions for the query/claim using Bloom
860
- generate_question = prompt_question_generation(query)
861
- # step 2: retrieve evidence for the generated questions using Google API
862
- retrieve_evidence = averitec_search(query, generate_question)
863
- # step 3: generate QA pairs for each retrieved document
864
- bm25_qa_pairs = decorate_with_questions(query, retrieve_evidence)
865
- # step 4: rerank QA pairs
866
- top3_qa_pairs = rerank_questions(query, bm25_qa_pairs)
867
- else:
868
- top3_qa_pairs = json.load(open(top3_qa_pairs_path, 'r'))
869
 
870
  # Add score to metadata
871
  results = []
@@ -877,12 +882,14 @@ def GoogleAPIretriever(query):
877
  metadata['cached_source_url'] = qa['source_url']
878
  metadata['short_name'] = "Evidence {}".format(i + 1)
879
  metadata['page_number'] = ""
 
 
880
  metadata['query'] = qa['question']
881
  metadata['answer'] = qa['answers']
882
  metadata['page_content'] = "<b>Question</b>: " + qa['question'] + "<br>" + "<b>Answer</b>: " + qa['answers']
883
  page_content = f"""{metadata['page_content']}"""
884
- results.append((metadata, page_content))
885
 
 
886
  return results
887
 
888
 
@@ -1181,11 +1188,8 @@ def log_on_azure(file, logs, azure_share_client):
1181
 
1182
  def chat(claim, history, sources):
1183
  evidence = []
1184
- # if 'Google' in sources:
1185
- # evidence = GoogleAPIretriever(query)
1186
-
1187
- # if 'WikiPediaDumps' in sources:
1188
- # evidence = WikipediaDumpsretriever(query)
1189
 
1190
  if 'WikiPedia' in sources:
1191
  evidence = Wikipediaretriever(claim, sources)
@@ -1212,7 +1216,8 @@ def chat(claim, history, sources):
1212
  for evi in evidence:
1213
  title_str = evi.metadata['title']
1214
  evi_str = evi.metadata['evidence']
1215
- evi_list.append([title_str, evi_str])
 
1216
 
1217
  try:
1218
  # Log answer on Azure Blob Storage
@@ -1226,7 +1231,6 @@ def chat(claim, history, sources):
1226
  "claim": claim,
1227
  "sources": sources,
1228
  "evidence": evi_list,
1229
- "url": url_of_evidence,
1230
  "answer": answer_output,
1231
  "time": timestamp,
1232
  }
@@ -1254,7 +1258,7 @@ def main():
1254
  chatbot = gr.Chatbot(
1255
  value=[(None, init_prompt)],
1256
  show_copy_button=True, show_label=False, elem_id="chatbot", layout="panel",
1257
- avatar_images=(None, "assets/averitec.png")
1258
  ) # avatar_images=(None, "https://i.ibb.co/YNyd5W2/logo4.png"),
1259
 
1260
  with gr.Row(elem_id="input-message"):
 
7
  import torch
8
  import numpy as np
9
  from time import sleep
10
+ from datetime import datetime
11
  import threading
12
  import gc
13
  import os
 
35
  from utils import create_user_id
36
  user_id = create_user_id()
37
 
 
38
  from azure.storage.fileshare import ShareServiceClient
39
  try:
40
  from dotenv import load_dotenv
 
86
  "Not Enough Evidence",
87
  "Conflicting Evidence/Cherrypicking",
88
  ]
89
+
90
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
91
+ # Question Generation
92
+ qg_tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-1b1")
93
+ qg_model = BloomForCausalLM.from_pretrained("bigscience/bloom-1b1", torch_dtype=torch.bfloat16).to(device)
94
+ # qg_tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-7b1")
95
+ # qg_model = BloomForCausalLM.from_pretrained("bigscience/bloom-7b1", torch_dtype=torch.bfloat16).to(device)
96
+
97
+ # rerank
98
+ rerank_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
99
+ rereank_bert_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2, problem_type="single_label_classification") # Must specify single_label for some reason
100
+ best_checkpoint = "averitec/pretrained_models/bert_dual_encoder.ckpt"
101
+ rerank_trained_model = DualEncoderModule.load_from_checkpoint(best_checkpoint, tokenizer=rerank_tokenizer, model=rereank_bert_model).to(device)
102
+
103
+ # Veracity
104
  veracity_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
105
  bert_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=4, problem_type="single_label_classification")
106
  veracity_model = SequenceClassificationModule.load_from_checkpoint("averitec/pretrained_models/bert_veracity.ckpt",
 
112
  justification_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=justification_tokenizer, model=bart_model).to(device)
113
  # ---------------------------------------------------------------------------
114
 
 
115
  # Set up Gradio Theme
116
  theme = gr.themes.Base(
117
  primary_hue="blue",
 
194
  )
195
 
196
 
197
+ def google_veracity_prediction(claim, qa_evidence):
198
  bert_model_name = "bert-base-uncased"
199
  tokenizer = BertTokenizer.from_pretrained(bert_model_name)
200
  bert_model = BertForSequenceClassification.from_pretrained(bert_model_name, num_labels=4,
 
352
  return claim_str
353
 
354
 
355
+ def google_justification_generation(claim, qa_evidence, verdict_label):
356
  #
357
  claim_str = extract_claim_str(claim, qa_evidence, verdict_label)
358
  claim_str.strip()
 
395
  for i, evi in enumerate(evidence, 1):
396
  part = f"""<span>Doc {i}</span>"""
397
  subpart = f"""<a href="#doc{i}" class="a-doc-ref" target="_self"><span class='doc-ref'><sup>{i}</sup></span></a>"""
 
398
  subparts = "".join([part, subpart])
399
  parts.append(subparts)
400
 
401
  evidence_part = ", ".join(parts)
402
 
403
  prediction_title = f"""<h5>Prediction:</h5>"""
404
+ # if 'Google' in sources:
405
+ # verdict_label = google_veracity_prediction(claim, evidence)
406
+ # justification_label = google_justification_generation(claim, evidence, verdict_label)
 
407
  # justification_part = f"""<span>Justification: {justification_label}</span>"""
408
  # if 'WikiPedia' in sources:
409
+ # verdict_label = wikipedia_veracity_prediction(claim, evidence)
410
+ # justification_label = wikipedia_justification_generation(claim, evidence, verdict_label)
411
  # # justification_label = "See retrieved docs."
412
  # justification_part = f"""<span>Justification: {justification_label}</span>"""
413
 
 
416
  # justification_label = "See retrieved docs."
417
  justification_part = f"""<span>Justification: {justification_label}</span>"""
418
 
 
419
  verdict_part = f"""Verdict: <span>{verdict_label}.</span><br>"""
 
420
  content_parts = "".join([evidence_title, evidence_part, prediction_title, verdict_part, justification_part])
 
421
 
422
  return content_parts, [verdict_label, justification_label]
423
 
 
425
  # ----------GoogleAPIretriever---------
426
  def generate_reference_corpus(reference_file):
427
  with open(reference_file) as f:
428
+ #
429
+ train_examples = json.load(f)
430
 
431
  all_data_corpus = []
432
  tokenized_corpus = []
 
463
 
464
  def prompt_question_generation(test_claim, speaker="they", topk=10):
465
  #
466
+ reference_file = "averitec/data/train.json"
467
  tokenized_corpus, all_data_corpus = generate_reference_corpus(reference_file)
468
  bm25 = BM25Okapi(tokenized_corpus)
469
 
470
  # Define the bloom model:
471
  accelerator = Accelerator()
472
+ # accel_device = accelerator.device
473
+ # device = "cuda:0" if torch.cuda.is_available() else "cpu"
474
+ # tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-7b1")
475
+ # model = BloomForCausalLM.from_pretrained("bigscience/bloom-7b1", torch_dtype=torch.bfloat16).to(device)
476
 
477
  # --------------------------------------------------
478
  # test claim
 
485
  "\". Criticism includes questions like: "
486
  sentences = [prompt]
487
 
488
+ inputs = qg_tokenizer(sentences, padding=True, return_tensors="pt").to(device)
489
+ outputs = qg_model.generate(inputs["input_ids"], max_length=2000, num_beams=2, no_repeat_ngram_size=2,
490
  early_stopping=True)
491
 
492
+ tgt_text = qg_tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
493
  in_len = len(sentences[0])
494
  questions_str = tgt_text[in_len:].split("\n")[0]
495
 
 
599
  return search_results
600
 
601
 
602
+ def averitec_search(claim, generate_question, speaker="they", check_date="2024-07-01", n_pages=1): # n_pages=3
603
  # default config
604
  api_key = os.environ["GOOGLE_API_KEY"]
605
  search_engine_id = os.environ["GOOGLE_SEARCH_ENGINE_ID"]
 
619
  ]
620
 
621
  # save to folder
622
+ store_folder = "averitec/data/store/retrieved_docs"
623
  #
624
  index = 0
625
  questions = [q["question"] for q in generate_question]
626
 
627
  # check the date of the claim
628
+ current_date = datetime.now().strftime("%Y-%m-%d")
629
+ sort_date = check_claim_date(current_date) # check_date="2022-01-01"
630
 
631
  #
632
  search_strings = []
 
651
  for page_num in range(n_pages):
652
  search_results = get_google_search_results(api_key, search_engine_id, google_search, sort_date,
653
  this_search_string, page=page_num)
654
+ search_results = search_results[:5]
655
 
656
  for result in search_results:
657
  link = str(result["link"])
 
742
 
743
  def decorate_with_questions(claim, retrieve_evidence, top_k=10): # top_k=100
744
  #
745
+ reference_file = "averitec/data/train.json"
746
  tokenized_corpus, prompt_corpus = generate_step2_reference_corpus(reference_file)
747
  prompt_bm25 = BM25Okapi(tokenized_corpus)
748
 
749
  # Define the bloom model:
750
  accelerator = Accelerator()
751
  accel_device = accelerator.device
752
+ # device = "cuda:0" if torch.cuda.is_available() else "cpu"
753
+ # tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-7b1")
754
+ # model = BloomForCausalLM.from_pretrained(
755
+ # "bigscience/bloom-7b1",
756
+ # device_map="auto",
757
+ # torch_dtype=torch.bfloat16,
758
+ # offload_folder="./offload"
759
+ # )
760
 
761
  #
762
  tokenized_corpus = []
 
804
  prompt = "\n\n".join(prompt_docs + [claim_prompt])
805
  sentences = [prompt]
806
 
807
+ inputs = qg_tokenizer(sentences, padding=True, return_tensors="pt").to(device)
808
+ outputs = qg_model.generate(inputs["input_ids"], max_length=5000, num_beams=2, no_repeat_ngram_size=2,
809
  early_stopping=True)
810
 
811
+ tgt_text = qg_tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:], skip_special_tokens=True)[0]
812
  # We are not allowed to generate more than 250 characters:
813
  tgt_text = tgt_text[:250]
814
 
 
824
 
825
  def rerank_questions(claim, bm25_qas, topk=3):
826
  #
827
+ # tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
828
+ # bert_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2,
829
+ # problem_type="single_label_classification") # Must specify single_label for some reason
830
+ # best_checkpoint = "averitec/pretrained_models/bert_dual_encoder.ckpt"
831
+ # device = "cuda:0" if torch.cuda.is_available() else "cpu"
832
+ # trained_model = DualEncoderModule.load_from_checkpoint(best_checkpoint, tokenizer=tokenizer, model=bert_model).to(
833
+ # device)
834
 
835
  #
836
  strs_to_score = []
 
843
  values.append([question, answer, source])
844
 
845
  if len(bm25_qas) > 0:
846
+ encoded_dict = rerank_tokenizer(strs_to_score, max_length=512, padding="longest", truncation=True,
847
  return_tensors="pt").to(device)
848
 
849
  input_ids = encoded_dict['input_ids']
850
  attention_masks = encoded_dict['attention_mask']
851
 
852
+ scores = torch.softmax(rerank_trained_model(input_ids, attention_mask=attention_masks).logits, axis=-1)[:, 1]
853
 
854
  top_n = torch.argsort(scores, descending=True)[:topk]
855
  pass_through = [{"question": values[i][0], "answers": values[i][1], "source_url": values[i][2]} for i in top_n]
 
861
  return top3_qa_pairs
862
 
863
 
864
+ def Googleretriever(query, sources):
865
  # ----- Generate QA pairs using AVeriTeC
866
+ # step 1: generate questions for the query/claim using Bloom
867
+ generate_question = prompt_question_generation(query)
868
+ # step 2: retrieve evidence for the generated questions using Google API
869
+ retrieve_evidence = averitec_search(query, generate_question)
870
+ # step 3: generate QA pairs for each retrieved document
871
+ bm25_qa_pairs = decorate_with_questions(query, retrieve_evidence)
872
+ # step 4: rerank QA pairs
873
+ top3_qa_pairs = rerank_questions(query, bm25_qa_pairs)
 
 
 
 
874
 
875
  # Add score to metadata
876
  results = []
 
882
  metadata['cached_source_url'] = qa['source_url']
883
  metadata['short_name'] = "Evidence {}".format(i + 1)
884
  metadata['page_number'] = ""
885
+ metadata['title'] = qa['question']
886
+ metadata['evidence'] = qa['answers']
887
  metadata['query'] = qa['question']
888
  metadata['answer'] = qa['answers']
889
  metadata['page_content'] = "<b>Question</b>: " + qa['question'] + "<br>" + "<b>Answer</b>: " + qa['answers']
890
  page_content = f"""{metadata['page_content']}"""
 
891
 
892
+ results.append(Docs(metadata, page_content))
893
  return results
894
 
895
 
 
1188
 
1189
  def chat(claim, history, sources):
1190
  evidence = []
1191
+ if 'Google' in sources:
1192
+ evidence = Googleretriever(claim, sources)
 
 
 
1193
 
1194
  if 'WikiPedia' in sources:
1195
  evidence = Wikipediaretriever(claim, sources)
 
1216
  for evi in evidence:
1217
  title_str = evi.metadata['title']
1218
  evi_str = evi.metadata['evidence']
1219
+ url_str = evi.metadata['url']
1220
+ evi_list.append([title_str, evi_str, url_str])
1221
 
1222
  try:
1223
  # Log answer on Azure Blob Storage
 
1231
  "claim": claim,
1232
  "sources": sources,
1233
  "evidence": evi_list,
 
1234
  "answer": answer_output,
1235
  "time": timestamp,
1236
  }
 
1258
  chatbot = gr.Chatbot(
1259
  value=[(None, init_prompt)],
1260
  show_copy_button=True, show_label=False, elem_id="chatbot", layout="panel",
1261
+ avatar_images = (None, "assets/averitec.png")
1262
  ) # avatar_images=(None, "https://i.ibb.co/YNyd5W2/logo4.png"),
1263
 
1264
  with gr.Row(elem_id="input-message"):