#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Created by zd302 at 17/07/2024
import torch
import numpy as np
import requests
from rank_bm25 import BM25Okapi
from bs4 import BeautifulSoup
from transformers import BartTokenizer, BartForConditionalGeneration
from transformers import BloomTokenizerFast, BloomForCausalLM, BertTokenizer, BertForSequenceClassification
from transformers import RobertaTokenizer, RobertaForSequenceClassification
import pytorch_lightning as pl
from averitec.models.DualEncoderModule import DualEncoderModule
from averitec.models.SequenceClassificationModule import SequenceClassificationModule
from averitec.models.JustificationGenerationModule import JustificationGenerationModule
import wikipediaapi
wiki_wiki = wikipediaapi.Wikipedia('AVeriTeC (zd302@cam.ac.uk)', 'en')
import os
import nltk
nltk.download('punkt')
from nltk import pos_tag, word_tokenize, sent_tokenize
import spacy
os.system("python -m spacy download en_core_web_sm")
nlp = spacy.load("en_core_web_sm")
# ---------- Load Veracity and Justification prediction model ----------
LABEL = [
"Supported",
"Refuted",
"Not Enough Evidence",
"Conflicting Evidence/Cherrypicking",
]
# Veracity
device = "cuda:0" if torch.cuda.is_available() else "cpu"
veracity_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
bert_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=4, problem_type="single_label_classification")
veracity_model = SequenceClassificationModule.load_from_checkpoint("averitec/pretrained_models/bert_veracity.ckpt",
tokenizer=veracity_tokenizer, model=bert_model).to(device)
# Justification
justification_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large', add_prefix_space=True)
bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large")
best_checkpoint = 'averitec/pretrained_models/bart_justifications_verdict-epoch=13-val_loss=2.03-val_meteor=0.28.ckpt'
justification_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=justification_tokenizer, model=bart_model).to(device)
# ---------------------------------------------------------------------------
# ----------------------------------------------------------------------------
class Docs:
def __init__(self, metadata=dict(), page_content=""):
self.metadata = metadata
self.page_content = page_content
# ------------------------------ Googleretriever -----------------------------
def Googleretriever():
return 0
# ------------------------------ Googleretriever -----------------------------
# ------------------------------ Wikipediaretriever --------------------------
def search_entity_wikipeida(entity):
find_evidence = []
page_py = wiki_wiki.page(entity)
if page_py.exists():
introduction = page_py.summary
find_evidence.append([str(entity), introduction])
return find_evidence
def clean_str(p):
return p.encode().decode("unicode-escape").encode("latin1").decode("utf-8")
def find_similar_wikipedia(entity, relevant_wikipages):
# If the relevant wikipeida page of the entity is less than 5, find similar wikipedia pages.
ent_ = entity.replace(" ", "+")
search_url = f"https://en.wikipedia.org/w/index.php?search={ent_}&title=Special:Search&profile=advanced&fulltext=1&ns0=1"
response_text = requests.get(search_url).text
soup = BeautifulSoup(response_text, features="html.parser")
result_divs = soup.find_all("div", {"class": "mw-search-result-heading"})
if result_divs:
result_titles = [clean_str(div.get_text().strip()) for div in result_divs]
similar_titles = result_titles[:5]
saved_titles = [ent[0] for ent in relevant_wikipages] if relevant_wikipages else relevant_wikipages
for _t in similar_titles:
if _t not in saved_titles and len(relevant_wikipages) < 5:
_evi = search_entity_wikipeida(_t)
# _evi = search_step(_t)
relevant_wikipages.extend(_evi)
return relevant_wikipages
def find_evidence_from_wikipedia(claim):
#
doc = nlp(claim)
#
wikipedia_page = []
for ent in doc.ents:
relevant_wikipages = search_entity_wikipeida(ent)
if len(relevant_wikipages) < 5:
relevant_wikipages = find_similar_wikipedia(str(ent), relevant_wikipages)
wikipedia_page.extend(relevant_wikipages)
return wikipedia_page
def bm25_retriever(query, corpus, topk=3):
bm25 = BM25Okapi(corpus)
#
query_tokens = word_tokenize(query)
scores = bm25.get_scores(query_tokens)
top_n = np.argsort(scores)[::-1][:topk]
top_n_scores = [scores[i] for i in top_n]
return top_n, top_n_scores
def relevant_sentence_retrieval(query, wiki_intro, k):
# 1. Create corpus here
corpus, sentences = [], []
titles = []
for i, (title, intro) in enumerate(wiki_intro):
sents_in_intro = sent_tokenize(intro)
for sent in sents_in_intro:
corpus.append(word_tokenize(sent))
sentences.append(sent)
titles.append(title)
# ----- BM25
bm25_top_n, bm25_top_n_scores = bm25_retriever(query, corpus, topk=k)
bm25_top_n_sents = [sentences[i] for i in bm25_top_n]
bm25_top_n_titles = [titles[i] for i in bm25_top_n]
return bm25_top_n_sents, bm25_top_n_titles
# ------------------------------ Wikipediaretriever -----------------------------
def Wikipediaretriever(claim):
# 1. extract relevant wikipedia pages from wikipedia dumps
wikipedia_page = find_evidence_from_wikipedia(claim)
# 2. extract relevant sentences from extracted wikipedia pages
sents, titles = relevant_sentence_retrieval(claim, wikipedia_page, k=3)
#
results = []
for i, (sent, title) in enumerate(zip(sents, titles)):
metadata = dict()
metadata['name'] = claim
metadata['url'] = "https://en.wikipedia.org/wiki/" + "_".join(title.split())
metadata['cached_source_url'] = "https://en.wikipedia.org/wiki/" + "_".join(title)
metadata['short_name'] = "Evidence {}".format(i + 1)
metadata['page_number'] = ""
metadata['query'] = sent
metadata['title'] = title
metadata['evidence'] = sent
metadata['answer'] = ""
metadata['page_content'] = "Title: " + str(metadata['title']) + "
" + "Evidence: " + metadata['evidence']
page_content = f"""{metadata['page_content']}"""
results.append(Docs(metadata, page_content))
return results
# ------------------------------ Veracity Prediction ------------------------------
class SequenceClassificationDataLoader(pl.LightningDataModule):
def __init__(self, tokenizer, data_file, batch_size, add_extra_nee=False):
super().__init__()
self.tokenizer = tokenizer
self.data_file = data_file
self.batch_size = batch_size
self.add_extra_nee = add_extra_nee
def tokenize_strings(
self,
source_sentences,
max_length=400,
pad_to_max_length=False,
return_tensors="pt",
):
encoded_dict = self.tokenizer(
source_sentences,
max_length=max_length,
padding="max_length" if pad_to_max_length else "longest",
truncation=True,
return_tensors=return_tensors,
)
input_ids = encoded_dict["input_ids"]
attention_masks = encoded_dict["attention_mask"]
return input_ids, attention_masks
def quadruple_to_string(self, claim, question, answer, bool_explanation=""):
if bool_explanation is not None and len(bool_explanation) > 0:
bool_explanation = ", because " + bool_explanation.lower().strip()
else:
bool_explanation = ""
return (
"[CLAIM] "
+ claim.strip()
+ " [QUESTION] "
+ question.strip()
+ " "
+ answer.strip()
+ bool_explanation
)
def veracity_prediction(claim, evidence):
dataLoader = SequenceClassificationDataLoader(
tokenizer=veracity_tokenizer,
data_file="this_is_discontinued",
batch_size=32,
add_extra_nee=False,
)
evidence_strings = []
for evi in evidence:
evidence_strings.append(dataLoader.quadruple_to_string(claim, evi.metadata["query"], evi.metadata["answer"], ""))
if len(evidence_strings) == 0: # If we found no evidence e.g. because google returned 0 pages, just output NEI.
pred_label = "Not Enough Evidence"
return pred_label
tokenized_strings, attention_mask = dataLoader.tokenize_strings(evidence_strings)
example_support = torch.argmax(
veracity_model(tokenized_strings.to(device), attention_mask=attention_mask.to(device)).logits, axis=1)
has_unanswerable = False
has_true = False
has_false = False
for v in example_support:
if v == 0:
has_true = True
if v == 1:
has_false = True
if v in (2, 3,): # TODO another hack -- we cant have different labels for train and test so we do this
has_unanswerable = True
if has_unanswerable:
answer = 2
elif has_true and not has_false:
answer = 0
elif not has_true and has_false:
answer = 1
else:
answer = 3
pred_label = LABEL[answer]
return pred_label
# ------------------------------ Justification Generation ------------------------------
def extract_claim_str(claim, evidence, verdict_label):
claim_str = "[CLAIM] " + claim + " [EVIDENCE] "
for evi in evidence:
q_text = evi.metadata['query'].strip()
if len(q_text) == 0:
continue
if not q_text[-1] == "?":
q_text += "?"
answer_strings = []
answer_strings.append(evi.metadata['answer'])
claim_str += q_text
for a_text in answer_strings:
if a_text:
if not a_text[-1] == ".":
a_text += "."
claim_str += " " + a_text.strip()
claim_str += " "
claim_str += " [VERDICT] " + verdict_label
return claim_str
def justification_generation(claim, evidence, verdict_label):
#
claim_str = extract_claim_str(claim, evidence, verdict_label)
claim_str.strip()
pred_justification = justification_model.generate(claim_str, device=device)
return pred_justification.strip()