#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Created by zd302 at 17/07/2024
from fastapi import FastAPI
from pydantic import BaseModel
# from averitec.models.AveritecModule import Wikipediaretriever, Googleretriever, veracity_prediction, justification_generation
import uvicorn
import spaces
app = FastAPI()
# ---------------------------------------------------------------------------------------------------------------------
import gradio as gr
import os
import torch
import json
import numpy as np
import requests
from rank_bm25 import BM25Okapi
from bs4 import BeautifulSoup
from datetime import datetime
from transformers import BartTokenizer, BartForConditionalGeneration
from transformers import BloomTokenizerFast, BloomForCausalLM, BertTokenizer, BertForSequenceClassification
from transformers import RobertaTokenizer, RobertaForSequenceClassification
import pytorch_lightning as pl
from averitec.models.DualEncoderModule import DualEncoderModule
from averitec.models.SequenceClassificationModule import SequenceClassificationModule
from averitec.models.JustificationGenerationModule import JustificationGenerationModule
# ---------------------------------------------------------------------------------------------------------------------
import wikipediaapi
wiki_wiki = wikipediaapi.Wikipedia('AVeriTeC (zd302@cam.ac.uk)', 'en')
import nltk
nltk.download('punkt')
from nltk import pos_tag, word_tokenize, sent_tokenize
import spacy
os.system("python -m spacy download en_core_web_sm")
nlp = spacy.load("en_core_web_sm")
# ---------------------------------------------------------------------------------------------------------------------
# ---------------------------------------------------------------------------
# load .env
from utils import create_user_id
user_id = create_user_id()
from azure.storage.fileshare import ShareServiceClient
try:
from dotenv import load_dotenv
load_dotenv()
except Exception as e:
pass
account_url = os.environ["AZURE_ACCOUNT_URL"]
credential = {
"account_key": os.environ['AZURE_ACCOUNT_KEY'],
"account_name": os.environ['AZURE_ACCOUNT_NAME']
}
file_share_name = "averitec"
azure_service = ShareServiceClient(account_url=account_url, credential=credential)
azure_share_client = azure_service.get_share_client(file_share_name)
# ---------- Setting ----------
# ---------- Load Veracity and Justification prediction model ----------
LABEL = [
"Supported",
"Refuted",
"Not Enough Evidence",
"Conflicting Evidence/Cherrypicking",
]
# Veracity
device = "cuda:0" if torch.cuda.is_available() else "cpu"
veracity_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
bert_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=4, problem_type="single_label_classification")
veracity_checkpoint_path = os.getcwd() + "/averitec/pretrained_models/bert_veracity.ckpt"
# veracity_model = SequenceClassificationModule.load_from_checkpoint(veracity_checkpoint_path,tokenizer=veracity_tokenizer, model=bert_model).to('cuda')
veracity_model = SequenceClassificationModule.load_from_checkpoint(veracity_checkpoint_path,tokenizer=veracity_tokenizer, model=bert_model).to(device)
# Justification
justification_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large', add_prefix_space=True)
bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large")
best_checkpoint = os.getcwd() + '/averitec/pretrained_models/bart_justifications_verdict-epoch=13-val_loss=2.03-val_meteor=0.28.ckpt'
# justification_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=justification_tokenizer, model=bart_model).to('cuda')
justification_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=justification_tokenizer, model=bart_model).to(device)
# ---------------------------------------------------------------------------
# ----------------------------------------------------------------------------
class Docs:
def __init__(self, metadata=dict(), page_content=""):
self.metadata = metadata
self.page_content = page_content
# ------------------------------ Googleretriever -----------------------------
def Googleretriever():
return 0
# ------------------------------ Googleretriever -----------------------------
# ------------------------------ Wikipediaretriever --------------------------
def search_entity_wikipeida(entity):
find_evidence = []
page_py = wiki_wiki.page(entity)
if page_py.exists():
introduction = page_py.summary
find_evidence.append([str(entity), introduction])
return find_evidence
def clean_str(p):
return p.encode().decode("unicode-escape").encode("latin1").decode("utf-8")
def find_similar_wikipedia(entity, relevant_wikipages):
# If the relevant wikipeida page of the entity is less than 5, find similar wikipedia pages.
ent_ = entity.replace(" ", "+")
search_url = f"https://en.wikipedia.org/w/index.php?search={ent_}&title=Special:Search&profile=advanced&fulltext=1&ns0=1"
response_text = requests.get(search_url).text
soup = BeautifulSoup(response_text, features="html.parser")
result_divs = soup.find_all("div", {"class": "mw-search-result-heading"})
if result_divs:
result_titles = [clean_str(div.get_text().strip()) for div in result_divs]
similar_titles = result_titles[:5]
saved_titles = [ent[0] for ent in relevant_wikipages] if relevant_wikipages else relevant_wikipages
for _t in similar_titles:
if _t not in saved_titles and len(relevant_wikipages) < 5:
_evi = search_entity_wikipeida(_t)
# _evi = search_step(_t)
relevant_wikipages.extend(_evi)
return relevant_wikipages
def find_evidence_from_wikipedia(claim):
#
doc = nlp(claim)
#
wikipedia_page = []
for ent in doc.ents:
relevant_wikipages = search_entity_wikipeida(ent)
if len(relevant_wikipages) < 5:
relevant_wikipages = find_similar_wikipedia(str(ent), relevant_wikipages)
wikipedia_page.extend(relevant_wikipages)
return wikipedia_page
def bm25_retriever(query, corpus, topk=3):
bm25 = BM25Okapi(corpus)
#
query_tokens = word_tokenize(query)
scores = bm25.get_scores(query_tokens)
top_n = np.argsort(scores)[::-1][:topk]
top_n_scores = [scores[i] for i in top_n]
return top_n, top_n_scores
def relevant_sentence_retrieval(query, wiki_intro, k):
# 1. Create corpus here
corpus, sentences = [], []
titles = []
for i, (title, intro) in enumerate(wiki_intro):
sents_in_intro = sent_tokenize(intro)
for sent in sents_in_intro:
corpus.append(word_tokenize(sent))
sentences.append(sent)
titles.append(title)
# ----- BM25
bm25_top_n, bm25_top_n_scores = bm25_retriever(query, corpus, topk=k)
bm25_top_n_sents = [sentences[i] for i in bm25_top_n]
bm25_top_n_titles = [titles[i] for i in bm25_top_n]
return bm25_top_n_sents, bm25_top_n_titles
# ------------------------------ Wikipediaretriever -----------------------------
def Wikipediaretriever(claim):
# 1. extract relevant wikipedia pages from wikipedia dumps
wikipedia_page = find_evidence_from_wikipedia(claim)
# 2. extract relevant sentences from extracted wikipedia pages
sents, titles = relevant_sentence_retrieval(claim, wikipedia_page, k=3)
#
results = []
for i, (sent, title) in enumerate(zip(sents, titles)):
metadata = dict()
metadata['name'] = claim
metadata['url'] = "https://en.wikipedia.org/wiki/" + "_".join(title.split())
metadata['cached_source_url'] = "https://en.wikipedia.org/wiki/" + "_".join(title)
metadata['short_name'] = "Evidence {}".format(i + 1)
metadata['page_number'] = ""
metadata['query'] = sent
metadata['title'] = title
metadata['evidence'] = sent
metadata['answer'] = ""
metadata['page_content'] = "Title: " + str(metadata['title']) + "
" + "Evidence: " + metadata['evidence']
page_content = f"""{metadata['page_content']}"""
results.append(Docs(metadata, page_content))
return results
# ------------------------------ Veracity Prediction ------------------------------
class SequenceClassificationDataLoader(pl.LightningDataModule):
def __init__(self, tokenizer, data_file, batch_size, add_extra_nee=False):
super().__init__()
self.tokenizer = tokenizer
self.data_file = data_file
self.batch_size = batch_size
self.add_extra_nee = add_extra_nee
def tokenize_strings(
self,
source_sentences,
max_length=400,
pad_to_max_length=False,
return_tensors="pt",
):
encoded_dict = self.tokenizer(
source_sentences,
max_length=max_length,
padding="max_length" if pad_to_max_length else "longest",
truncation=True,
return_tensors=return_tensors,
)
input_ids = encoded_dict["input_ids"]
attention_masks = encoded_dict["attention_mask"]
return input_ids, attention_masks
def quadruple_to_string(self, claim, question, answer, bool_explanation=""):
if bool_explanation is not None and len(bool_explanation) > 0:
bool_explanation = ", because " + bool_explanation.lower().strip()
else:
bool_explanation = ""
return (
"[CLAIM] "
+ claim.strip()
+ " [QUESTION] "
+ question.strip()
+ " "
+ answer.strip()
+ bool_explanation
)
# @spaces.GPU
def veracity_prediction(claim, evidence):
dataLoader = SequenceClassificationDataLoader(
tokenizer=veracity_tokenizer,
data_file="this_is_discontinued",
batch_size=32,
add_extra_nee=False,
)
evidence_strings = []
for evi in evidence:
evidence_strings.append(dataLoader.quadruple_to_string(claim, evi.metadata["query"], evi.metadata["answer"], ""))
if len(evidence_strings) == 0: # If we found no evidence e.g. because google returned 0 pages, just output NEI.
pred_label = "Not Enough Evidence"
return pred_label
tokenized_strings, attention_mask = dataLoader.tokenize_strings(evidence_strings)
# example_support = torch.argmax(veracity_model(tokenized_strings.to('cuda'), attention_mask=attention_mask.to('cuda')).logits, axis=1)
example_support = torch.argmax(veracity_model(tokenized_strings.to(device), attention_mask=attention_mask.to(device)).logits, axis=1)
has_unanswerable = False
has_true = False
has_false = False
for v in example_support:
if v == 0:
has_true = True
if v == 1:
has_false = True
if v in (2, 3,): # TODO another hack -- we cant have different labels for train and test so we do this
has_unanswerable = True
if has_unanswerable:
answer = 2
elif has_true and not has_false:
answer = 0
elif not has_true and has_false:
answer = 1
else:
answer = 3
pred_label = LABEL[answer]
return pred_label
# ------------------------------ Justification Generation ------------------------------
def extract_claim_str(claim, evidence, verdict_label):
claim_str = "[CLAIM] " + claim + " [EVIDENCE] "
for evi in evidence:
q_text = evi.metadata['query'].strip()
if len(q_text) == 0:
continue
if not q_text[-1] == "?":
q_text += "?"
answer_strings = []
answer_strings.append(evi.metadata['answer'])
claim_str += q_text
for a_text in answer_strings:
if a_text:
if not a_text[-1] == ".":
a_text += "."
claim_str += " " + a_text.strip()
claim_str += " "
claim_str += " [VERDICT] " + verdict_label
return claim_str
# @spaces.GPU
def justification_generation(claim, evidence, verdict_label):
#
claim_str = extract_claim_str(claim, evidence, verdict_label)
claim_str.strip()
# pred_justification = justification_model.generate(claim_str, device='cuda')
pred_justification = justification_model.generate(claim_str, device=device)
return pred_justification.strip()
# ---------------------------------------------------------------------------------------------------------------------
class Item(BaseModel):
claim: str
source: str
@app.get("/")
def greet_json():
return {"Hello": "World!"}
def log_on_azure(file, logs, azure_share_client):
logs = json.dumps(logs)
file_client = azure_share_client.get_file_client(file)
file_client.upload_file(logs)
# @spaces.GPU
@app.post("/predict/")
def fact_checking(item: Item):
claim = item['claim']
source = item['source']
# claim = item.claim
# source = item.source
# Step1: Evidence Retrieval
if source == "Wikipedia":
evidence = Wikipediaretriever(claim)
elif source == "Google":
evidence = Googleretriever(claim)
# Step2: Veracity Prediction and Justification Generation
verdict_label = veracity_prediction(claim, evidence)
justification_label = justification_generation(claim, evidence, verdict_label)
############################################################
evidence_list = []
for evi in evidence:
title_str = evi.metadata['title']
evi_str = evi.metadata['evidence']
url_str = evi.metadata['url']
evidence_list.append([title_str, evi_str, url_str])
try:
# Log answer on Azure Blob Storage
# IF AZURE_ISSAVE=TRUE, save the logs into the Azure share client.
if os.environ["AZURE_ISSAVE"] == "TRUE":
timestamp = str(datetime.now().timestamp())
# timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
file = timestamp + ".json"
logs = {
"user_id": str(user_id),
"claim": claim,
"sources": source,
"evidence": evidence_list,
"answer": [verdict_label, justification_label],
"time": timestamp,
}
log_on_azure(file, logs, azure_share_client)
except Exception as e:
print(f"Error logging on Azure Blob Storage: {e}")
raise gr.Error(
f"AVeriTeC Error: {str(e)[:100]} - The error has been noted, try another question and if the error remains, you can contact us :)")
##########
return {"Verdict": verdict_label, "Justification": justification_label, "Evidence": evidence_list}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)
# if __name__ == "__main__":
# item = {
# "claim": "England won the Euro 2024.",
# "source": "Wikipedia",
# }
#
# results = fact_checking(item)
#
# print(results)
# # -----------------------------------------------------------------------------------------
# import requests
#
# # 定义API URL
# api_url = "https://zhenyundeng-zd-api.hf.space/generate/"
#
# # 定义请求数据
# item = {
# "name": "Alice"
# }
#
# # 发送Get请求
# # response = requests.get("https://zhenyundeng-zd-api.hf.space/")
# # 发送POST请求
# response = requests.post(api_url, json=item)
#
# # 打印响应
# print(response.json())