loren-fact-checking / src /er_client /document_retrieval.py
jiangjiechen's picture
init loren for spaces
7f7285f
raw
history blame
8.23 kB
# -*- coding:utf-8 -*-
"""
@Author : Bao
@Date : 2020/9/17
@Desc : Document selection and sentence ranking code from KGAT. Not used in LOREN.
@Last modified by : Bao
@Last modified date : 2020/9/17
"""
import re
import time
import json
import nltk
from tqdm import tqdm
from allennlp.predictors import Predictor
from drqa.retriever import DocDB, utils
from drqa.retriever.utils import normalize
import wikipedia
class FeverDocDB(DocDB):
def __init__(self, path=None):
super().__init__(path)
def get_doc_lines(self, doc_id):
"""Fetch the raw text of the doc for 'doc_id'."""
cursor = self.connection.cursor()
cursor.execute(
"SELECT lines FROM documents WHERE id = ?",
(utils.normalize(doc_id),)
)
result = cursor.fetchone()
cursor.close()
result = result[0] if result is not None else ''
doc_lines = []
for line in result.split('\n'):
if len(line) == 0: continue
line = line.split('\t')[1]
if len(line) == 0: continue
doc_lines.append((doc_id, len(doc_lines), line, 0))
return doc_lines
def get_non_empty_doc_ids(self):
"""Fetch all ids of docs stored in the db."""
cursor = self.connection.cursor()
cursor.execute("SELECT id FROM documents WHERE length(trim(text)) > 0")
results = [r[0] for r in cursor.fetchall()]
cursor.close()
return results
class DocRetrieval:
def __init__(self, database_path, add_claim=False, k_wiki_results=None):
self.db = FeverDocDB(database_path)
self.add_claim = add_claim
self.k_wiki_results = k_wiki_results
self.porter_stemmer = nltk.PorterStemmer()
self.tokenizer = nltk.word_tokenize
self.predictor = Predictor.from_path(
"https://storage.googleapis.com/allennlp-public-models/elmo-constituency-parser-2020.02.10.tar.gz"
)
def get_NP(self, tree, nps):
if isinstance(tree, dict):
if "children" not in tree:
if tree['nodeType'] == "NP":
# print(tree['word'])
# print(tree)
nps.append(tree['word'])
elif "children" in tree:
if tree['nodeType'] == "NP":
# print(tree['word'])
nps.append(tree['word'])
self.get_NP(tree['children'], nps)
else:
self.get_NP(tree['children'], nps)
elif isinstance(tree, list):
for sub_tree in tree:
self.get_NP(sub_tree, nps)
return nps
def get_subjects(self, tree):
subject_words = []
subjects = []
for subtree in tree['children']:
if subtree['nodeType'] == "VP" or subtree['nodeType'] == 'S' or subtree['nodeType'] == 'VBZ':
subjects.append(' '.join(subject_words))
subject_words.append(subtree['word'])
else:
subject_words.append(subtree['word'])
return subjects
def get_noun_phrases(self, claim):
tokens = self.predictor.predict(claim)
nps = []
tree = tokens['hierplane_tree']['root']
noun_phrases = self.get_NP(tree, nps)
subjects = self.get_subjects(tree)
for subject in subjects:
if len(subject) > 0:
noun_phrases.append(subject)
if self.add_claim:
noun_phrases.append(claim)
return list(set(noun_phrases))
def get_doc_for_claim(self, noun_phrases):
predicted_pages = []
for np in noun_phrases:
if len(np) > 300:
continue
i = 1
while i < 12:
try:
# print(np)
# res = server.lookup(np, keep_all=True)
# docs = [y for _, y in res] if res is not None else []
docs = wikipedia.search(np)
if self.k_wiki_results is not None:
predicted_pages.extend(docs[:self.k_wiki_results])
else:
predicted_pages.extend(docs)
except (ConnectionResetError, ConnectionError, ConnectionAbortedError, ConnectionRefusedError):
print("Connection reset error received! Trial #" + str(i))
time.sleep(600 * i)
i += 1
else:
break
# sleep_num = random.uniform(0.1,0.7)
# time.sleep(sleep_num)
predicted_pages = set(predicted_pages)
processed_pages = []
for page in predicted_pages:
page = page.replace(" ", "_")
page = page.replace("(", "-LRB-")
page = page.replace(")", "-RRB-")
page = page.replace(":", "-COLON-")
processed_pages.append(page)
return processed_pages
def np_conc(self, noun_phrases):
noun_phrases = set(noun_phrases)
predicted_pages = []
for np in noun_phrases:
page = np.replace('( ', '-LRB-')
page = page.replace(' )', '-RRB-')
page = page.replace(' - ', '-')
page = page.replace(' :', '-COLON-')
page = page.replace(' ,', ',')
page = page.replace(" 's", "'s")
page = page.replace(' ', '_')
if len(page) < 1:
continue
doc_lines = self.db.get_doc_lines(page)
if len(doc_lines) > 0:
predicted_pages.append(page)
return predicted_pages
def exact_match(self, claim):
noun_phrases = self.get_noun_phrases(claim)
wiki_results = self.get_doc_for_claim(noun_phrases)
wiki_results = list(set(wiki_results))
claim = claim.replace(".", "")
claim = claim.replace("-", " ")
words = [self.porter_stemmer.stem(word.lower()) for word in self.tokenizer(claim)]
words = set(words)
predicted_pages = self.np_conc(noun_phrases)
for page in wiki_results:
page = normalize(page)
processed_page = re.sub("-LRB-.*?-RRB-", "", page)
processed_page = re.sub("_", " ", processed_page)
processed_page = re.sub("-COLON-", ":", processed_page)
processed_page = processed_page.replace("-", " ")
processed_page = processed_page.replace("–", " ")
processed_page = processed_page.replace(".", "")
page_words = [self.porter_stemmer.stem(word.lower()) for word in self.tokenizer(processed_page) if
len(word) > 0]
if all([item in words for item in page_words]):
if ':' in page:
page = page.replace(":", "-COLON-")
predicted_pages.append(page)
predicted_pages = list(set(predicted_pages))
return noun_phrases, wiki_results, predicted_pages
def save_to_file(results, client, filename):
with open(filename, 'w', encoding='utf-8') as fout:
for _id, line in enumerate(results):
claim = line['claim']
evidence = []
for page in line['predicted_pages']:
evidence.extend(client.db.get_doc_lines(page))
print(json.dumps({'claim': claim, 'evidence': evidence}, ensure_ascii=False), file=fout)
if __name__ == '__main__':
database_path = 'data/fever.db'
add_claim = True
k_wiki_results = 7
client = DocRetrieval(database_path, add_claim, k_wiki_results)
results = []
with open('data/claims.json', 'r', encoding='utf-8') as fin:
for line in tqdm(fin):
line = json.loads(line)
_, _, predicted_pages = client.exact_match(line['claim'])
evidence = []
for page in predicted_pages:
evidence.extend(client.db.get_doc_lines(page))
line['evidence'] = evidence
results.append(line)
with open('data/pages.json', 'w', encoding='utf-8') as fout:
for line in results:
print(json.dumps(line, ensure_ascii=False), file=fout)