AVeriTeC-API / app.py
zhenyundeng
add files
e62781a
raw
history blame
50.8 kB
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Created by zd302 at 08/07/2024
import gradio as gr
import tqdm
import torch
import numpy as np
from time import sleep
import threading
import gc
import os
import json
import pytorch_lightning as pl
from urllib.parse import urlparse
from accelerate import Accelerator
from transformers import BartTokenizer, BartForConditionalGeneration
from transformers import BloomTokenizerFast, BloomForCausalLM, BertTokenizer, BertForSequenceClassification
from transformers import RobertaTokenizer, RobertaForSequenceClassification
from rank_bm25 import BM25Okapi
# import bm25s
# import Stemmer # optional: for stemming
from html2lines import url2lines
from googleapiclient.discovery import build
from averitec.models.DualEncoderModule import DualEncoderModule
from averitec.models.SequenceClassificationModule import SequenceClassificationModule
from averitec.models.JustificationGenerationModule import JustificationGenerationModule
from averitec.data.sample_claims import CLAIMS_Type
# ---------------------------------------------------------------------------
# load .env
from utils import create_user_id
user_id = create_user_id()
from datetime import datetime
from azure.storage.fileshare import ShareServiceClient
try:
from dotenv import load_dotenv
load_dotenv()
except Exception as e:
pass
account_url = os.environ["AZURE_ACCOUNT_URL"]
credential = {
"account_key": os.environ['AZURE_ACCOUNT_KEY'],
"account_name": os.environ['AZURE_ACCOUNT_NAME']
}
file_share_name = "averitec"
azure_service = ShareServiceClient(account_url=account_url, credential=credential)
azure_share_client = azure_service.get_share_client(file_share_name)
# ---------- Setting ----------
import requests
from bs4 import BeautifulSoup
import wikipediaapi
wiki_wiki = wikipediaapi.Wikipedia('AVeriTeC ([email protected])', 'en')
import nltk
nltk.download('punkt')
from nltk import pos_tag, word_tokenize, sent_tokenize
import spacy
os.system("python -m spacy download en_core_web_sm")
nlp = spacy.load("en_core_web_sm")
# ---------------------------------------------------------------------------
# Load sample dict for AVeriTeC search
# all_samples_dict = json.load(open('averitec/data/all_samples.json', 'r'))
# ---------------------------------------------------------------------------
# ---------- Load pretrained models ----------
# ---------- load Evidence retrieval model ----------
# from drqa import retriever
# db_class = retriever.get_class('sqlite')
# doc_db = db_class("averitec/data/wikipedia_dumps/enwiki.db")
# ranker = retriever.get_class('tfidf')(tfidf_path="averitec/data/wikipedia_dumps/enwiki-tfidf-with-id-title.npz")
# ---------- Load Veracity and Justification prediction model ----------
print("Loading models ...")
LABEL = [
"Supported",
"Refuted",
"Not Enough Evidence",
"Conflicting Evidence/Cherrypicking",
]
# Veracity
device = "cuda:0" if torch.cuda.is_available() else "cpu"
veracity_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
bert_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=4, problem_type="single_label_classification")
veracity_model = SequenceClassificationModule.load_from_checkpoint("averitec/pretrained_models/bert_veracity.ckpt",
tokenizer=veracity_tokenizer, model=bert_model).to(device)
# Justification
justification_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large', add_prefix_space=True)
bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large")
best_checkpoint = 'averitec/pretrained_models/bart_justifications_verdict-epoch=13-val_loss=2.03-val_meteor=0.28.ckpt'
justification_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=justification_tokenizer, model=bart_model).to(device)
# ---------------------------------------------------------------------------
# Set up Gradio Theme
theme = gr.themes.Base(
primary_hue="blue",
secondary_hue="red",
font=[gr.themes.GoogleFont("Poppins"), "ui-sans-serif", "system-ui", "sans-serif"],
)
# ---------- Setting ----------
class Docs:
def __init__(self, metadata=dict(), page_content=""):
self.metadata = metadata
self.page_content = page_content
def make_html_source(source, i):
meta = source.metadata
content = source.page_content.strip()
card = f"""
<div class="card" id="doc{i}">
<div class="card-content">
<h2>Doc {i} - URL: <a href="{meta['url']}" target="_blank" class="pdf-link">{meta['url']}</a></h2>
<p>{content}</p>
</div>
<div class="card-footer">
<span>CACHED SOURCE URL:</span>
<a href="{meta['cached_source_url']}" target="_blank" class="pdf-link">
<span role="img" aria-label="Open PDF">🔗</span>
</a>
</div>
</div>
"""
return card
# ----- veracity_prediction -----
class SequenceClassificationDataLoader(pl.LightningDataModule):
def __init__(self, tokenizer, data_file, batch_size, add_extra_nee=False):
super().__init__()
self.tokenizer = tokenizer
self.data_file = data_file
self.batch_size = batch_size
self.add_extra_nee = add_extra_nee
def tokenize_strings(
self,
source_sentences,
max_length=400,
pad_to_max_length=False,
return_tensors="pt",
):
encoded_dict = self.tokenizer(
source_sentences,
max_length=max_length,
padding="max_length" if pad_to_max_length else "longest",
truncation=True,
return_tensors=return_tensors,
)
input_ids = encoded_dict["input_ids"]
attention_masks = encoded_dict["attention_mask"]
return input_ids, attention_masks
def quadruple_to_string(self, claim, question, answer, bool_explanation=""):
if bool_explanation is not None and len(bool_explanation) > 0:
bool_explanation = ", because " + bool_explanation.lower().strip()
else:
bool_explanation = ""
return (
"[CLAIM] "
+ claim.strip()
+ " [QUESTION] "
+ question.strip()
+ " "
+ answer.strip()
+ bool_explanation
)
def averitec_veracity_prediction(claim, qa_evidence):
bert_model_name = "bert-base-uncased"
tokenizer = BertTokenizer.from_pretrained(bert_model_name)
bert_model = BertForSequenceClassification.from_pretrained(bert_model_name, num_labels=4,
problem_type="single_label_classification")
device = "cuda:0" if torch.cuda.is_available() else "cpu"
trained_model = SequenceClassificationModule.load_from_checkpoint("averitec/pretrained_models/bert_veracity.ckpt",
tokenizer=tokenizer, model=bert_model).to(device)
dataLoader = SequenceClassificationDataLoader(
tokenizer=tokenizer,
data_file="this_is_discontinued",
batch_size=32,
add_extra_nee=False,
)
evidence_strings = []
for evidence in qa_evidence:
evidence_strings.append(
dataLoader.quadruple_to_string(claim, evidence.metadata["query"], evidence.metadata["answer"], ""))
if len(evidence_strings) == 0: # If we found no evidence e.g. because google returned 0 pages, just output NEI.
pred_label = "Not Enough Evidence"
return pred_label
tokenized_strings, attention_mask = dataLoader.tokenize_strings(evidence_strings)
example_support = torch.argmax(
trained_model(tokenized_strings.to(device), attention_mask=attention_mask.to(device)).logits, axis=1)
has_unanswerable = False
has_true = False
has_false = False
for v in example_support:
if v == 0:
has_true = True
if v == 1:
has_false = True
if v in (2, 3,): # TODO another hack -- we cant have different labels for train and test so we do this
has_unanswerable = True
if has_unanswerable:
answer = 2
elif has_true and not has_false:
answer = 0
elif not has_true and has_false:
answer = 1
else:
answer = 3
pred_label = LABEL[answer]
return pred_label
def fever_veracity_prediction(claim, evidence):
tokenizer = RobertaTokenizer.from_pretrained('Dzeniks/roberta-fact-check')
model = RobertaForSequenceClassification.from_pretrained('Dzeniks/roberta-fact-check')
evidence_string = ""
for evi in evidence:
evidence_string += evi.metadata['title'] + evi.metadata['evidence'] + ' '
input_sequence = tokenizer.encode_plus(claim, evidence_string, return_tensors="pt")
with torch.no_grad():
prediction = model(**input_sequence)
label = torch.argmax(prediction[0]).item()
pred_label = LABEL[label]
return pred_label
def veracity_prediction(claim, qa_evidence):
# bert_model_name = "bert-base-uncased"
# tokenizer = BertTokenizer.from_pretrained(bert_model_name)
# bert_model = BertForSequenceClassification.from_pretrained(bert_model_name, num_labels=4,
# problem_type="single_label_classification")
#
# device = "cuda:0" if torch.cuda.is_available() else "cpu"
# trained_model = SequenceClassificationModule.load_from_checkpoint("averitec/pretrained_models/bert_veracity.ckpt",
# tokenizer=tokenizer, model=bert_model).to(device)
dataLoader = SequenceClassificationDataLoader(
tokenizer=veracity_tokenizer,
data_file="this_is_discontinued",
batch_size=32,
add_extra_nee=False,
)
evidence_strings = []
for evidence in qa_evidence:
evidence_strings.append(
dataLoader.quadruple_to_string(claim, evidence.metadata["query"], evidence.metadata["answer"], ""))
if len(evidence_strings) == 0: # If we found no evidence e.g. because google returned 0 pages, just output NEI.
pred_label = "Not Enough Evidence"
return pred_label
tokenized_strings, attention_mask = dataLoader.tokenize_strings(evidence_strings)
example_support = torch.argmax(veracity_model(tokenized_strings.to(device), attention_mask=attention_mask.to(device)).logits, axis=1)
has_unanswerable = False
has_true = False
has_false = False
for v in example_support:
if v == 0:
has_true = True
if v == 1:
has_false = True
if v in (2, 3,): # TODO another hack -- we cant have different labels for train and test so we do this
has_unanswerable = True
if has_unanswerable:
answer = 2
elif has_true and not has_false:
answer = 0
elif not has_true and has_false:
answer = 1
else:
answer = 3
pred_label = LABEL[answer]
return pred_label
def extract_claim_str(claim, qa_evidence, verdict_label):
claim_str = "[CLAIM] " + claim + " [EVIDENCE] "
for evidence in qa_evidence:
q_text = evidence.metadata['query'].strip()
if len(q_text) == 0:
continue
if not q_text[-1] == "?":
q_text += "?"
answer_strings = []
answer_strings.append(evidence.metadata['answer'])
claim_str += q_text
for a_text in answer_strings:
if a_text:
if not a_text[-1] == ".":
a_text += "."
claim_str += " " + a_text.strip()
claim_str += " "
claim_str += " [VERDICT] " + verdict_label
return claim_str
def averitec_justification_generation(claim, qa_evidence, verdict_label):
#
claim_str = extract_claim_str(claim, qa_evidence, verdict_label)
claim_str.strip()
device = "cuda:0" if torch.cuda.is_available() else "cpu"
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large', add_prefix_space=True)
bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large")
best_checkpoint = 'averitec/pretrained_models/bart_justifications_verdict-epoch=13-val_loss=2.03-val_meteor=0.28.ckpt'
trained_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=tokenizer,
model=bart_model).to(device)
pred_justification = trained_model.generate(claim_str, device=device)
return pred_justification.strip()
def justification_generation(claim, qa_evidence, verdict_label):
#
claim_str = extract_claim_str(claim, qa_evidence, verdict_label)
claim_str.strip()
# device = "cuda:0" if torch.cuda.is_available() else "cpu"
# tokenizer = BartTokenizer.from_pretrained('facebook/bart-large', add_prefix_space=True)
# bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large")
#
# best_checkpoint = 'averitec/pretrained_models/bart_justifications_verdict-epoch=13-val_loss=2.03-val_meteor=0.28.ckpt'
# trained_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=tokenizer,
# model=bart_model).to(device)
pred_justification = justification_model.generate(claim_str, device=device)
return pred_justification.strip()
def QAprediction(claim, evidence, sources):
parts = []
#
evidence_title = f"""<h5>Retrieved Evidence:</h5>"""
for i, evi in enumerate(evidence, 1):
part = f"""<span>Doc {i}</span>"""
subpart = f"""<a href="#doc{i}" class="a-doc-ref" target="_self"><span class='doc-ref'><sup>{i}</sup></span></a>"""
# subpart = f"""<span class='doc-ref'>{i}</sup></span>"""
subparts = "".join([part, subpart])
parts.append(subparts)
evidence_part = ", ".join(parts)
prediction_title = f"""<h5>Prediction:</h5>"""
# if 'Google' in sources or 'AVeriTeC' in sources:
# verdict_label = averitec_veracity_prediction(claim, evidence)
# justification_label = averitec_justification_generation(claim, evidence, verdict_label)
# # justification_label = "See retrieved docs."
# justification_part = f"""<span>Justification: {justification_label}</span>"""
# if 'WikiPedia' in sources:
# # verdict_label = fever_veracity_prediction(claim, evidence)
# justification_label = averitec_justification_generation(claim, evidence, verdict_label)
# # justification_label = "See retrieved docs."
# justification_part = f"""<span>Justification: {justification_label}</span>"""
verdict_label = veracity_prediction(claim, evidence)
justification_label = justification_generation(claim, evidence, verdict_label)
# justification_label = "See retrieved docs."
justification_part = f"""<span>Justification: {justification_label}</span>"""
verdict_part = f"""Verdict: <span>{verdict_label}.</span><br>"""
content_parts = "".join([evidence_title, evidence_part, prediction_title, verdict_part, justification_part])
# content_parts = "".join([evidence_title, evidence_part, verdict_title, verdict_part, justification_title, justification_part])
return content_parts, [verdict_label, justification_label]
# ----------GoogleAPIretriever---------
def generate_reference_corpus(reference_file):
with open(reference_file) as f:
j = json.load(f)
train_examples = j
all_data_corpus = []
tokenized_corpus = []
for train_example in train_examples:
train_claim = train_example["claim"]
speaker = train_example["speaker"].strip() if train_example["speaker"] is not None and len(
train_example["speaker"]) > 1 else "they"
questions = [q["question"] for q in train_example["questions"]]
claim_dict_builder = {}
claim_dict_builder["claim"] = train_claim
claim_dict_builder["speaker"] = speaker
claim_dict_builder["questions"] = questions
tokenized_corpus.append(nltk.word_tokenize(claim_dict_builder["claim"]))
all_data_corpus.append(claim_dict_builder)
return tokenized_corpus, all_data_corpus
def doc2prompt(doc):
prompt_parts = "Outrageously, " + doc["speaker"] + " claimed that \"" + doc[
"claim"].strip() + "\". Criticism includes questions like: "
questions = [q.strip() for q in doc["questions"]]
return prompt_parts + " ".join(questions)
def docs2prompt(top_docs):
return "\n\n".join([doc2prompt(d) for d in top_docs])
def prompt_question_generation(test_claim, speaker="they", topk=10):
#
reference_file = "averitec_code/data/train.json"
tokenized_corpus, all_data_corpus = generate_reference_corpus(reference_file)
bm25 = BM25Okapi(tokenized_corpus)
# Define the bloom model:
accelerator = Accelerator()
accel_device = accelerator.device
device = "cuda:0" if torch.cuda.is_available() else "cpu"
tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-7b1")
model = BloomForCausalLM.from_pretrained("bigscience/bloom-7b1", torch_dtype=torch.bfloat16).to(device)
# --------------------------------------------------
# test claim
s = bm25.get_scores(nltk.word_tokenize(test_claim))
top_n = np.argsort(s)[::-1][:topk]
docs = [all_data_corpus[i] for i in top_n]
# --------------------------------------------------
prompt = docs2prompt(docs) + "\n\n" + "Outrageously, " + speaker + " claimed that \"" + test_claim.strip() + \
"\". Criticism includes questions like: "
sentences = [prompt]
inputs = tokenizer(sentences, padding=True, return_tensors="pt").to(device)
outputs = model.generate(inputs["input_ids"], max_length=2000, num_beams=2, no_repeat_ngram_size=2,
early_stopping=True)
tgt_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
in_len = len(sentences[0])
questions_str = tgt_text[in_len:].split("\n")[0]
qs = questions_str.split("?")
qs = [q.strip() + "?" for q in qs if q.strip() and len(q.strip()) < 300]
#
generate_question = [{"question": q, "answers": []} for q in qs]
return generate_question
def check_claim_date(check_date):
try:
year, month, date = check_date.split("-")
except:
month, date, year = "01", "01", "2022"
if len(year) == 2 and int(year) <= 30:
year = "20" + year
elif len(year) == 2:
year = "19" + year
elif len(year) == 1:
year = "200" + year
if len(month) == 1:
month = "0" + month
if len(date) == 1:
date = "0" + date
sort_date = year + month + date
return sort_date
def string_to_search_query(text, author):
parts = word_tokenize(text.strip())
tags = pos_tag(parts)
keep_tags = ["CD", "JJ", "NN", "VB"]
if author is not None:
search_string = author.split()
else:
search_string = []
for token, tag in zip(parts, tags):
for keep_tag in keep_tags:
if tag[1].startswith(keep_tag):
search_string.append(token)
search_string = " ".join(search_string)
return search_string
def google_search(search_term, api_key, cse_id, **kwargs):
service = build("customsearch", "v1", developerKey=api_key)
res = service.cse().list(q=search_term, cx=cse_id, **kwargs).execute()
if "items" in res:
return res['items']
else:
return []
def get_domain_name(url):
if '://' not in url:
url = 'http://' + url
domain = urlparse(url).netloc
if domain.startswith("www."):
return domain[4:]
else:
return domain
def get_and_store(url_link, fp, worker, worker_stack):
page_lines = url2lines(url_link)
with open(fp, "w") as out_f:
print("\n".join([url_link] + page_lines), file=out_f)
worker_stack.append(worker)
gc.collect()
def get_google_search_results(api_key, search_engine_id, google_search, sort_date, search_string, page=0):
search_results = []
for i in range(3):
try:
search_results += google_search(
search_string,
api_key,
search_engine_id,
num=10,
start=0 + 10 * page,
sort="date:r:19000101:" + sort_date,
dateRestrict=None,
gl="US"
)
break
except:
sleep(3)
return search_results
def averitec_search(claim, generate_question, speaker="they", check_date="2024-01-01", n_pages=1): # n_pages=3
# default config
api_key = os.environ["GOOGLE_API_KEY"]
search_engine_id = os.environ["GOOGLE_SEARCH_ENGINE_ID"]
blacklist = [
"jstor.org", # Blacklisted because their pdfs are not labelled as such, and clog up the download
"facebook.com", # Blacklisted because only post titles can be scraped, but the scraper doesn't know this,
"ftp.cs.princeton.edu", # Blacklisted because it hosts many large NLP corpora that keep showing up
"nlp.cs.princeton.edu",
"huggingface.co"
]
blacklist_files = [ # Blacklisted some NLP nonsense that crashes my machine with OOM errors
"/glove.",
"ftp://ftp.cs.princeton.edu/pub/cs226/autocomplete/words-333333.txt",
"https://web.mit.edu/adamrose/Public/googlelist",
]
# save to folder
store_folder = "averitec_code/store/retrieved_docs"
#
index = 0
questions = [q["question"] for q in generate_question]
# check the date of the claim
sort_date = check_claim_date(check_date) # check_date="2022-01-01"
#
search_strings = []
search_types = []
search_string_2 = string_to_search_query(claim, None)
search_strings += [search_string_2, claim, ]
search_types += ["claim", "claim-noformat", ]
search_strings += questions
search_types += ["question" for _ in questions]
# start to search
search_results = []
visited = {}
store_counter = 0
worker_stack = list(range(10))
retrieve_evidence = []
for this_search_string, this_search_type in zip(search_strings, search_types):
for page_num in range(n_pages):
search_results = get_google_search_results(api_key, search_engine_id, google_search, sort_date,
this_search_string, page=page_num)
for result in search_results:
link = str(result["link"])
domain = get_domain_name(link)
if domain in blacklist:
continue
broken = False
for b_file in blacklist_files:
if b_file in link:
broken = True
if broken:
continue
if link.endswith(".pdf") or link.endswith(".doc"):
continue
store_file_path = ""
if link in visited:
store_file_path = visited[link]
else:
store_counter += 1
store_file_path = store_folder + "/search_result_" + str(index) + "_" + str(
store_counter) + ".store"
visited[link] = store_file_path
while len(worker_stack) == 0: # Wait for a wrrker to become available. Check every second.
sleep(1)
worker = worker_stack.pop()
t = threading.Thread(target=get_and_store, args=(link, store_file_path, worker, worker_stack))
t.start()
line = [str(index), claim, link, str(page_num), this_search_string, this_search_type, store_file_path]
retrieve_evidence.append(line)
return retrieve_evidence
def claim2prompts(example):
claim = example["claim"]
# claim_str = "Claim: " + claim + "||Evidence: "
claim_str = "Evidence: "
for question in example["questions"]:
q_text = question["question"].strip()
if len(q_text) == 0:
continue
if not q_text[-1] == "?":
q_text += "?"
answer_strings = []
for a in question["answers"]:
if a["answer_type"] in ["Extractive", "Abstractive"]:
answer_strings.append(a["answer"])
if a["answer_type"] == "Boolean":
answer_strings.append(a["answer"] + ", because " + a["boolean_explanation"].lower().strip())
for a_text in answer_strings:
if not a_text[-1] in [".", "!", ":", "?"]:
a_text += "."
# prompt_lookup_str = claim + " " + a_text
prompt_lookup_str = a_text
this_q_claim_str = claim_str + " " + a_text.strip() + "||Question answered: " + q_text
yield (prompt_lookup_str, this_q_claim_str.replace("\n", " ").replace("||", "\n"))
def generate_step2_reference_corpus(reference_file):
with open(reference_file) as f:
train_examples = json.load(f)
prompt_corpus = []
tokenized_corpus = []
for example in train_examples:
for lookup_str, prompt in claim2prompts(example):
entry = nltk.word_tokenize(lookup_str)
tokenized_corpus.append(entry)
prompt_corpus.append(prompt)
return tokenized_corpus, prompt_corpus
def decorate_with_questions(claim, retrieve_evidence, top_k=10): # top_k=100
#
reference_file = "averitec_code/data/train.json"
tokenized_corpus, prompt_corpus = generate_step2_reference_corpus(reference_file)
prompt_bm25 = BM25Okapi(tokenized_corpus)
# Define the bloom model:
accelerator = Accelerator()
accel_device = accelerator.device
device = "cuda:0" if torch.cuda.is_available() else "cpu"
tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-7b1")
model = BloomForCausalLM.from_pretrained(
"bigscience/bloom-7b1",
device_map="auto",
torch_dtype=torch.bfloat16,
offload_folder="./offload"
)
#
tokenized_corpus = []
all_data_corpus = []
for retri_evi in tqdm.tqdm(retrieve_evidence):
store_file = retri_evi[-1]
with open(store_file, 'r') as f:
first = True
for line in f:
line = line.strip()
if first:
first = False
location_url = line
continue
if len(line) > 3:
entry = nltk.word_tokenize(line)
if (location_url, line) not in all_data_corpus:
tokenized_corpus.append(entry)
all_data_corpus.append((location_url, line))
if len(tokenized_corpus) == 0:
print("")
bm25 = BM25Okapi(tokenized_corpus)
s = bm25.get_scores(nltk.word_tokenize(claim))
top_n = np.argsort(s)[::-1][:top_k]
docs = [all_data_corpus[i] for i in top_n]
generate_qa_pairs = []
# Then, generate questions for those top 50:
for doc in tqdm.tqdm(docs):
# prompt_lookup_str = example["claim"] + " " + doc[1]
prompt_lookup_str = doc[1]
prompt_s = prompt_bm25.get_scores(nltk.word_tokenize(prompt_lookup_str))
prompt_n = 10
prompt_top_n = np.argsort(prompt_s)[::-1][:prompt_n]
prompt_docs = [prompt_corpus[i] for i in prompt_top_n]
claim_prompt = "Evidence: " + doc[1].replace("\n", " ") + "\nQuestion answered: "
prompt = "\n\n".join(prompt_docs + [claim_prompt])
sentences = [prompt]
inputs = tokenizer(sentences, padding=True, return_tensors="pt").to(device)
outputs = model.generate(inputs["input_ids"], max_length=5000, num_beams=2, no_repeat_ngram_size=2,
early_stopping=True)
tgt_text = tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:], skip_special_tokens=True)[0]
# We are not allowed to generate more than 250 characters:
tgt_text = tgt_text[:250]
qa_pair = [tgt_text.strip().split("?")[0].replace("\n", " ") + "?", doc[1].replace("\n", " "), doc[0]]
generate_qa_pairs.append(qa_pair)
return generate_qa_pairs
def triple_to_string(x):
return " </s> ".join([item.strip() for item in x])
def rerank_questions(claim, bm25_qas, topk=3):
#
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
bert_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2,
problem_type="single_label_classification") # Must specify single_label for some reason
best_checkpoint = "averitec_code/pretrained_models/bert_dual_encoder.ckpt"
device = "cuda:0" if torch.cuda.is_available() else "cpu"
trained_model = DualEncoderModule.load_from_checkpoint(best_checkpoint, tokenizer=tokenizer, model=bert_model).to(
device)
#
strs_to_score = []
values = []
for question, answer, source in bm25_qas:
str_to_score = triple_to_string([claim, question, answer])
strs_to_score.append(str_to_score)
values.append([question, answer, source])
if len(bm25_qas) > 0:
encoded_dict = tokenizer(strs_to_score, max_length=512, padding="longest", truncation=True,
return_tensors="pt").to(device)
input_ids = encoded_dict['input_ids']
attention_masks = encoded_dict['attention_mask']
scores = torch.softmax(trained_model(input_ids, attention_mask=attention_masks).logits, axis=-1)[:, 1]
top_n = torch.argsort(scores, descending=True)[:topk]
pass_through = [{"question": values[i][0], "answers": values[i][1], "source_url": values[i][2]} for i in top_n]
else:
pass_through = []
top3_qa_pairs = pass_through
return top3_qa_pairs
def GoogleAPIretriever(query):
# ----- Generate QA pairs using AVeriTeC
top3_qa_pairs_path = "averitec_code/top3_qa_pairs1.json"
if not os.path.exists(top3_qa_pairs_path):
# step 1: generate questions for the query/claim using Bloom
generate_question = prompt_question_generation(query)
# step 2: retrieve evidence for the generated questions using Google API
retrieve_evidence = averitec_search(query, generate_question)
# step 3: generate QA pairs for each retrieved document
bm25_qa_pairs = decorate_with_questions(query, retrieve_evidence)
# step 4: rerank QA pairs
top3_qa_pairs = rerank_questions(query, bm25_qa_pairs)
else:
top3_qa_pairs = json.load(open(top3_qa_pairs_path, 'r'))
# Add score to metadata
results = []
for i, qa in enumerate(top3_qa_pairs):
metadata = dict()
metadata['name'] = qa['question']
metadata['url'] = qa['source_url']
metadata['cached_source_url'] = qa['source_url']
metadata['short_name'] = "Evidence {}".format(i + 1)
metadata['page_number'] = ""
metadata['query'] = qa['question']
metadata['answer'] = qa['answers']
metadata['page_content'] = "<b>Question</b>: " + qa['question'] + "<br>" + "<b>Answer</b>: " + qa['answers']
page_content = f"""{metadata['page_content']}"""
results.append((metadata, page_content))
return results
# ----------GoogleAPIretriever---------
# ----------Wikipediaretriever---------
def bm25_retriever(query, corpus, topk=3):
bm25 = BM25Okapi(corpus)
#
query_tokens = word_tokenize(query)
scores = bm25.get_scores(query_tokens)
top_n = np.argsort(scores)[::-1][:topk]
top_n_scores = [scores[i] for i in top_n]
return top_n, top_n_scores
def bm25s_retriever(query, corpus, topk=3):
# optional: create a stemmer
stemmer = Stemmer.Stemmer("english")
# Tokenize the corpus and only keep the ids (faster and saves memory)
corpus_tokens = bm25s.tokenize(corpus, stopwords="en", stemmer=stemmer)
# Create the BM25 model and index the corpus
retriever = bm25s.BM25()
retriever.index(corpus_tokens)
# Query the corpus
query_tokens = bm25s.tokenize(query, stemmer=stemmer)
# Get top-k results as a tuple of (doc ids, scores). Both are arrays of shape (n_queries, k)
results, scores = retriever.retrieve(query_tokens, corpus=corpus, k=topk)
top_n = [corpus.index(res) for res in results[0]]
return top_n, scores
def find_evidence_from_wikipedia_dumps(claim):
#
doc = nlp(claim)
entities_in_claim = [str(ent).lower() for ent in doc.ents]
title2id = ranker.doc_dict[0]
wiki_intro, ent_list = [], []
for ent in entities_in_claim:
if ent in title2id.keys():
ids = title2id[ent]
introduction = doc_db.get_doc_intro(ids)
wiki_intro.append([ent, introduction])
# fulltext = doc_db.get_doc_text(ids)
# evidence.append([ent, fulltext])
ent_list.append(ent)
if len(wiki_intro) < 5:
evidence_tfidf = process_topk(claim, title2id, ent_list, k=5)
wiki_intro.extend(evidence_tfidf)
return wiki_intro, doc
def relevant_sentence_retrieval(query, wiki_intro, k):
# 1. Create corpus here
corpus, sentences = [], []
titles = []
for i, (title, intro) in enumerate(wiki_intro):
sents_in_intro = sent_tokenize(intro)
for sent in sents_in_intro:
corpus.append(word_tokenize(sent))
sentences.append(sent)
titles.append(title)
#
# ----- BM25
bm25_top_n, bm25_top_n_scores = bm25_retriever(query, corpus, topk=k)
bm25_top_n_sents = [sentences[i] for i in bm25_top_n]
bm25_top_n_titles = [titles[i] for i in bm25_top_n]
# ----- BM25s
# bm25s_top_n, bm25s_top_n_scores = bm25s_retriever(query, sentences, topk=k) # corpus->sentences
# bm25s_top_n_sents = [sentences[i] for i in bm25s_top_n]
# bm25s_top_n_titles = [titles[i] for i in bm25s_top_n]
return bm25_top_n_sents, bm25_top_n_titles
def process_topk(query, title2id, ent_list, k=1):
doc_names, doc_scores = ranker.closest_docs(query, k)
evidence_tfidf = []
for _name in doc_names:
if _name not in ent_list and len(ent_list) < 5:
ent_list.append(_name)
idx = title2id[_name]
introduction = doc_db.get_doc_intro(idx)
evidence_tfidf.append([_name, introduction])
# fulltext = doc_db.get_doc_text(idx)
# evidence_tfidf.append([_name,fulltext])
return evidence_tfidf
def WikipediaDumpsretriever(claim):
#
# 1. extract relevant wikipedia pages from wikipedia dumps
wiki_intro, doc = find_evidence_from_wikipedia_dumps(claim)
# wiki_intro = [['trump', "'''Trump''' most commonly refers to:\n* Donald Trump (born 1946), President of the United States from 2017 to 2021 \n* Trump (card games), any playing card given an ad-hoc high rank\n\n'''Trump''' may also refer to:"]]
# 2. extract relevant sentences from extracted wikipedia pages
sents, titles = relevant_sentence_retrieval(claim, wiki_intro, k=3)
#
results = []
for i, (sent, title) in enumerate(zip(sents, titles)):
metadata = dict()
metadata['name'] = claim
metadata['url'] = "https://en.wikipedia.org/wiki/" + "_".join(title.split())
metadata['cached_source_url'] = "https://en.wikipedia.org/wiki/" + "_".join(title.split())
metadata['short_name'] = "Evidence {}".format(i + 1)
metadata['page_number'] = ""
metadata['query'] = sent
metadata['title'] = title
metadata['evidence'] = sent
metadata['answer'] = ""
metadata['page_content'] = "<b>Title</b>: " + str(metadata['title']) + "<br>" + "<b>Evidence</b>: " + metadata[
'evidence']
page_content = f"""{metadata['page_content']}"""
results.append(Docs(metadata, page_content))
return results
# ----------WikipediaAPIretriever---------
def clean_str(p):
return p.encode().decode("unicode-escape").encode("latin1").decode("utf-8")
def get_page_obs(page):
# find all paragraphs
paragraphs = page.split("\n")
paragraphs = [p.strip() for p in paragraphs if p.strip()]
# # find all sentence
# sentences = []
# for p in paragraphs:
# sentences += p.split('. ')
# sentences = [s.strip() + '.' for s in sentences if s.strip()]
# # return ' '.join(sentences[:5])
# return ' '.join(sentences)
return ' '.join(paragraphs[:5])
def search_entity_wikipeida(entity):
find_evidence = []
page_py = wiki_wiki.page(entity)
if page_py.exists():
introduction = page_py.summary
find_evidence.append([str(entity), introduction])
return find_evidence
def search_step(entity):
ent_ = entity.replace(" ", "+")
search_url = f"https://en.wikipedia.org/w/index.php?search={ent_}"
response_text = requests.get(search_url).text
soup = BeautifulSoup(response_text, features="html.parser")
result_divs = soup.find_all("div", {"class": "mw-search-result-heading"})
find_evidence = []
if result_divs: # mismatch
# If the wikipeida page of the entity is not exist, find similar wikipedia pages.
result_titles = [clean_str(div.get_text().strip()) for div in result_divs]
similar_titles = result_titles[:5]
for _t in similar_titles:
if len(find_evidence) < 5:
_evi = search_step(_t)
find_evidence.extend(_evi)
else:
page = [p.get_text().strip() for p in soup.find_all("p") + soup.find_all("ul")]
if any("may refer to:" in p for p in page):
_evi = search_step("[" + entity + "]")
find_evidence.extend(_evi)
else:
# page_py = wiki_wiki.page(entity)
#
# if page_py.exists():
# introduction = page_py.summary
# else:
page_text = ""
for p in page:
if len(p.split(" ")) > 2:
page_text += clean_str(p)
if not p.endswith("\n"):
page_text += "\n"
introduction = get_page_obs(page_text)
find_evidence.append([entity, introduction])
return find_evidence
def find_similar_wikipedia(entity, relevant_wikipages):
# If the relevant wikipeida page of the entity is less than 5, find similar wikipedia pages.
ent_ = entity.replace(" ", "+")
search_url = f"https://en.wikipedia.org/w/index.php?search={ent_}&title=Special:Search&profile=advanced&fulltext=1&ns0=1"
response_text = requests.get(search_url).text
soup = BeautifulSoup(response_text, features="html.parser")
result_divs = soup.find_all("div", {"class": "mw-search-result-heading"})
if result_divs:
result_titles = [clean_str(div.get_text().strip()) for div in result_divs]
similar_titles = result_titles[:5]
saved_titles = [ent[0] for ent in relevant_wikipages] if relevant_wikipages else relevant_wikipages
for _t in similar_titles:
if _t not in saved_titles and len(relevant_wikipages) < 5:
_evi = search_entity_wikipeida(_t)
# _evi = search_step(_t)
relevant_wikipages.extend(_evi)
return relevant_wikipages
def find_evidence_from_wikipedia(claim):
#
doc = nlp(claim)
#
wikipedia_page = []
for ent in doc.ents:
relevant_wikipages = search_entity_wikipeida(ent)
if len(relevant_wikipages) < 5:
relevant_wikipages = find_similar_wikipedia(str(ent), relevant_wikipages)
wikipedia_page.extend(relevant_wikipages)
return wikipedia_page
def relevant_wikipedia_API_retriever(claim):
#
doc = nlp(claim)
wiki_intro = []
for ent in doc.ents:
page_py = wiki_wiki.page(ent)
if page_py.exists():
introduction = page_py.summary
else:
introduction = "No documents found."
wiki_intro.append([str(ent), introduction])
return wiki_intro, doc
def Wikipediaretriever(claim, sources):
#
# 1. extract relevant wikipedia pages from wikipedia dumps
if "Dump" in sources:
wikipedia_page = find_evidence_from_wikipedia_dumps(claim)
else:
wikipedia_page = find_evidence_from_wikipedia(claim)
# wiki_intro, doc = relevant_wikipedia_API_retriever(claim)
# 2. extract relevant sentences from extracted wikipedia pages
sents, titles = relevant_sentence_retrieval(claim, wikipedia_page, k=3)
#
results = []
for i, (sent, title) in enumerate(zip(sents, titles)):
metadata = dict()
metadata['name'] = claim
metadata['url'] = "https://en.wikipedia.org/wiki/" + "_".join(title.split())
metadata['cached_source_url'] = "https://en.wikipedia.org/wiki/" + "_".join(title)
metadata['short_name'] = "Evidence {}".format(i + 1)
metadata['page_number'] = ""
metadata['query'] = sent
metadata['title'] = title
metadata['evidence'] = sent
metadata['answer'] = ""
metadata['page_content'] = "<b>Title</b>: " + str(metadata['title']) + "<br>" + "<b>Evidence</b>: " + metadata['evidence']
page_content = f"""{metadata['page_content']}"""
results.append(Docs(metadata, page_content))
return results
def log_on_azure(file, logs, azure_share_client):
logs = json.dumps(logs)
file_client = azure_share_client.get_file_client(file)
file_client.upload_file(logs)
def chat(claim, history, sources):
evidence = []
# if 'Google' in sources:
# evidence = GoogleAPIretriever(query)
# if 'WikiPediaDumps' in sources:
# evidence = WikipediaDumpsretriever(query)
if 'WikiPedia' in sources:
evidence = Wikipediaretriever(claim, sources)
answer_set, answer_output = QAprediction(claim, evidence, sources)
docs_html = ""
if len(evidence) > 0:
docs_html = []
for i, evi in enumerate(evidence, 1):
docs_html.append(make_html_source(evi, i))
docs_html = "".join(docs_html)
else:
print("No documents found")
url_of_evidence = ""
output_language = "English"
output_query = claim
history[-1] = (claim, answer_set)
history = [tuple(x) for x in history]
############################################################
evi_list = []
for evi in evidence:
title_str = evi.metadata['title']
evi_str = evi.metadata['evidence']
evi_list.append([title_str, evi_str])
try:
# Log answer on Azure Blob Storage
# IF AZURE_ISSAVE=TRUE, save the logs into the Azure share client.
if bool(os.environ["AZURE_ISSAVE"]):
timestamp = str(datetime.now().timestamp())
# timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
file = timestamp + ".json"
logs = {
"user_id": str(user_id),
"claim": claim,
"sources": sources,
"evidence": evi_list,
"url": url_of_evidence,
"answer": answer_output,
"time": timestamp,
}
log_on_azure(file, logs, azure_share_client)
except Exception as e:
print(f"Error logging on Azure Blob Storage: {e}")
raise gr.Error(
f"AVeriTeC Error: {str(e)[:100]} - The error has been noted, try another question and if the error remains, you can contact us :)")
##########
return history, docs_html, output_query, output_language
def main():
init_prompt = """
Hello, I am a fact-checking assistant designed to help you find appropriate evidence to predict the veracity of claims.
What do you want to fact-check?
"""
with gr.Blocks(title="AVeriTeC fact-checker", css="style.css", theme=theme, elem_id="main-component") as demo:
with gr.Tab("AVeriTeC"):
with gr.Row(elem_id="chatbot-row"):
with gr.Column(scale=2):
chatbot = gr.Chatbot(
value=[(None, init_prompt)],
show_copy_button=True, show_label=False, elem_id="chatbot", layout="panel",
avatar_images=(None, "assets/averitec.png")
) # avatar_images=(None, "https://i.ibb.co/YNyd5W2/logo4.png"),
with gr.Row(elem_id="input-message"):
textbox = gr.Textbox(placeholder="Ask me what claim do you want to check!", show_label=False,
scale=7, lines=1, interactive=True, elem_id="input-textbox")
# submit = gr.Button("",elem_id = "submit-button",scale = 1,interactive = True,icon = "https://static-00.iconduck.com/assets.00/settings-icon-2048x2046-cw28eevx.png")
with gr.Column(scale=1, variant="panel", elem_id="right-panel"):
with gr.Tabs() as tabs:
with gr.TabItem("Examples", elem_id="tab-examples", id=0):
examples_hidden = gr.Textbox(visible=False)
first_key = list(CLAIMS_Type.keys())[0]
dropdown_samples = gr.Dropdown(CLAIMS_Type.keys(), value=first_key, interactive=True,
show_label=True,
label="Select claim type",
elem_id="dropdown-samples")
samples = []
for i, key in enumerate(CLAIMS_Type.keys()):
examples_visible = True if i == 0 else False
with gr.Row(visible=examples_visible) as group_examples:
examples_questions = gr.Examples(
CLAIMS_Type[key],
[examples_hidden],
examples_per_page=8,
run_on_click=False,
elem_id=f"examples{i}",
api_name=f"examples{i}",
# label = "Click on the example question or enter your own",
# cache_examples=True,
)
samples.append(group_examples)
with gr.Tab("Sources", elem_id="tab-citations", id=1):
sources_textbox = gr.HTML(show_label=False, elem_id="sources-textbox")
docs_textbox = gr.State("")
with gr.Tab("Configuration", elem_id="tab-config", id=2):
gr.Markdown("Reminder: We currently only support fact-checking in English!")
# dropdown_sources = gr.Radio(
# ["AVeriTeC", "WikiPediaDumps", "Google", "WikiPediaAPI"],
# label="Select source",
# value="WikiPediaAPI",
# interactive=True,
# )
dropdown_sources = gr.Radio(
["Google", "WikiPedia"],
label="Select source",
value="WikiPedia",
interactive=True,
)
dropdown_retriever = gr.Dropdown(
["BM25", "BM25s"],
label="Select evidence retriever",
multiselect=False,
value="BM25",
interactive=True,
)
output_query = gr.Textbox(label="Query used for retrieval", show_label=True,
elem_id="reformulated-query", lines=2, interactive=False)
output_language = gr.Textbox(label="Language", show_label=True, elem_id="language", lines=1,
interactive=False)
with gr.Tab("About", elem_classes="max-height other-tabs"):
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("See more info at [https://fever.ai/task.html](https://fever.ai/task.html)")
def start_chat(query, history):
history = history + [(query, None)]
history = [tuple(x) for x in history]
return (gr.update(interactive=False), gr.update(selected=1), history)
def finish_chat():
return (gr.update(interactive=True, value=""))
(textbox
.submit(start_chat, [textbox, chatbot], [textbox, tabs, chatbot], queue=False, api_name="start_chat_textbox")
.then(chat, [textbox, chatbot, dropdown_sources],
[chatbot, sources_textbox, output_query, output_language], concurrency_limit=8, api_name="chat_textbox")
.then(finish_chat, None, [textbox], api_name="finish_chat_textbox")
)
(examples_hidden
.change(start_chat, [examples_hidden, chatbot], [textbox, tabs, chatbot], queue=False,
api_name="start_chat_examples")
.then(chat, [examples_hidden, chatbot, dropdown_sources],
[chatbot, sources_textbox, output_query, output_language], concurrency_limit=8, api_name="chat_examples")
.then(finish_chat, None, [textbox], api_name="finish_chat_examples")
)
def change_sample_questions(key):
index = list(CLAIMS_Type.keys()).index(key)
visible_bools = [False] * len(samples)
visible_bools[index] = True
return [gr.update(visible=visible_bools[i]) for i in range(len(samples))]
dropdown_samples.change(change_sample_questions, dropdown_samples, samples)
demo.queue()
demo.launch(share=True)
if __name__ == "__main__":
main()