Spaces:
Build error
Build error
zhenyundeng
commited on
Commit
·
200e5b6
1
Parent(s):
8a7fa89
update app.py
Browse files
app.py
CHANGED
@@ -7,6 +7,7 @@ import tqdm
|
|
7 |
import torch
|
8 |
import numpy as np
|
9 |
from time import sleep
|
|
|
10 |
import threading
|
11 |
import gc
|
12 |
import os
|
@@ -34,7 +35,6 @@ from averitec.data.sample_claims import CLAIMS_Type
|
|
34 |
from utils import create_user_id
|
35 |
user_id = create_user_id()
|
36 |
|
37 |
-
from datetime import datetime
|
38 |
from azure.storage.fileshare import ShareServiceClient
|
39 |
try:
|
40 |
from dotenv import load_dotenv
|
@@ -86,8 +86,21 @@ LABEL = [
|
|
86 |
"Not Enough Evidence",
|
87 |
"Conflicting Evidence/Cherrypicking",
|
88 |
]
|
89 |
-
|
90 |
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
veracity_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
92 |
bert_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=4, problem_type="single_label_classification")
|
93 |
veracity_model = SequenceClassificationModule.load_from_checkpoint("averitec/pretrained_models/bert_veracity.ckpt",
|
@@ -99,7 +112,6 @@ best_checkpoint = 'averitec/pretrained_models/bart_justifications_verdict-epoch=
|
|
99 |
justification_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=justification_tokenizer, model=bart_model).to(device)
|
100 |
# ---------------------------------------------------------------------------
|
101 |
|
102 |
-
|
103 |
# Set up Gradio Theme
|
104 |
theme = gr.themes.Base(
|
105 |
primary_hue="blue",
|
@@ -182,7 +194,7 @@ class SequenceClassificationDataLoader(pl.LightningDataModule):
|
|
182 |
)
|
183 |
|
184 |
|
185 |
-
def
|
186 |
bert_model_name = "bert-base-uncased"
|
187 |
tokenizer = BertTokenizer.from_pretrained(bert_model_name)
|
188 |
bert_model = BertForSequenceClassification.from_pretrained(bert_model_name, num_labels=4,
|
@@ -340,7 +352,7 @@ def extract_claim_str(claim, qa_evidence, verdict_label):
|
|
340 |
return claim_str
|
341 |
|
342 |
|
343 |
-
def
|
344 |
#
|
345 |
claim_str = extract_claim_str(claim, qa_evidence, verdict_label)
|
346 |
claim_str.strip()
|
@@ -383,21 +395,19 @@ def QAprediction(claim, evidence, sources):
|
|
383 |
for i, evi in enumerate(evidence, 1):
|
384 |
part = f"""<span>Doc {i}</span>"""
|
385 |
subpart = f"""<a href="#doc{i}" class="a-doc-ref" target="_self"><span class='doc-ref'><sup>{i}</sup></span></a>"""
|
386 |
-
# subpart = f"""<span class='doc-ref'>{i}</sup></span>"""
|
387 |
subparts = "".join([part, subpart])
|
388 |
parts.append(subparts)
|
389 |
|
390 |
evidence_part = ", ".join(parts)
|
391 |
|
392 |
prediction_title = f"""<h5>Prediction:</h5>"""
|
393 |
-
# if 'Google' in sources
|
394 |
-
# verdict_label =
|
395 |
-
# justification_label =
|
396 |
-
# # justification_label = "See retrieved docs."
|
397 |
# justification_part = f"""<span>Justification: {justification_label}</span>"""
|
398 |
# if 'WikiPedia' in sources:
|
399 |
-
#
|
400 |
-
# justification_label =
|
401 |
# # justification_label = "See retrieved docs."
|
402 |
# justification_part = f"""<span>Justification: {justification_label}</span>"""
|
403 |
|
@@ -406,11 +416,8 @@ def QAprediction(claim, evidence, sources):
|
|
406 |
# justification_label = "See retrieved docs."
|
407 |
justification_part = f"""<span>Justification: {justification_label}</span>"""
|
408 |
|
409 |
-
|
410 |
verdict_part = f"""Verdict: <span>{verdict_label}.</span><br>"""
|
411 |
-
|
412 |
content_parts = "".join([evidence_title, evidence_part, prediction_title, verdict_part, justification_part])
|
413 |
-
# content_parts = "".join([evidence_title, evidence_part, verdict_title, verdict_part, justification_title, justification_part])
|
414 |
|
415 |
return content_parts, [verdict_label, justification_label]
|
416 |
|
@@ -418,8 +425,8 @@ def QAprediction(claim, evidence, sources):
|
|
418 |
# ----------GoogleAPIretriever---------
|
419 |
def generate_reference_corpus(reference_file):
|
420 |
with open(reference_file) as f:
|
421 |
-
|
422 |
-
train_examples =
|
423 |
|
424 |
all_data_corpus = []
|
425 |
tokenized_corpus = []
|
@@ -456,16 +463,16 @@ def docs2prompt(top_docs):
|
|
456 |
|
457 |
def prompt_question_generation(test_claim, speaker="they", topk=10):
|
458 |
#
|
459 |
-
reference_file = "
|
460 |
tokenized_corpus, all_data_corpus = generate_reference_corpus(reference_file)
|
461 |
bm25 = BM25Okapi(tokenized_corpus)
|
462 |
|
463 |
# Define the bloom model:
|
464 |
accelerator = Accelerator()
|
465 |
-
accel_device = accelerator.device
|
466 |
-
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
467 |
-
tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-7b1")
|
468 |
-
model = BloomForCausalLM.from_pretrained("bigscience/bloom-7b1", torch_dtype=torch.bfloat16).to(device)
|
469 |
|
470 |
# --------------------------------------------------
|
471 |
# test claim
|
@@ -478,11 +485,11 @@ def prompt_question_generation(test_claim, speaker="they", topk=10):
|
|
478 |
"\". Criticism includes questions like: "
|
479 |
sentences = [prompt]
|
480 |
|
481 |
-
inputs =
|
482 |
-
outputs =
|
483 |
early_stopping=True)
|
484 |
|
485 |
-
tgt_text =
|
486 |
in_len = len(sentences[0])
|
487 |
questions_str = tgt_text[in_len:].split("\n")[0]
|
488 |
|
@@ -592,7 +599,7 @@ def get_google_search_results(api_key, search_engine_id, google_search, sort_dat
|
|
592 |
return search_results
|
593 |
|
594 |
|
595 |
-
def averitec_search(claim, generate_question, speaker="they", check_date="2024-
|
596 |
# default config
|
597 |
api_key = os.environ["GOOGLE_API_KEY"]
|
598 |
search_engine_id = os.environ["GOOGLE_SEARCH_ENGINE_ID"]
|
@@ -612,13 +619,14 @@ def averitec_search(claim, generate_question, speaker="they", check_date="2024-0
|
|
612 |
]
|
613 |
|
614 |
# save to folder
|
615 |
-
store_folder = "
|
616 |
#
|
617 |
index = 0
|
618 |
questions = [q["question"] for q in generate_question]
|
619 |
|
620 |
# check the date of the claim
|
621 |
-
|
|
|
622 |
|
623 |
#
|
624 |
search_strings = []
|
@@ -643,6 +651,7 @@ def averitec_search(claim, generate_question, speaker="they", check_date="2024-0
|
|
643 |
for page_num in range(n_pages):
|
644 |
search_results = get_google_search_results(api_key, search_engine_id, google_search, sort_date,
|
645 |
this_search_string, page=page_num)
|
|
|
646 |
|
647 |
for result in search_results:
|
648 |
link = str(result["link"])
|
@@ -733,21 +742,21 @@ def generate_step2_reference_corpus(reference_file):
|
|
733 |
|
734 |
def decorate_with_questions(claim, retrieve_evidence, top_k=10): # top_k=100
|
735 |
#
|
736 |
-
reference_file = "
|
737 |
tokenized_corpus, prompt_corpus = generate_step2_reference_corpus(reference_file)
|
738 |
prompt_bm25 = BM25Okapi(tokenized_corpus)
|
739 |
|
740 |
# Define the bloom model:
|
741 |
accelerator = Accelerator()
|
742 |
accel_device = accelerator.device
|
743 |
-
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
744 |
-
tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-7b1")
|
745 |
-
model = BloomForCausalLM.from_pretrained(
|
746 |
-
|
747 |
-
|
748 |
-
|
749 |
-
|
750 |
-
)
|
751 |
|
752 |
#
|
753 |
tokenized_corpus = []
|
@@ -795,11 +804,11 @@ def decorate_with_questions(claim, retrieve_evidence, top_k=10): # top_k=100
|
|
795 |
prompt = "\n\n".join(prompt_docs + [claim_prompt])
|
796 |
sentences = [prompt]
|
797 |
|
798 |
-
inputs =
|
799 |
-
outputs =
|
800 |
early_stopping=True)
|
801 |
|
802 |
-
tgt_text =
|
803 |
# We are not allowed to generate more than 250 characters:
|
804 |
tgt_text = tgt_text[:250]
|
805 |
|
@@ -815,13 +824,13 @@ def triple_to_string(x):
|
|
815 |
|
816 |
def rerank_questions(claim, bm25_qas, topk=3):
|
817 |
#
|
818 |
-
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
819 |
-
bert_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2,
|
820 |
-
|
821 |
-
best_checkpoint = "
|
822 |
-
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
823 |
-
trained_model = DualEncoderModule.load_from_checkpoint(best_checkpoint, tokenizer=tokenizer, model=bert_model).to(
|
824 |
-
|
825 |
|
826 |
#
|
827 |
strs_to_score = []
|
@@ -834,13 +843,13 @@ def rerank_questions(claim, bm25_qas, topk=3):
|
|
834 |
values.append([question, answer, source])
|
835 |
|
836 |
if len(bm25_qas) > 0:
|
837 |
-
encoded_dict =
|
838 |
return_tensors="pt").to(device)
|
839 |
|
840 |
input_ids = encoded_dict['input_ids']
|
841 |
attention_masks = encoded_dict['attention_mask']
|
842 |
|
843 |
-
scores = torch.softmax(
|
844 |
|
845 |
top_n = torch.argsort(scores, descending=True)[:topk]
|
846 |
pass_through = [{"question": values[i][0], "answers": values[i][1], "source_url": values[i][2]} for i in top_n]
|
@@ -852,20 +861,16 @@ def rerank_questions(claim, bm25_qas, topk=3):
|
|
852 |
return top3_qa_pairs
|
853 |
|
854 |
|
855 |
-
def
|
856 |
# ----- Generate QA pairs using AVeriTeC
|
857 |
-
|
858 |
-
|
859 |
-
|
860 |
-
|
861 |
-
|
862 |
-
|
863 |
-
|
864 |
-
|
865 |
-
# step 4: rerank QA pairs
|
866 |
-
top3_qa_pairs = rerank_questions(query, bm25_qa_pairs)
|
867 |
-
else:
|
868 |
-
top3_qa_pairs = json.load(open(top3_qa_pairs_path, 'r'))
|
869 |
|
870 |
# Add score to metadata
|
871 |
results = []
|
@@ -877,12 +882,14 @@ def GoogleAPIretriever(query):
|
|
877 |
metadata['cached_source_url'] = qa['source_url']
|
878 |
metadata['short_name'] = "Evidence {}".format(i + 1)
|
879 |
metadata['page_number'] = ""
|
|
|
|
|
880 |
metadata['query'] = qa['question']
|
881 |
metadata['answer'] = qa['answers']
|
882 |
metadata['page_content'] = "<b>Question</b>: " + qa['question'] + "<br>" + "<b>Answer</b>: " + qa['answers']
|
883 |
page_content = f"""{metadata['page_content']}"""
|
884 |
-
results.append((metadata, page_content))
|
885 |
|
|
|
886 |
return results
|
887 |
|
888 |
|
@@ -1181,11 +1188,8 @@ def log_on_azure(file, logs, azure_share_client):
|
|
1181 |
|
1182 |
def chat(claim, history, sources):
|
1183 |
evidence = []
|
1184 |
-
|
1185 |
-
|
1186 |
-
|
1187 |
-
# if 'WikiPediaDumps' in sources:
|
1188 |
-
# evidence = WikipediaDumpsretriever(query)
|
1189 |
|
1190 |
if 'WikiPedia' in sources:
|
1191 |
evidence = Wikipediaretriever(claim, sources)
|
@@ -1212,7 +1216,8 @@ def chat(claim, history, sources):
|
|
1212 |
for evi in evidence:
|
1213 |
title_str = evi.metadata['title']
|
1214 |
evi_str = evi.metadata['evidence']
|
1215 |
-
|
|
|
1216 |
|
1217 |
try:
|
1218 |
# Log answer on Azure Blob Storage
|
@@ -1226,7 +1231,6 @@ def chat(claim, history, sources):
|
|
1226 |
"claim": claim,
|
1227 |
"sources": sources,
|
1228 |
"evidence": evi_list,
|
1229 |
-
"url": url_of_evidence,
|
1230 |
"answer": answer_output,
|
1231 |
"time": timestamp,
|
1232 |
}
|
@@ -1254,7 +1258,7 @@ def main():
|
|
1254 |
chatbot = gr.Chatbot(
|
1255 |
value=[(None, init_prompt)],
|
1256 |
show_copy_button=True, show_label=False, elem_id="chatbot", layout="panel",
|
1257 |
-
avatar_images=(None, "assets/averitec.png")
|
1258 |
) # avatar_images=(None, "https://i.ibb.co/YNyd5W2/logo4.png"),
|
1259 |
|
1260 |
with gr.Row(elem_id="input-message"):
|
|
|
7 |
import torch
|
8 |
import numpy as np
|
9 |
from time import sleep
|
10 |
+
from datetime import datetime
|
11 |
import threading
|
12 |
import gc
|
13 |
import os
|
|
|
35 |
from utils import create_user_id
|
36 |
user_id = create_user_id()
|
37 |
|
|
|
38 |
from azure.storage.fileshare import ShareServiceClient
|
39 |
try:
|
40 |
from dotenv import load_dotenv
|
|
|
86 |
"Not Enough Evidence",
|
87 |
"Conflicting Evidence/Cherrypicking",
|
88 |
]
|
89 |
+
|
90 |
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
91 |
+
# Question Generation
|
92 |
+
qg_tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-1b1")
|
93 |
+
qg_model = BloomForCausalLM.from_pretrained("bigscience/bloom-1b1", torch_dtype=torch.bfloat16).to(device)
|
94 |
+
# qg_tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-7b1")
|
95 |
+
# qg_model = BloomForCausalLM.from_pretrained("bigscience/bloom-7b1", torch_dtype=torch.bfloat16).to(device)
|
96 |
+
|
97 |
+
# rerank
|
98 |
+
rerank_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
99 |
+
rereank_bert_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2, problem_type="single_label_classification") # Must specify single_label for some reason
|
100 |
+
best_checkpoint = "averitec/pretrained_models/bert_dual_encoder.ckpt"
|
101 |
+
rerank_trained_model = DualEncoderModule.load_from_checkpoint(best_checkpoint, tokenizer=rerank_tokenizer, model=rereank_bert_model).to(device)
|
102 |
+
|
103 |
+
# Veracity
|
104 |
veracity_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
105 |
bert_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=4, problem_type="single_label_classification")
|
106 |
veracity_model = SequenceClassificationModule.load_from_checkpoint("averitec/pretrained_models/bert_veracity.ckpt",
|
|
|
112 |
justification_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=justification_tokenizer, model=bart_model).to(device)
|
113 |
# ---------------------------------------------------------------------------
|
114 |
|
|
|
115 |
# Set up Gradio Theme
|
116 |
theme = gr.themes.Base(
|
117 |
primary_hue="blue",
|
|
|
194 |
)
|
195 |
|
196 |
|
197 |
+
def google_veracity_prediction(claim, qa_evidence):
|
198 |
bert_model_name = "bert-base-uncased"
|
199 |
tokenizer = BertTokenizer.from_pretrained(bert_model_name)
|
200 |
bert_model = BertForSequenceClassification.from_pretrained(bert_model_name, num_labels=4,
|
|
|
352 |
return claim_str
|
353 |
|
354 |
|
355 |
+
def google_justification_generation(claim, qa_evidence, verdict_label):
|
356 |
#
|
357 |
claim_str = extract_claim_str(claim, qa_evidence, verdict_label)
|
358 |
claim_str.strip()
|
|
|
395 |
for i, evi in enumerate(evidence, 1):
|
396 |
part = f"""<span>Doc {i}</span>"""
|
397 |
subpart = f"""<a href="#doc{i}" class="a-doc-ref" target="_self"><span class='doc-ref'><sup>{i}</sup></span></a>"""
|
|
|
398 |
subparts = "".join([part, subpart])
|
399 |
parts.append(subparts)
|
400 |
|
401 |
evidence_part = ", ".join(parts)
|
402 |
|
403 |
prediction_title = f"""<h5>Prediction:</h5>"""
|
404 |
+
# if 'Google' in sources:
|
405 |
+
# verdict_label = google_veracity_prediction(claim, evidence)
|
406 |
+
# justification_label = google_justification_generation(claim, evidence, verdict_label)
|
|
|
407 |
# justification_part = f"""<span>Justification: {justification_label}</span>"""
|
408 |
# if 'WikiPedia' in sources:
|
409 |
+
# verdict_label = wikipedia_veracity_prediction(claim, evidence)
|
410 |
+
# justification_label = wikipedia_justification_generation(claim, evidence, verdict_label)
|
411 |
# # justification_label = "See retrieved docs."
|
412 |
# justification_part = f"""<span>Justification: {justification_label}</span>"""
|
413 |
|
|
|
416 |
# justification_label = "See retrieved docs."
|
417 |
justification_part = f"""<span>Justification: {justification_label}</span>"""
|
418 |
|
|
|
419 |
verdict_part = f"""Verdict: <span>{verdict_label}.</span><br>"""
|
|
|
420 |
content_parts = "".join([evidence_title, evidence_part, prediction_title, verdict_part, justification_part])
|
|
|
421 |
|
422 |
return content_parts, [verdict_label, justification_label]
|
423 |
|
|
|
425 |
# ----------GoogleAPIretriever---------
|
426 |
def generate_reference_corpus(reference_file):
|
427 |
with open(reference_file) as f:
|
428 |
+
#
|
429 |
+
train_examples = json.load(f)
|
430 |
|
431 |
all_data_corpus = []
|
432 |
tokenized_corpus = []
|
|
|
463 |
|
464 |
def prompt_question_generation(test_claim, speaker="they", topk=10):
|
465 |
#
|
466 |
+
reference_file = "averitec/data/train.json"
|
467 |
tokenized_corpus, all_data_corpus = generate_reference_corpus(reference_file)
|
468 |
bm25 = BM25Okapi(tokenized_corpus)
|
469 |
|
470 |
# Define the bloom model:
|
471 |
accelerator = Accelerator()
|
472 |
+
# accel_device = accelerator.device
|
473 |
+
# device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
474 |
+
# tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-7b1")
|
475 |
+
# model = BloomForCausalLM.from_pretrained("bigscience/bloom-7b1", torch_dtype=torch.bfloat16).to(device)
|
476 |
|
477 |
# --------------------------------------------------
|
478 |
# test claim
|
|
|
485 |
"\". Criticism includes questions like: "
|
486 |
sentences = [prompt]
|
487 |
|
488 |
+
inputs = qg_tokenizer(sentences, padding=True, return_tensors="pt").to(device)
|
489 |
+
outputs = qg_model.generate(inputs["input_ids"], max_length=2000, num_beams=2, no_repeat_ngram_size=2,
|
490 |
early_stopping=True)
|
491 |
|
492 |
+
tgt_text = qg_tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
|
493 |
in_len = len(sentences[0])
|
494 |
questions_str = tgt_text[in_len:].split("\n")[0]
|
495 |
|
|
|
599 |
return search_results
|
600 |
|
601 |
|
602 |
+
def averitec_search(claim, generate_question, speaker="they", check_date="2024-07-01", n_pages=1): # n_pages=3
|
603 |
# default config
|
604 |
api_key = os.environ["GOOGLE_API_KEY"]
|
605 |
search_engine_id = os.environ["GOOGLE_SEARCH_ENGINE_ID"]
|
|
|
619 |
]
|
620 |
|
621 |
# save to folder
|
622 |
+
store_folder = "averitec/data/store/retrieved_docs"
|
623 |
#
|
624 |
index = 0
|
625 |
questions = [q["question"] for q in generate_question]
|
626 |
|
627 |
# check the date of the claim
|
628 |
+
current_date = datetime.now().strftime("%Y-%m-%d")
|
629 |
+
sort_date = check_claim_date(current_date) # check_date="2022-01-01"
|
630 |
|
631 |
#
|
632 |
search_strings = []
|
|
|
651 |
for page_num in range(n_pages):
|
652 |
search_results = get_google_search_results(api_key, search_engine_id, google_search, sort_date,
|
653 |
this_search_string, page=page_num)
|
654 |
+
search_results = search_results[:5]
|
655 |
|
656 |
for result in search_results:
|
657 |
link = str(result["link"])
|
|
|
742 |
|
743 |
def decorate_with_questions(claim, retrieve_evidence, top_k=10): # top_k=100
|
744 |
#
|
745 |
+
reference_file = "averitec/data/train.json"
|
746 |
tokenized_corpus, prompt_corpus = generate_step2_reference_corpus(reference_file)
|
747 |
prompt_bm25 = BM25Okapi(tokenized_corpus)
|
748 |
|
749 |
# Define the bloom model:
|
750 |
accelerator = Accelerator()
|
751 |
accel_device = accelerator.device
|
752 |
+
# device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
753 |
+
# tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-7b1")
|
754 |
+
# model = BloomForCausalLM.from_pretrained(
|
755 |
+
# "bigscience/bloom-7b1",
|
756 |
+
# device_map="auto",
|
757 |
+
# torch_dtype=torch.bfloat16,
|
758 |
+
# offload_folder="./offload"
|
759 |
+
# )
|
760 |
|
761 |
#
|
762 |
tokenized_corpus = []
|
|
|
804 |
prompt = "\n\n".join(prompt_docs + [claim_prompt])
|
805 |
sentences = [prompt]
|
806 |
|
807 |
+
inputs = qg_tokenizer(sentences, padding=True, return_tensors="pt").to(device)
|
808 |
+
outputs = qg_model.generate(inputs["input_ids"], max_length=5000, num_beams=2, no_repeat_ngram_size=2,
|
809 |
early_stopping=True)
|
810 |
|
811 |
+
tgt_text = qg_tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:], skip_special_tokens=True)[0]
|
812 |
# We are not allowed to generate more than 250 characters:
|
813 |
tgt_text = tgt_text[:250]
|
814 |
|
|
|
824 |
|
825 |
def rerank_questions(claim, bm25_qas, topk=3):
|
826 |
#
|
827 |
+
# tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
828 |
+
# bert_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2,
|
829 |
+
# problem_type="single_label_classification") # Must specify single_label for some reason
|
830 |
+
# best_checkpoint = "averitec/pretrained_models/bert_dual_encoder.ckpt"
|
831 |
+
# device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
832 |
+
# trained_model = DualEncoderModule.load_from_checkpoint(best_checkpoint, tokenizer=tokenizer, model=bert_model).to(
|
833 |
+
# device)
|
834 |
|
835 |
#
|
836 |
strs_to_score = []
|
|
|
843 |
values.append([question, answer, source])
|
844 |
|
845 |
if len(bm25_qas) > 0:
|
846 |
+
encoded_dict = rerank_tokenizer(strs_to_score, max_length=512, padding="longest", truncation=True,
|
847 |
return_tensors="pt").to(device)
|
848 |
|
849 |
input_ids = encoded_dict['input_ids']
|
850 |
attention_masks = encoded_dict['attention_mask']
|
851 |
|
852 |
+
scores = torch.softmax(rerank_trained_model(input_ids, attention_mask=attention_masks).logits, axis=-1)[:, 1]
|
853 |
|
854 |
top_n = torch.argsort(scores, descending=True)[:topk]
|
855 |
pass_through = [{"question": values[i][0], "answers": values[i][1], "source_url": values[i][2]} for i in top_n]
|
|
|
861 |
return top3_qa_pairs
|
862 |
|
863 |
|
864 |
+
def Googleretriever(query, sources):
|
865 |
# ----- Generate QA pairs using AVeriTeC
|
866 |
+
# step 1: generate questions for the query/claim using Bloom
|
867 |
+
generate_question = prompt_question_generation(query)
|
868 |
+
# step 2: retrieve evidence for the generated questions using Google API
|
869 |
+
retrieve_evidence = averitec_search(query, generate_question)
|
870 |
+
# step 3: generate QA pairs for each retrieved document
|
871 |
+
bm25_qa_pairs = decorate_with_questions(query, retrieve_evidence)
|
872 |
+
# step 4: rerank QA pairs
|
873 |
+
top3_qa_pairs = rerank_questions(query, bm25_qa_pairs)
|
|
|
|
|
|
|
|
|
874 |
|
875 |
# Add score to metadata
|
876 |
results = []
|
|
|
882 |
metadata['cached_source_url'] = qa['source_url']
|
883 |
metadata['short_name'] = "Evidence {}".format(i + 1)
|
884 |
metadata['page_number'] = ""
|
885 |
+
metadata['title'] = qa['question']
|
886 |
+
metadata['evidence'] = qa['answers']
|
887 |
metadata['query'] = qa['question']
|
888 |
metadata['answer'] = qa['answers']
|
889 |
metadata['page_content'] = "<b>Question</b>: " + qa['question'] + "<br>" + "<b>Answer</b>: " + qa['answers']
|
890 |
page_content = f"""{metadata['page_content']}"""
|
|
|
891 |
|
892 |
+
results.append(Docs(metadata, page_content))
|
893 |
return results
|
894 |
|
895 |
|
|
|
1188 |
|
1189 |
def chat(claim, history, sources):
|
1190 |
evidence = []
|
1191 |
+
if 'Google' in sources:
|
1192 |
+
evidence = Googleretriever(claim, sources)
|
|
|
|
|
|
|
1193 |
|
1194 |
if 'WikiPedia' in sources:
|
1195 |
evidence = Wikipediaretriever(claim, sources)
|
|
|
1216 |
for evi in evidence:
|
1217 |
title_str = evi.metadata['title']
|
1218 |
evi_str = evi.metadata['evidence']
|
1219 |
+
url_str = evi.metadata['url']
|
1220 |
+
evi_list.append([title_str, evi_str, url_str])
|
1221 |
|
1222 |
try:
|
1223 |
# Log answer on Azure Blob Storage
|
|
|
1231 |
"claim": claim,
|
1232 |
"sources": sources,
|
1233 |
"evidence": evi_list,
|
|
|
1234 |
"answer": answer_output,
|
1235 |
"time": timestamp,
|
1236 |
}
|
|
|
1258 |
chatbot = gr.Chatbot(
|
1259 |
value=[(None, init_prompt)],
|
1260 |
show_copy_button=True, show_label=False, elem_id="chatbot", layout="panel",
|
1261 |
+
avatar_images = (None, "assets/averitec.png")
|
1262 |
) # avatar_images=(None, "https://i.ibb.co/YNyd5W2/logo4.png"),
|
1263 |
|
1264 |
with gr.Row(elem_id="input-message"):
|