Spaces:
Build error
Build error
# -*- coding:utf-8 -*- | |
""" | |
@Author : Bao | |
@Date : 2020/9/17 | |
@Desc : Document selection and sentence ranking code from KGAT. Not used in LOREN. | |
@Last modified by : Bao | |
@Last modified date : 2020/9/17 | |
""" | |
import re | |
import time | |
import json | |
import nltk | |
from tqdm import tqdm | |
from allennlp.predictors import Predictor | |
from drqa.retriever import DocDB, utils | |
from drqa.retriever.utils import normalize | |
import wikipedia | |
class FeverDocDB(DocDB): | |
def __init__(self, path=None): | |
super().__init__(path) | |
def get_doc_lines(self, doc_id): | |
"""Fetch the raw text of the doc for 'doc_id'.""" | |
cursor = self.connection.cursor() | |
cursor.execute( | |
"SELECT lines FROM documents WHERE id = ?", | |
(utils.normalize(doc_id),) | |
) | |
result = cursor.fetchone() | |
cursor.close() | |
result = result[0] if result is not None else '' | |
doc_lines = [] | |
for line in result.split('\n'): | |
if len(line) == 0: continue | |
line = line.split('\t')[1] | |
if len(line) == 0: continue | |
doc_lines.append((doc_id, len(doc_lines), line, 0)) | |
return doc_lines | |
def get_non_empty_doc_ids(self): | |
"""Fetch all ids of docs stored in the db.""" | |
cursor = self.connection.cursor() | |
cursor.execute("SELECT id FROM documents WHERE length(trim(text)) > 0") | |
results = [r[0] for r in cursor.fetchall()] | |
cursor.close() | |
return results | |
class DocRetrieval: | |
def __init__(self, database_path, add_claim=False, k_wiki_results=None): | |
self.db = FeverDocDB(database_path) | |
self.add_claim = add_claim | |
self.k_wiki_results = k_wiki_results | |
self.porter_stemmer = nltk.PorterStemmer() | |
self.tokenizer = nltk.word_tokenize | |
self.predictor = Predictor.from_path( | |
"https://storage.googleapis.com/allennlp-public-models/elmo-constituency-parser-2020.02.10.tar.gz" | |
) | |
def get_NP(self, tree, nps): | |
if isinstance(tree, dict): | |
if "children" not in tree: | |
if tree['nodeType'] == "NP": | |
# print(tree['word']) | |
# print(tree) | |
nps.append(tree['word']) | |
elif "children" in tree: | |
if tree['nodeType'] == "NP": | |
# print(tree['word']) | |
nps.append(tree['word']) | |
self.get_NP(tree['children'], nps) | |
else: | |
self.get_NP(tree['children'], nps) | |
elif isinstance(tree, list): | |
for sub_tree in tree: | |
self.get_NP(sub_tree, nps) | |
return nps | |
def get_subjects(self, tree): | |
subject_words = [] | |
subjects = [] | |
for subtree in tree['children']: | |
if subtree['nodeType'] == "VP" or subtree['nodeType'] == 'S' or subtree['nodeType'] == 'VBZ': | |
subjects.append(' '.join(subject_words)) | |
subject_words.append(subtree['word']) | |
else: | |
subject_words.append(subtree['word']) | |
return subjects | |
def get_noun_phrases(self, claim): | |
tokens = self.predictor.predict(claim) | |
nps = [] | |
tree = tokens['hierplane_tree']['root'] | |
noun_phrases = self.get_NP(tree, nps) | |
subjects = self.get_subjects(tree) | |
for subject in subjects: | |
if len(subject) > 0: | |
noun_phrases.append(subject) | |
if self.add_claim: | |
noun_phrases.append(claim) | |
return list(set(noun_phrases)) | |
def get_doc_for_claim(self, noun_phrases): | |
predicted_pages = [] | |
for np in noun_phrases: | |
if len(np) > 300: | |
continue | |
i = 1 | |
while i < 12: | |
try: | |
# print(np) | |
# res = server.lookup(np, keep_all=True) | |
# docs = [y for _, y in res] if res is not None else [] | |
docs = wikipedia.search(np) | |
if self.k_wiki_results is not None: | |
predicted_pages.extend(docs[:self.k_wiki_results]) | |
else: | |
predicted_pages.extend(docs) | |
except (ConnectionResetError, ConnectionError, ConnectionAbortedError, ConnectionRefusedError): | |
print("Connection reset error received! Trial #" + str(i)) | |
time.sleep(600 * i) | |
i += 1 | |
else: | |
break | |
# sleep_num = random.uniform(0.1,0.7) | |
# time.sleep(sleep_num) | |
predicted_pages = set(predicted_pages) | |
processed_pages = [] | |
for page in predicted_pages: | |
page = page.replace(" ", "_") | |
page = page.replace("(", "-LRB-") | |
page = page.replace(")", "-RRB-") | |
page = page.replace(":", "-COLON-") | |
processed_pages.append(page) | |
return processed_pages | |
def np_conc(self, noun_phrases): | |
noun_phrases = set(noun_phrases) | |
predicted_pages = [] | |
for np in noun_phrases: | |
page = np.replace('( ', '-LRB-') | |
page = page.replace(' )', '-RRB-') | |
page = page.replace(' - ', '-') | |
page = page.replace(' :', '-COLON-') | |
page = page.replace(' ,', ',') | |
page = page.replace(" 's", "'s") | |
page = page.replace(' ', '_') | |
if len(page) < 1: | |
continue | |
doc_lines = self.db.get_doc_lines(page) | |
if len(doc_lines) > 0: | |
predicted_pages.append(page) | |
return predicted_pages | |
def exact_match(self, claim): | |
noun_phrases = self.get_noun_phrases(claim) | |
wiki_results = self.get_doc_for_claim(noun_phrases) | |
wiki_results = list(set(wiki_results)) | |
claim = claim.replace(".", "") | |
claim = claim.replace("-", " ") | |
words = [self.porter_stemmer.stem(word.lower()) for word in self.tokenizer(claim)] | |
words = set(words) | |
predicted_pages = self.np_conc(noun_phrases) | |
for page in wiki_results: | |
page = normalize(page) | |
processed_page = re.sub("-LRB-.*?-RRB-", "", page) | |
processed_page = re.sub("_", " ", processed_page) | |
processed_page = re.sub("-COLON-", ":", processed_page) | |
processed_page = processed_page.replace("-", " ") | |
processed_page = processed_page.replace("–", " ") | |
processed_page = processed_page.replace(".", "") | |
page_words = [self.porter_stemmer.stem(word.lower()) for word in self.tokenizer(processed_page) if | |
len(word) > 0] | |
if all([item in words for item in page_words]): | |
if ':' in page: | |
page = page.replace(":", "-COLON-") | |
predicted_pages.append(page) | |
predicted_pages = list(set(predicted_pages)) | |
return noun_phrases, wiki_results, predicted_pages | |
def save_to_file(results, client, filename): | |
with open(filename, 'w', encoding='utf-8') as fout: | |
for _id, line in enumerate(results): | |
claim = line['claim'] | |
evidence = [] | |
for page in line['predicted_pages']: | |
evidence.extend(client.db.get_doc_lines(page)) | |
print(json.dumps({'claim': claim, 'evidence': evidence}, ensure_ascii=False), file=fout) | |
if __name__ == '__main__': | |
database_path = 'data/fever.db' | |
add_claim = True | |
k_wiki_results = 7 | |
client = DocRetrieval(database_path, add_claim, k_wiki_results) | |
results = [] | |
with open('data/claims.json', 'r', encoding='utf-8') as fin: | |
for line in tqdm(fin): | |
line = json.loads(line) | |
_, _, predicted_pages = client.exact_match(line['claim']) | |
evidence = [] | |
for page in predicted_pages: | |
evidence.extend(client.db.get_doc_lines(page)) | |
line['evidence'] = evidence | |
results.append(line) | |
with open('data/pages.json', 'w', encoding='utf-8') as fout: | |
for line in results: | |
print(json.dumps(line, ensure_ascii=False), file=fout) | |