#!/usr/bin/env python # -*- coding: utf-8 -*- # Created by zd302 at 17/07/2024 import torch import numpy as np import requests from rank_bm25 import BM25Okapi from bs4 import BeautifulSoup from transformers import BartTokenizer, BartForConditionalGeneration from transformers import BloomTokenizerFast, BloomForCausalLM, BertTokenizer, BertForSequenceClassification from transformers import RobertaTokenizer, RobertaForSequenceClassification import pytorch_lightning as pl from averitec.models.DualEncoderModule import DualEncoderModule from averitec.models.SequenceClassificationModule import SequenceClassificationModule from averitec.models.JustificationGenerationModule import JustificationGenerationModule import wikipediaapi wiki_wiki = wikipediaapi.Wikipedia('AVeriTeC (zd302@cam.ac.uk)', 'en') import os import nltk nltk.download('punkt') from nltk import pos_tag, word_tokenize, sent_tokenize import spacy os.system("python -m spacy download en_core_web_sm") nlp = spacy.load("en_core_web_sm") # ---------- Load Veracity and Justification prediction model ---------- LABEL = [ "Supported", "Refuted", "Not Enough Evidence", "Conflicting Evidence/Cherrypicking", ] # Veracity device = "cuda:0" if torch.cuda.is_available() else "cpu" veracity_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") bert_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=4, problem_type="single_label_classification") veracity_model = SequenceClassificationModule.load_from_checkpoint("averitec/pretrained_models/bert_veracity.ckpt", tokenizer=veracity_tokenizer, model=bert_model).to(device) # Justification justification_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large', add_prefix_space=True) bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large") best_checkpoint = 'averitec/pretrained_models/bart_justifications_verdict-epoch=13-val_loss=2.03-val_meteor=0.28.ckpt' justification_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=justification_tokenizer, model=bart_model).to(device) # --------------------------------------------------------------------------- # ---------------------------------------------------------------------------- class Docs: def __init__(self, metadata=dict(), page_content=""): self.metadata = metadata self.page_content = page_content # ------------------------------ Googleretriever ----------------------------- def Googleretriever(): return 0 # ------------------------------ Googleretriever ----------------------------- # ------------------------------ Wikipediaretriever -------------------------- def search_entity_wikipeida(entity): find_evidence = [] page_py = wiki_wiki.page(entity) if page_py.exists(): introduction = page_py.summary find_evidence.append([str(entity), introduction]) return find_evidence def clean_str(p): return p.encode().decode("unicode-escape").encode("latin1").decode("utf-8") def find_similar_wikipedia(entity, relevant_wikipages): # If the relevant wikipeida page of the entity is less than 5, find similar wikipedia pages. ent_ = entity.replace(" ", "+") search_url = f"https://en.wikipedia.org/w/index.php?search={ent_}&title=Special:Search&profile=advanced&fulltext=1&ns0=1" response_text = requests.get(search_url).text soup = BeautifulSoup(response_text, features="html.parser") result_divs = soup.find_all("div", {"class": "mw-search-result-heading"}) if result_divs: result_titles = [clean_str(div.get_text().strip()) for div in result_divs] similar_titles = result_titles[:5] saved_titles = [ent[0] for ent in relevant_wikipages] if relevant_wikipages else relevant_wikipages for _t in similar_titles: if _t not in saved_titles and len(relevant_wikipages) < 5: _evi = search_entity_wikipeida(_t) # _evi = search_step(_t) relevant_wikipages.extend(_evi) return relevant_wikipages def find_evidence_from_wikipedia(claim): # doc = nlp(claim) # wikipedia_page = [] for ent in doc.ents: relevant_wikipages = search_entity_wikipeida(ent) if len(relevant_wikipages) < 5: relevant_wikipages = find_similar_wikipedia(str(ent), relevant_wikipages) wikipedia_page.extend(relevant_wikipages) return wikipedia_page def bm25_retriever(query, corpus, topk=3): bm25 = BM25Okapi(corpus) # query_tokens = word_tokenize(query) scores = bm25.get_scores(query_tokens) top_n = np.argsort(scores)[::-1][:topk] top_n_scores = [scores[i] for i in top_n] return top_n, top_n_scores def relevant_sentence_retrieval(query, wiki_intro, k): # 1. Create corpus here corpus, sentences = [], [] titles = [] for i, (title, intro) in enumerate(wiki_intro): sents_in_intro = sent_tokenize(intro) for sent in sents_in_intro: corpus.append(word_tokenize(sent)) sentences.append(sent) titles.append(title) # ----- BM25 bm25_top_n, bm25_top_n_scores = bm25_retriever(query, corpus, topk=k) bm25_top_n_sents = [sentences[i] for i in bm25_top_n] bm25_top_n_titles = [titles[i] for i in bm25_top_n] return bm25_top_n_sents, bm25_top_n_titles # ------------------------------ Wikipediaretriever ----------------------------- def Wikipediaretriever(claim): # 1. extract relevant wikipedia pages from wikipedia dumps wikipedia_page = find_evidence_from_wikipedia(claim) # 2. extract relevant sentences from extracted wikipedia pages sents, titles = relevant_sentence_retrieval(claim, wikipedia_page, k=3) # results = [] for i, (sent, title) in enumerate(zip(sents, titles)): metadata = dict() metadata['name'] = claim metadata['url'] = "https://en.wikipedia.org/wiki/" + "_".join(title.split()) metadata['cached_source_url'] = "https://en.wikipedia.org/wiki/" + "_".join(title) metadata['short_name'] = "Evidence {}".format(i + 1) metadata['page_number'] = "" metadata['query'] = sent metadata['title'] = title metadata['evidence'] = sent metadata['answer'] = "" metadata['page_content'] = "Title: " + str(metadata['title']) + "
" + "Evidence: " + metadata['evidence'] page_content = f"""{metadata['page_content']}""" results.append(Docs(metadata, page_content)) return results # ------------------------------ Veracity Prediction ------------------------------ class SequenceClassificationDataLoader(pl.LightningDataModule): def __init__(self, tokenizer, data_file, batch_size, add_extra_nee=False): super().__init__() self.tokenizer = tokenizer self.data_file = data_file self.batch_size = batch_size self.add_extra_nee = add_extra_nee def tokenize_strings( self, source_sentences, max_length=400, pad_to_max_length=False, return_tensors="pt", ): encoded_dict = self.tokenizer( source_sentences, max_length=max_length, padding="max_length" if pad_to_max_length else "longest", truncation=True, return_tensors=return_tensors, ) input_ids = encoded_dict["input_ids"] attention_masks = encoded_dict["attention_mask"] return input_ids, attention_masks def quadruple_to_string(self, claim, question, answer, bool_explanation=""): if bool_explanation is not None and len(bool_explanation) > 0: bool_explanation = ", because " + bool_explanation.lower().strip() else: bool_explanation = "" return ( "[CLAIM] " + claim.strip() + " [QUESTION] " + question.strip() + " " + answer.strip() + bool_explanation ) def veracity_prediction(claim, evidence): dataLoader = SequenceClassificationDataLoader( tokenizer=veracity_tokenizer, data_file="this_is_discontinued", batch_size=32, add_extra_nee=False, ) evidence_strings = [] for evi in evidence: evidence_strings.append(dataLoader.quadruple_to_string(claim, evi.metadata["query"], evi.metadata["answer"], "")) if len(evidence_strings) == 0: # If we found no evidence e.g. because google returned 0 pages, just output NEI. pred_label = "Not Enough Evidence" return pred_label tokenized_strings, attention_mask = dataLoader.tokenize_strings(evidence_strings) example_support = torch.argmax( veracity_model(tokenized_strings.to(device), attention_mask=attention_mask.to(device)).logits, axis=1) has_unanswerable = False has_true = False has_false = False for v in example_support: if v == 0: has_true = True if v == 1: has_false = True if v in (2, 3,): # TODO another hack -- we cant have different labels for train and test so we do this has_unanswerable = True if has_unanswerable: answer = 2 elif has_true and not has_false: answer = 0 elif not has_true and has_false: answer = 1 else: answer = 3 pred_label = LABEL[answer] return pred_label # ------------------------------ Justification Generation ------------------------------ def extract_claim_str(claim, evidence, verdict_label): claim_str = "[CLAIM] " + claim + " [EVIDENCE] " for evi in evidence: q_text = evi.metadata['query'].strip() if len(q_text) == 0: continue if not q_text[-1] == "?": q_text += "?" answer_strings = [] answer_strings.append(evi.metadata['answer']) claim_str += q_text for a_text in answer_strings: if a_text: if not a_text[-1] == ".": a_text += "." claim_str += " " + a_text.strip() claim_str += " " claim_str += " [VERDICT] " + verdict_label return claim_str def justification_generation(claim, evidence, verdict_label): # claim_str = extract_claim_str(claim, evidence, verdict_label) claim_str.strip() pred_justification = justification_model.generate(claim_str, device=device) return pred_justification.strip()