diff --git a/README.md b/README.md index 4718020c6296e8b180cbcffaf5d6d0a179d7653b..66df3a6dfd3ed12f68ec86fa947cb05b0556f6c2 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,13 @@ --- -title: AVeriTeC API -emoji: 🚀 -colorFrom: blue -colorTo: gray +title: AVeriTeC +emoji: 🏆 +colorFrom: purple +colorTo: red sdk: gradio -sdk_version: 4.38.1 +sdk_version: 4.37.2 app_file: app.py pinned: false +license: apache-2.0 --- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..fa0ac52430050f3a3d58b64c0de4fa45fc5bce0c --- /dev/null +++ b/app.py @@ -0,0 +1,1368 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Created by zd302 at 08/07/2024 + +import gradio as gr +import tqdm +import torch +import numpy as np +from time import sleep +import threading +import gc +import os +import json +import pytorch_lightning as pl +from urllib.parse import urlparse +from accelerate import Accelerator + +from transformers import BartTokenizer, BartForConditionalGeneration +from transformers import BloomTokenizerFast, BloomForCausalLM, BertTokenizer, BertForSequenceClassification +from transformers import RobertaTokenizer, RobertaForSequenceClassification + +from rank_bm25 import BM25Okapi +# import bm25s +# import Stemmer # optional: for stemming +from html2lines import url2lines +from googleapiclient.discovery import build +from averitec.models.DualEncoderModule import DualEncoderModule +from averitec.models.SequenceClassificationModule import SequenceClassificationModule +from averitec.models.JustificationGenerationModule import JustificationGenerationModule +from averitec.data.sample_claims import CLAIMS_Type + +# --------------------------------------------------------------------------- +# load .env +from utils import create_user_id +user_id = create_user_id() + +from datetime import datetime +from azure.storage.fileshare import ShareServiceClient +try: + from dotenv import load_dotenv + load_dotenv() +except Exception as e: + pass + +account_url = os.environ["AZURE_ACCOUNT_URL"] +credential = { + "account_key": os.environ['AZURE_ACCOUNT_KEY'], + "account_name": os.environ['AZURE_ACCOUNT_NAME'] +} + +file_share_name = "averitec" +azure_service = ShareServiceClient(account_url=account_url, credential=credential) +azure_share_client = azure_service.get_share_client(file_share_name) + +# ---------- Setting ---------- +import requests +from bs4 import BeautifulSoup +import wikipediaapi +wiki_wiki = wikipediaapi.Wikipedia('AVeriTeC (zd302@cam.ac.uk)', 'en') + +import nltk +nltk.download('punkt') +from nltk import pos_tag, word_tokenize, sent_tokenize + +import spacy +os.system("python -m spacy download en_core_web_sm") +nlp = spacy.load("en_core_web_sm") + +# --------------------------------------------------------------------------- +# Load sample dict for AVeriTeC search +# all_samples_dict = json.load(open('averitec/data/all_samples.json', 'r')) + +# --------------------------------------------------------------------------- +# ---------- Load pretrained models ---------- +# ---------- load Evidence retrieval model ---------- +# from drqa import retriever +# db_class = retriever.get_class('sqlite') +# doc_db = db_class("averitec/data/wikipedia_dumps/enwiki.db") +# ranker = retriever.get_class('tfidf')(tfidf_path="averitec/data/wikipedia_dumps/enwiki-tfidf-with-id-title.npz") + +# ---------- Load Veracity and Justification prediction model ---------- +print("Loading models ...") +LABEL = [ + "Supported", + "Refuted", + "Not Enough Evidence", + "Conflicting Evidence/Cherrypicking", +] +# Veracity +device = "cuda:0" if torch.cuda.is_available() else "cpu" +veracity_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") +bert_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=4, problem_type="single_label_classification") +veracity_model = SequenceClassificationModule.load_from_checkpoint("averitec/pretrained_models/bert_veracity.ckpt", + tokenizer=veracity_tokenizer, model=bert_model).to(device) +# Justification +justification_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large', add_prefix_space=True) +bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large") +best_checkpoint = 'averitec/pretrained_models/bart_justifications_verdict-epoch=13-val_loss=2.03-val_meteor=0.28.ckpt' +justification_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=justification_tokenizer, model=bart_model).to(device) +# --------------------------------------------------------------------------- + + +# Set up Gradio Theme +theme = gr.themes.Base( + primary_hue="blue", + secondary_hue="red", + font=[gr.themes.GoogleFont("Poppins"), "ui-sans-serif", "system-ui", "sans-serif"], +) + +# ---------- Setting ---------- + +class Docs: + def __init__(self, metadata=dict(), page_content=""): + self.metadata = metadata + self.page_content = page_content + + +def make_html_source(source, i): + meta = source.metadata + content = source.page_content.strip() + + card = f""" +
+
+

Doc {i} - URL: {meta['url']}

+

{content}

+
+ +
+ """ + + return card + + +# ----- veracity_prediction ----- +class SequenceClassificationDataLoader(pl.LightningDataModule): + def __init__(self, tokenizer, data_file, batch_size, add_extra_nee=False): + super().__init__() + self.tokenizer = tokenizer + self.data_file = data_file + self.batch_size = batch_size + self.add_extra_nee = add_extra_nee + + def tokenize_strings( + self, + source_sentences, + max_length=400, + pad_to_max_length=False, + return_tensors="pt", + ): + encoded_dict = self.tokenizer( + source_sentences, + max_length=max_length, + padding="max_length" if pad_to_max_length else "longest", + truncation=True, + return_tensors=return_tensors, + ) + + input_ids = encoded_dict["input_ids"] + attention_masks = encoded_dict["attention_mask"] + + return input_ids, attention_masks + + def quadruple_to_string(self, claim, question, answer, bool_explanation=""): + if bool_explanation is not None and len(bool_explanation) > 0: + bool_explanation = ", because " + bool_explanation.lower().strip() + else: + bool_explanation = "" + return ( + "[CLAIM] " + + claim.strip() + + " [QUESTION] " + + question.strip() + + " " + + answer.strip() + + bool_explanation + ) + + +def averitec_veracity_prediction(claim, qa_evidence): + bert_model_name = "bert-base-uncased" + tokenizer = BertTokenizer.from_pretrained(bert_model_name) + bert_model = BertForSequenceClassification.from_pretrained(bert_model_name, num_labels=4, + problem_type="single_label_classification") + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + trained_model = SequenceClassificationModule.load_from_checkpoint("averitec/pretrained_models/bert_veracity.ckpt", + tokenizer=tokenizer, model=bert_model).to(device) + + dataLoader = SequenceClassificationDataLoader( + tokenizer=tokenizer, + data_file="this_is_discontinued", + batch_size=32, + add_extra_nee=False, + ) + + evidence_strings = [] + for evidence in qa_evidence: + evidence_strings.append( + dataLoader.quadruple_to_string(claim, evidence.metadata["query"], evidence.metadata["answer"], "")) + + if len(evidence_strings) == 0: # If we found no evidence e.g. because google returned 0 pages, just output NEI. + pred_label = "Not Enough Evidence" + return pred_label + + tokenized_strings, attention_mask = dataLoader.tokenize_strings(evidence_strings) + example_support = torch.argmax( + trained_model(tokenized_strings.to(device), attention_mask=attention_mask.to(device)).logits, axis=1) + + has_unanswerable = False + has_true = False + has_false = False + + for v in example_support: + if v == 0: + has_true = True + if v == 1: + has_false = True + if v in (2, 3,): # TODO another hack -- we cant have different labels for train and test so we do this + has_unanswerable = True + + if has_unanswerable: + answer = 2 + elif has_true and not has_false: + answer = 0 + elif not has_true and has_false: + answer = 1 + else: + answer = 3 + + pred_label = LABEL[answer] + + return pred_label + + +def fever_veracity_prediction(claim, evidence): + tokenizer = RobertaTokenizer.from_pretrained('Dzeniks/roberta-fact-check') + model = RobertaForSequenceClassification.from_pretrained('Dzeniks/roberta-fact-check') + + evidence_string = "" + for evi in evidence: + evidence_string += evi.metadata['title'] + evi.metadata['evidence'] + ' ' + + input_sequence = tokenizer.encode_plus(claim, evidence_string, return_tensors="pt") + with torch.no_grad(): + prediction = model(**input_sequence) + + label = torch.argmax(prediction[0]).item() + pred_label = LABEL[label] + + return pred_label + + +def veracity_prediction(claim, qa_evidence): + # bert_model_name = "bert-base-uncased" + # tokenizer = BertTokenizer.from_pretrained(bert_model_name) + # bert_model = BertForSequenceClassification.from_pretrained(bert_model_name, num_labels=4, + # problem_type="single_label_classification") + # + # device = "cuda:0" if torch.cuda.is_available() else "cpu" + # trained_model = SequenceClassificationModule.load_from_checkpoint("averitec/pretrained_models/bert_veracity.ckpt", + # tokenizer=tokenizer, model=bert_model).to(device) + + dataLoader = SequenceClassificationDataLoader( + tokenizer=veracity_tokenizer, + data_file="this_is_discontinued", + batch_size=32, + add_extra_nee=False, + ) + + evidence_strings = [] + for evidence in qa_evidence: + evidence_strings.append( + dataLoader.quadruple_to_string(claim, evidence.metadata["query"], evidence.metadata["answer"], "")) + + if len(evidence_strings) == 0: # If we found no evidence e.g. because google returned 0 pages, just output NEI. + pred_label = "Not Enough Evidence" + return pred_label + + tokenized_strings, attention_mask = dataLoader.tokenize_strings(evidence_strings) + example_support = torch.argmax(veracity_model(tokenized_strings.to(device), attention_mask=attention_mask.to(device)).logits, axis=1) + + has_unanswerable = False + has_true = False + has_false = False + + for v in example_support: + if v == 0: + has_true = True + if v == 1: + has_false = True + if v in (2, 3,): # TODO another hack -- we cant have different labels for train and test so we do this + has_unanswerable = True + + if has_unanswerable: + answer = 2 + elif has_true and not has_false: + answer = 0 + elif not has_true and has_false: + answer = 1 + else: + answer = 3 + + pred_label = LABEL[answer] + + return pred_label + + +def extract_claim_str(claim, qa_evidence, verdict_label): + claim_str = "[CLAIM] " + claim + " [EVIDENCE] " + + for evidence in qa_evidence: + q_text = evidence.metadata['query'].strip() + + if len(q_text) == 0: + continue + + if not q_text[-1] == "?": + q_text += "?" + + answer_strings = [] + answer_strings.append(evidence.metadata['answer']) + + claim_str += q_text + for a_text in answer_strings: + if a_text: + if not a_text[-1] == ".": + a_text += "." + claim_str += " " + a_text.strip() + + claim_str += " " + + claim_str += " [VERDICT] " + verdict_label + + return claim_str + + +def averitec_justification_generation(claim, qa_evidence, verdict_label): + # + claim_str = extract_claim_str(claim, qa_evidence, verdict_label) + claim_str.strip() + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + tokenizer = BartTokenizer.from_pretrained('facebook/bart-large', add_prefix_space=True) + bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large") + + best_checkpoint = 'averitec/pretrained_models/bart_justifications_verdict-epoch=13-val_loss=2.03-val_meteor=0.28.ckpt' + trained_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=tokenizer, + model=bart_model).to(device) + + pred_justification = trained_model.generate(claim_str, device=device) + + return pred_justification.strip() + + +def justification_generation(claim, qa_evidence, verdict_label): + # + claim_str = extract_claim_str(claim, qa_evidence, verdict_label) + claim_str.strip() + + # device = "cuda:0" if torch.cuda.is_available() else "cpu" + # tokenizer = BartTokenizer.from_pretrained('facebook/bart-large', add_prefix_space=True) + # bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large") + # + # best_checkpoint = 'averitec/pretrained_models/bart_justifications_verdict-epoch=13-val_loss=2.03-val_meteor=0.28.ckpt' + # trained_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=tokenizer, + # model=bart_model).to(device) + + pred_justification = justification_model.generate(claim_str, device=device) + + return pred_justification.strip() + + +def QAprediction(claim, evidence, sources): + parts = [] + # + evidence_title = f"""
Retrieved Evidence:
""" + for i, evi in enumerate(evidence, 1): + part = f"""Doc {i}""" + subpart = f"""{i}""" + # subpart = f"""{i}""" + subparts = "".join([part, subpart]) + parts.append(subparts) + + evidence_part = ", ".join(parts) + + prediction_title = f"""
Prediction:
""" + # if 'Google' in sources or 'AVeriTeC' in sources: + # verdict_label = averitec_veracity_prediction(claim, evidence) + # justification_label = averitec_justification_generation(claim, evidence, verdict_label) + # # justification_label = "See retrieved docs." + # justification_part = f"""Justification: {justification_label}""" + # if 'WikiPedia' in sources: + # # verdict_label = fever_veracity_prediction(claim, evidence) + # justification_label = averitec_justification_generation(claim, evidence, verdict_label) + # # justification_label = "See retrieved docs." + # justification_part = f"""Justification: {justification_label}""" + + verdict_label = veracity_prediction(claim, evidence) + justification_label = justification_generation(claim, evidence, verdict_label) + # justification_label = "See retrieved docs." + justification_part = f"""Justification: {justification_label}""" + + + verdict_part = f"""Verdict: {verdict_label}.
""" + + content_parts = "".join([evidence_title, evidence_part, prediction_title, verdict_part, justification_part]) + # content_parts = "".join([evidence_title, evidence_part, verdict_title, verdict_part, justification_title, justification_part]) + + return content_parts, [verdict_label, justification_label] + + +# ----------GoogleAPIretriever--------- +def generate_reference_corpus(reference_file): + with open(reference_file) as f: + j = json.load(f) + train_examples = j + + all_data_corpus = [] + tokenized_corpus = [] + + for train_example in train_examples: + train_claim = train_example["claim"] + + speaker = train_example["speaker"].strip() if train_example["speaker"] is not None and len( + train_example["speaker"]) > 1 else "they" + + questions = [q["question"] for q in train_example["questions"]] + + claim_dict_builder = {} + claim_dict_builder["claim"] = train_claim + claim_dict_builder["speaker"] = speaker + claim_dict_builder["questions"] = questions + + tokenized_corpus.append(nltk.word_tokenize(claim_dict_builder["claim"])) + all_data_corpus.append(claim_dict_builder) + + return tokenized_corpus, all_data_corpus + + +def doc2prompt(doc): + prompt_parts = "Outrageously, " + doc["speaker"] + " claimed that \"" + doc[ + "claim"].strip() + "\". Criticism includes questions like: " + questions = [q.strip() for q in doc["questions"]] + return prompt_parts + " ".join(questions) + + +def docs2prompt(top_docs): + return "\n\n".join([doc2prompt(d) for d in top_docs]) + + +def prompt_question_generation(test_claim, speaker="they", topk=10): + # + reference_file = "averitec_code/data/train.json" + tokenized_corpus, all_data_corpus = generate_reference_corpus(reference_file) + bm25 = BM25Okapi(tokenized_corpus) + + # Define the bloom model: + accelerator = Accelerator() + accel_device = accelerator.device + device = "cuda:0" if torch.cuda.is_available() else "cpu" + tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-7b1") + model = BloomForCausalLM.from_pretrained("bigscience/bloom-7b1", torch_dtype=torch.bfloat16).to(device) + + # -------------------------------------------------- + # test claim + s = bm25.get_scores(nltk.word_tokenize(test_claim)) + top_n = np.argsort(s)[::-1][:topk] + docs = [all_data_corpus[i] for i in top_n] + # -------------------------------------------------- + + prompt = docs2prompt(docs) + "\n\n" + "Outrageously, " + speaker + " claimed that \"" + test_claim.strip() + \ + "\". Criticism includes questions like: " + sentences = [prompt] + + inputs = tokenizer(sentences, padding=True, return_tensors="pt").to(device) + outputs = model.generate(inputs["input_ids"], max_length=2000, num_beams=2, no_repeat_ngram_size=2, + early_stopping=True) + + tgt_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] + in_len = len(sentences[0]) + questions_str = tgt_text[in_len:].split("\n")[0] + + qs = questions_str.split("?") + qs = [q.strip() + "?" for q in qs if q.strip() and len(q.strip()) < 300] + + # + generate_question = [{"question": q, "answers": []} for q in qs] + + return generate_question + + +def check_claim_date(check_date): + try: + year, month, date = check_date.split("-") + except: + month, date, year = "01", "01", "2022" + + if len(year) == 2 and int(year) <= 30: + year = "20" + year + elif len(year) == 2: + year = "19" + year + elif len(year) == 1: + year = "200" + year + + if len(month) == 1: + month = "0" + month + + if len(date) == 1: + date = "0" + date + + sort_date = year + month + date + + return sort_date + + +def string_to_search_query(text, author): + parts = word_tokenize(text.strip()) + tags = pos_tag(parts) + + keep_tags = ["CD", "JJ", "NN", "VB"] + + if author is not None: + search_string = author.split() + else: + search_string = [] + + for token, tag in zip(parts, tags): + for keep_tag in keep_tags: + if tag[1].startswith(keep_tag): + search_string.append(token) + + search_string = " ".join(search_string) + return search_string + + +def google_search(search_term, api_key, cse_id, **kwargs): + service = build("customsearch", "v1", developerKey=api_key) + res = service.cse().list(q=search_term, cx=cse_id, **kwargs).execute() + + if "items" in res: + return res['items'] + else: + return [] + + +def get_domain_name(url): + if '://' not in url: + url = 'http://' + url + + domain = urlparse(url).netloc + + if domain.startswith("www."): + return domain[4:] + else: + return domain + + +def get_and_store(url_link, fp, worker, worker_stack): + page_lines = url2lines(url_link) + + with open(fp, "w") as out_f: + print("\n".join([url_link] + page_lines), file=out_f) + + worker_stack.append(worker) + gc.collect() + + +def get_google_search_results(api_key, search_engine_id, google_search, sort_date, search_string, page=0): + search_results = [] + for i in range(3): + try: + search_results += google_search( + search_string, + api_key, + search_engine_id, + num=10, + start=0 + 10 * page, + sort="date:r:19000101:" + sort_date, + dateRestrict=None, + gl="US" + ) + break + except: + sleep(3) + + return search_results + + +def averitec_search(claim, generate_question, speaker="they", check_date="2024-01-01", n_pages=1): # n_pages=3 + # default config + api_key = os.environ["GOOGLE_API_KEY"] + search_engine_id = os.environ["GOOGLE_SEARCH_ENGINE_ID"] + + blacklist = [ + "jstor.org", # Blacklisted because their pdfs are not labelled as such, and clog up the download + "facebook.com", # Blacklisted because only post titles can be scraped, but the scraper doesn't know this, + "ftp.cs.princeton.edu", # Blacklisted because it hosts many large NLP corpora that keep showing up + "nlp.cs.princeton.edu", + "huggingface.co" + ] + + blacklist_files = [ # Blacklisted some NLP nonsense that crashes my machine with OOM errors + "/glove.", + "ftp://ftp.cs.princeton.edu/pub/cs226/autocomplete/words-333333.txt", + "https://web.mit.edu/adamrose/Public/googlelist", + ] + + # save to folder + store_folder = "averitec_code/store/retrieved_docs" + # + index = 0 + questions = [q["question"] for q in generate_question] + + # check the date of the claim + sort_date = check_claim_date(check_date) # check_date="2022-01-01" + + # + search_strings = [] + search_types = [] + + search_string_2 = string_to_search_query(claim, None) + search_strings += [search_string_2, claim, ] + search_types += ["claim", "claim-noformat", ] + + search_strings += questions + search_types += ["question" for _ in questions] + + # start to search + search_results = [] + visited = {} + store_counter = 0 + worker_stack = list(range(10)) + + retrieve_evidence = [] + + for this_search_string, this_search_type in zip(search_strings, search_types): + for page_num in range(n_pages): + search_results = get_google_search_results(api_key, search_engine_id, google_search, sort_date, + this_search_string, page=page_num) + + for result in search_results: + link = str(result["link"]) + domain = get_domain_name(link) + + if domain in blacklist: + continue + broken = False + for b_file in blacklist_files: + if b_file in link: + broken = True + if broken: + continue + if link.endswith(".pdf") or link.endswith(".doc"): + continue + + store_file_path = "" + + if link in visited: + store_file_path = visited[link] + else: + store_counter += 1 + store_file_path = store_folder + "/search_result_" + str(index) + "_" + str( + store_counter) + ".store" + visited[link] = store_file_path + + while len(worker_stack) == 0: # Wait for a wrrker to become available. Check every second. + sleep(1) + + worker = worker_stack.pop() + + t = threading.Thread(target=get_and_store, args=(link, store_file_path, worker, worker_stack)) + t.start() + + line = [str(index), claim, link, str(page_num), this_search_string, this_search_type, store_file_path] + retrieve_evidence.append(line) + + return retrieve_evidence + + +def claim2prompts(example): + claim = example["claim"] + + # claim_str = "Claim: " + claim + "||Evidence: " + claim_str = "Evidence: " + + for question in example["questions"]: + q_text = question["question"].strip() + if len(q_text) == 0: + continue + + if not q_text[-1] == "?": + q_text += "?" + + answer_strings = [] + + for a in question["answers"]: + if a["answer_type"] in ["Extractive", "Abstractive"]: + answer_strings.append(a["answer"]) + if a["answer_type"] == "Boolean": + answer_strings.append(a["answer"] + ", because " + a["boolean_explanation"].lower().strip()) + + for a_text in answer_strings: + if not a_text[-1] in [".", "!", ":", "?"]: + a_text += "." + + # prompt_lookup_str = claim + " " + a_text + prompt_lookup_str = a_text + this_q_claim_str = claim_str + " " + a_text.strip() + "||Question answered: " + q_text + yield (prompt_lookup_str, this_q_claim_str.replace("\n", " ").replace("||", "\n")) + + +def generate_step2_reference_corpus(reference_file): + with open(reference_file) as f: + train_examples = json.load(f) + + prompt_corpus = [] + tokenized_corpus = [] + + for example in train_examples: + for lookup_str, prompt in claim2prompts(example): + entry = nltk.word_tokenize(lookup_str) + tokenized_corpus.append(entry) + prompt_corpus.append(prompt) + + return tokenized_corpus, prompt_corpus + + +def decorate_with_questions(claim, retrieve_evidence, top_k=10): # top_k=100 + # + reference_file = "averitec_code/data/train.json" + tokenized_corpus, prompt_corpus = generate_step2_reference_corpus(reference_file) + prompt_bm25 = BM25Okapi(tokenized_corpus) + + # Define the bloom model: + accelerator = Accelerator() + accel_device = accelerator.device + device = "cuda:0" if torch.cuda.is_available() else "cpu" + tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-7b1") + model = BloomForCausalLM.from_pretrained( + "bigscience/bloom-7b1", + device_map="auto", + torch_dtype=torch.bfloat16, + offload_folder="./offload" + ) + + # + tokenized_corpus = [] + all_data_corpus = [] + + for retri_evi in tqdm.tqdm(retrieve_evidence): + store_file = retri_evi[-1] + + with open(store_file, 'r') as f: + first = True + for line in f: + line = line.strip() + + if first: + first = False + location_url = line + continue + + if len(line) > 3: + entry = nltk.word_tokenize(line) + if (location_url, line) not in all_data_corpus: + tokenized_corpus.append(entry) + all_data_corpus.append((location_url, line)) + + if len(tokenized_corpus) == 0: + print("") + + bm25 = BM25Okapi(tokenized_corpus) + s = bm25.get_scores(nltk.word_tokenize(claim)) + top_n = np.argsort(s)[::-1][:top_k] + docs = [all_data_corpus[i] for i in top_n] + + generate_qa_pairs = [] + # Then, generate questions for those top 50: + for doc in tqdm.tqdm(docs): + # prompt_lookup_str = example["claim"] + " " + doc[1] + prompt_lookup_str = doc[1] + + prompt_s = prompt_bm25.get_scores(nltk.word_tokenize(prompt_lookup_str)) + prompt_n = 10 + prompt_top_n = np.argsort(prompt_s)[::-1][:prompt_n] + prompt_docs = [prompt_corpus[i] for i in prompt_top_n] + + claim_prompt = "Evidence: " + doc[1].replace("\n", " ") + "\nQuestion answered: " + prompt = "\n\n".join(prompt_docs + [claim_prompt]) + sentences = [prompt] + + inputs = tokenizer(sentences, padding=True, return_tensors="pt").to(device) + outputs = model.generate(inputs["input_ids"], max_length=5000, num_beams=2, no_repeat_ngram_size=2, + early_stopping=True) + + tgt_text = tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:], skip_special_tokens=True)[0] + # We are not allowed to generate more than 250 characters: + tgt_text = tgt_text[:250] + + qa_pair = [tgt_text.strip().split("?")[0].replace("\n", " ") + "?", doc[1].replace("\n", " "), doc[0]] + generate_qa_pairs.append(qa_pair) + + return generate_qa_pairs + + +def triple_to_string(x): + return " ".join([item.strip() for item in x]) + + +def rerank_questions(claim, bm25_qas, topk=3): + # + tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") + bert_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2, + problem_type="single_label_classification") # Must specify single_label for some reason + best_checkpoint = "averitec_code/pretrained_models/bert_dual_encoder.ckpt" + device = "cuda:0" if torch.cuda.is_available() else "cpu" + trained_model = DualEncoderModule.load_from_checkpoint(best_checkpoint, tokenizer=tokenizer, model=bert_model).to( + device) + + # + strs_to_score = [] + values = [] + + for question, answer, source in bm25_qas: + str_to_score = triple_to_string([claim, question, answer]) + + strs_to_score.append(str_to_score) + values.append([question, answer, source]) + + if len(bm25_qas) > 0: + encoded_dict = tokenizer(strs_to_score, max_length=512, padding="longest", truncation=True, + return_tensors="pt").to(device) + + input_ids = encoded_dict['input_ids'] + attention_masks = encoded_dict['attention_mask'] + + scores = torch.softmax(trained_model(input_ids, attention_mask=attention_masks).logits, axis=-1)[:, 1] + + top_n = torch.argsort(scores, descending=True)[:topk] + pass_through = [{"question": values[i][0], "answers": values[i][1], "source_url": values[i][2]} for i in top_n] + else: + pass_through = [] + + top3_qa_pairs = pass_through + + return top3_qa_pairs + + +def GoogleAPIretriever(query): + # ----- Generate QA pairs using AVeriTeC + top3_qa_pairs_path = "averitec_code/top3_qa_pairs1.json" + if not os.path.exists(top3_qa_pairs_path): + # step 1: generate questions for the query/claim using Bloom + generate_question = prompt_question_generation(query) + # step 2: retrieve evidence for the generated questions using Google API + retrieve_evidence = averitec_search(query, generate_question) + # step 3: generate QA pairs for each retrieved document + bm25_qa_pairs = decorate_with_questions(query, retrieve_evidence) + # step 4: rerank QA pairs + top3_qa_pairs = rerank_questions(query, bm25_qa_pairs) + else: + top3_qa_pairs = json.load(open(top3_qa_pairs_path, 'r')) + + # Add score to metadata + results = [] + for i, qa in enumerate(top3_qa_pairs): + metadata = dict() + + metadata['name'] = qa['question'] + metadata['url'] = qa['source_url'] + metadata['cached_source_url'] = qa['source_url'] + metadata['short_name'] = "Evidence {}".format(i + 1) + metadata['page_number'] = "" + metadata['query'] = qa['question'] + metadata['answer'] = qa['answers'] + metadata['page_content'] = "Question: " + qa['question'] + "
" + "Answer: " + qa['answers'] + page_content = f"""{metadata['page_content']}""" + results.append((metadata, page_content)) + + return results + + +# ----------GoogleAPIretriever--------- + +# ----------Wikipediaretriever--------- +def bm25_retriever(query, corpus, topk=3): + bm25 = BM25Okapi(corpus) + # + query_tokens = word_tokenize(query) + scores = bm25.get_scores(query_tokens) + top_n = np.argsort(scores)[::-1][:topk] + top_n_scores = [scores[i] for i in top_n] + + return top_n, top_n_scores + + +def bm25s_retriever(query, corpus, topk=3): + # optional: create a stemmer + stemmer = Stemmer.Stemmer("english") + # Tokenize the corpus and only keep the ids (faster and saves memory) + corpus_tokens = bm25s.tokenize(corpus, stopwords="en", stemmer=stemmer) + # Create the BM25 model and index the corpus + retriever = bm25s.BM25() + retriever.index(corpus_tokens) + # Query the corpus + query_tokens = bm25s.tokenize(query, stemmer=stemmer) + # Get top-k results as a tuple of (doc ids, scores). Both are arrays of shape (n_queries, k) + results, scores = retriever.retrieve(query_tokens, corpus=corpus, k=topk) + top_n = [corpus.index(res) for res in results[0]] + return top_n, scores + + +def find_evidence_from_wikipedia_dumps(claim): + # + doc = nlp(claim) + entities_in_claim = [str(ent).lower() for ent in doc.ents] + title2id = ranker.doc_dict[0] + wiki_intro, ent_list = [], [] + for ent in entities_in_claim: + if ent in title2id.keys(): + ids = title2id[ent] + introduction = doc_db.get_doc_intro(ids) + wiki_intro.append([ent, introduction]) + # fulltext = doc_db.get_doc_text(ids) + # evidence.append([ent, fulltext]) + ent_list.append(ent) + + if len(wiki_intro) < 5: + evidence_tfidf = process_topk(claim, title2id, ent_list, k=5) + wiki_intro.extend(evidence_tfidf) + + return wiki_intro, doc + + +def relevant_sentence_retrieval(query, wiki_intro, k): + # 1. Create corpus here + corpus, sentences = [], [] + titles = [] + for i, (title, intro) in enumerate(wiki_intro): + sents_in_intro = sent_tokenize(intro) + + for sent in sents_in_intro: + corpus.append(word_tokenize(sent)) + sentences.append(sent) + titles.append(title) + # + # ----- BM25 + bm25_top_n, bm25_top_n_scores = bm25_retriever(query, corpus, topk=k) + bm25_top_n_sents = [sentences[i] for i in bm25_top_n] + bm25_top_n_titles = [titles[i] for i in bm25_top_n] + + # ----- BM25s + # bm25s_top_n, bm25s_top_n_scores = bm25s_retriever(query, sentences, topk=k) # corpus->sentences + # bm25s_top_n_sents = [sentences[i] for i in bm25s_top_n] + # bm25s_top_n_titles = [titles[i] for i in bm25s_top_n] + + return bm25_top_n_sents, bm25_top_n_titles + + +def process_topk(query, title2id, ent_list, k=1): + doc_names, doc_scores = ranker.closest_docs(query, k) + evidence_tfidf = [] + + for _name in doc_names: + if _name not in ent_list and len(ent_list) < 5: + ent_list.append(_name) + idx = title2id[_name] + introduction = doc_db.get_doc_intro(idx) + evidence_tfidf.append([_name, introduction]) + # fulltext = doc_db.get_doc_text(idx) + # evidence_tfidf.append([_name,fulltext]) + + return evidence_tfidf + + +def WikipediaDumpsretriever(claim): + # + # 1. extract relevant wikipedia pages from wikipedia dumps + wiki_intro, doc = find_evidence_from_wikipedia_dumps(claim) + # wiki_intro = [['trump', "'''Trump''' most commonly refers to:\n* Donald Trump (born 1946), President of the United States from 2017 to 2021 \n* Trump (card games), any playing card given an ad-hoc high rank\n\n'''Trump''' may also refer to:"]] + + # 2. extract relevant sentences from extracted wikipedia pages + sents, titles = relevant_sentence_retrieval(claim, wiki_intro, k=3) + + # + results = [] + for i, (sent, title) in enumerate(zip(sents, titles)): + metadata = dict() + metadata['name'] = claim + metadata['url'] = "https://en.wikipedia.org/wiki/" + "_".join(title.split()) + metadata['cached_source_url'] = "https://en.wikipedia.org/wiki/" + "_".join(title.split()) + metadata['short_name'] = "Evidence {}".format(i + 1) + metadata['page_number'] = "" + metadata['query'] = sent + metadata['title'] = title + metadata['evidence'] = sent + metadata['answer'] = "" + metadata['page_content'] = "Title: " + str(metadata['title']) + "
" + "Evidence: " + metadata[ + 'evidence'] + page_content = f"""{metadata['page_content']}""" + + results.append(Docs(metadata, page_content)) + + return results + +# ----------WikipediaAPIretriever--------- +def clean_str(p): + return p.encode().decode("unicode-escape").encode("latin1").decode("utf-8") + + +def get_page_obs(page): + # find all paragraphs + paragraphs = page.split("\n") + paragraphs = [p.strip() for p in paragraphs if p.strip()] + + # # find all sentence + # sentences = [] + # for p in paragraphs: + # sentences += p.split('. ') + # sentences = [s.strip() + '.' for s in sentences if s.strip()] + # # return ' '.join(sentences[:5]) + # return ' '.join(sentences) + + return ' '.join(paragraphs[:5]) + + +def search_entity_wikipeida(entity): + find_evidence = [] + + page_py = wiki_wiki.page(entity) + if page_py.exists(): + introduction = page_py.summary + + find_evidence.append([str(entity), introduction]) + + return find_evidence + + +def search_step(entity): + ent_ = entity.replace(" ", "+") + search_url = f"https://en.wikipedia.org/w/index.php?search={ent_}" + response_text = requests.get(search_url).text + soup = BeautifulSoup(response_text, features="html.parser") + result_divs = soup.find_all("div", {"class": "mw-search-result-heading"}) + + find_evidence = [] + + if result_divs: # mismatch + # If the wikipeida page of the entity is not exist, find similar wikipedia pages. + result_titles = [clean_str(div.get_text().strip()) for div in result_divs] + similar_titles = result_titles[:5] + + for _t in similar_titles: + if len(find_evidence) < 5: + _evi = search_step(_t) + find_evidence.extend(_evi) + else: + page = [p.get_text().strip() for p in soup.find_all("p") + soup.find_all("ul")] + if any("may refer to:" in p for p in page): + _evi = search_step("[" + entity + "]") + find_evidence.extend(_evi) + else: + # page_py = wiki_wiki.page(entity) + # + # if page_py.exists(): + # introduction = page_py.summary + # else: + page_text = "" + for p in page: + if len(p.split(" ")) > 2: + page_text += clean_str(p) + if not p.endswith("\n"): + page_text += "\n" + introduction = get_page_obs(page_text) + + find_evidence.append([entity, introduction]) + + return find_evidence + + +def find_similar_wikipedia(entity, relevant_wikipages): + # If the relevant wikipeida page of the entity is less than 5, find similar wikipedia pages. + ent_ = entity.replace(" ", "+") + search_url = f"https://en.wikipedia.org/w/index.php?search={ent_}&title=Special:Search&profile=advanced&fulltext=1&ns0=1" + response_text = requests.get(search_url).text + soup = BeautifulSoup(response_text, features="html.parser") + result_divs = soup.find_all("div", {"class": "mw-search-result-heading"}) + + if result_divs: + result_titles = [clean_str(div.get_text().strip()) for div in result_divs] + similar_titles = result_titles[:5] + + saved_titles = [ent[0] for ent in relevant_wikipages] if relevant_wikipages else relevant_wikipages + for _t in similar_titles: + if _t not in saved_titles and len(relevant_wikipages) < 5: + _evi = search_entity_wikipeida(_t) + # _evi = search_step(_t) + relevant_wikipages.extend(_evi) + + return relevant_wikipages + + +def find_evidence_from_wikipedia(claim): + # + doc = nlp(claim) + # + wikipedia_page = [] + for ent in doc.ents: + relevant_wikipages = search_entity_wikipeida(ent) + + if len(relevant_wikipages) < 5: + relevant_wikipages = find_similar_wikipedia(str(ent), relevant_wikipages) + + wikipedia_page.extend(relevant_wikipages) + + return wikipedia_page + + +def relevant_wikipedia_API_retriever(claim): + # + doc = nlp(claim) + + wiki_intro = [] + for ent in doc.ents: + page_py = wiki_wiki.page(ent) + + if page_py.exists(): + introduction = page_py.summary + else: + introduction = "No documents found." + + wiki_intro.append([str(ent), introduction]) + + return wiki_intro, doc + + +def Wikipediaretriever(claim, sources): + # + # 1. extract relevant wikipedia pages from wikipedia dumps + if "Dump" in sources: + wikipedia_page = find_evidence_from_wikipedia_dumps(claim) + else: + wikipedia_page = find_evidence_from_wikipedia(claim) + # wiki_intro, doc = relevant_wikipedia_API_retriever(claim) + + # 2. extract relevant sentences from extracted wikipedia pages + sents, titles = relevant_sentence_retrieval(claim, wikipedia_page, k=3) + + # + results = [] + for i, (sent, title) in enumerate(zip(sents, titles)): + metadata = dict() + metadata['name'] = claim + metadata['url'] = "https://en.wikipedia.org/wiki/" + "_".join(title.split()) + metadata['cached_source_url'] = "https://en.wikipedia.org/wiki/" + "_".join(title) + metadata['short_name'] = "Evidence {}".format(i + 1) + metadata['page_number'] = "" + metadata['query'] = sent + metadata['title'] = title + metadata['evidence'] = sent + metadata['answer'] = "" + metadata['page_content'] = "Title: " + str(metadata['title']) + "
" + "Evidence: " + metadata['evidence'] + page_content = f"""{metadata['page_content']}""" + + results.append(Docs(metadata, page_content)) + + return results + + +def log_on_azure(file, logs, azure_share_client): + logs = json.dumps(logs) + file_client = azure_share_client.get_file_client(file) + file_client.upload_file(logs) + + +def chat(claim, history, sources): + evidence = [] + # if 'Google' in sources: + # evidence = GoogleAPIretriever(query) + + # if 'WikiPediaDumps' in sources: + # evidence = WikipediaDumpsretriever(query) + + if 'WikiPedia' in sources: + evidence = Wikipediaretriever(claim, sources) + + answer_set, answer_output = QAprediction(claim, evidence, sources) + + docs_html = "" + if len(evidence) > 0: + docs_html = [] + for i, evi in enumerate(evidence, 1): + docs_html.append(make_html_source(evi, i)) + docs_html = "".join(docs_html) + else: + print("No documents found") + + url_of_evidence = "" + output_language = "English" + output_query = claim + history[-1] = (claim, answer_set) + history = [tuple(x) for x in history] + + ############################################################ + evi_list = [] + for evi in evidence: + title_str = evi.metadata['title'] + evi_str = evi.metadata['evidence'] + evi_list.append([title_str, evi_str]) + + try: + # Log answer on Azure Blob Storage + # IF AZURE_ISSAVE=TRUE, save the logs into the Azure share client. + if bool(os.environ["AZURE_ISSAVE"]): + timestamp = str(datetime.now().timestamp()) + # timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + file = timestamp + ".json" + logs = { + "user_id": str(user_id), + "claim": claim, + "sources": sources, + "evidence": evi_list, + "url": url_of_evidence, + "answer": answer_output, + "time": timestamp, + } + log_on_azure(file, logs, azure_share_client) + except Exception as e: + print(f"Error logging on Azure Blob Storage: {e}") + raise gr.Error( + f"AVeriTeC Error: {str(e)[:100]} - The error has been noted, try another question and if the error remains, you can contact us :)") + ########## + + return history, docs_html, output_query, output_language + + +def main(): + init_prompt = """ + Hello, I am a fact-checking assistant designed to help you find appropriate evidence to predict the veracity of claims. + + What do you want to fact-check? + """ + + with gr.Blocks(title="AVeriTeC fact-checker", css="style.css", theme=theme, elem_id="main-component") as demo: + with gr.Tab("AVeriTeC"): + with gr.Row(elem_id="chatbot-row"): + with gr.Column(scale=2): + chatbot = gr.Chatbot( + value=[(None, init_prompt)], + show_copy_button=True, show_label=False, elem_id="chatbot", layout="panel", + avatar_images=(None, "assets/averitec.png") + ) # avatar_images=(None, "https://i.ibb.co/YNyd5W2/logo4.png"), + + with gr.Row(elem_id="input-message"): + textbox = gr.Textbox(placeholder="Ask me what claim do you want to check!", show_label=False, + scale=7, lines=1, interactive=True, elem_id="input-textbox") + # submit = gr.Button("",elem_id = "submit-button",scale = 1,interactive = True,icon = "https://static-00.iconduck.com/assets.00/settings-icon-2048x2046-cw28eevx.png") + + with gr.Column(scale=1, variant="panel", elem_id="right-panel"): + with gr.Tabs() as tabs: + with gr.TabItem("Examples", elem_id="tab-examples", id=0): + examples_hidden = gr.Textbox(visible=False) + first_key = list(CLAIMS_Type.keys())[0] + dropdown_samples = gr.Dropdown(CLAIMS_Type.keys(), value=first_key, interactive=True, + show_label=True, + label="Select claim type", + elem_id="dropdown-samples") + + samples = [] + for i, key in enumerate(CLAIMS_Type.keys()): + examples_visible = True if i == 0 else False + + with gr.Row(visible=examples_visible) as group_examples: + examples_questions = gr.Examples( + CLAIMS_Type[key], + [examples_hidden], + examples_per_page=8, + run_on_click=False, + elem_id=f"examples{i}", + api_name=f"examples{i}", + # label = "Click on the example question or enter your own", + # cache_examples=True, + ) + + samples.append(group_examples) + + with gr.Tab("Sources", elem_id="tab-citations", id=1): + sources_textbox = gr.HTML(show_label=False, elem_id="sources-textbox") + docs_textbox = gr.State("") + + with gr.Tab("Configuration", elem_id="tab-config", id=2): + gr.Markdown("Reminder: We currently only support fact-checking in English!") + + # dropdown_sources = gr.Radio( + # ["AVeriTeC", "WikiPediaDumps", "Google", "WikiPediaAPI"], + # label="Select source", + # value="WikiPediaAPI", + # interactive=True, + # ) + + dropdown_sources = gr.Radio( + ["Google", "WikiPedia"], + label="Select source", + value="WikiPedia", + interactive=True, + ) + + dropdown_retriever = gr.Dropdown( + ["BM25", "BM25s"], + label="Select evidence retriever", + multiselect=False, + value="BM25", + interactive=True, + ) + + output_query = gr.Textbox(label="Query used for retrieval", show_label=True, + elem_id="reformulated-query", lines=2, interactive=False) + output_language = gr.Textbox(label="Language", show_label=True, elem_id="language", lines=1, + interactive=False) + + with gr.Tab("About", elem_classes="max-height other-tabs"): + with gr.Row(): + with gr.Column(scale=1): + gr.Markdown("See more info at [https://fever.ai/task.html](https://fever.ai/task.html)") + + def start_chat(query, history): + history = history + [(query, None)] + history = [tuple(x) for x in history] + return (gr.update(interactive=False), gr.update(selected=1), history) + + def finish_chat(): + return (gr.update(interactive=True, value="")) + + (textbox + .submit(start_chat, [textbox, chatbot], [textbox, tabs, chatbot], queue=False, api_name="start_chat_textbox") + .then(chat, [textbox, chatbot, dropdown_sources], + [chatbot, sources_textbox, output_query, output_language], concurrency_limit=8, api_name="chat_textbox") + .then(finish_chat, None, [textbox], api_name="finish_chat_textbox") + ) + + (examples_hidden + .change(start_chat, [examples_hidden, chatbot], [textbox, tabs, chatbot], queue=False, + api_name="start_chat_examples") + .then(chat, [examples_hidden, chatbot, dropdown_sources], + [chatbot, sources_textbox, output_query, output_language], concurrency_limit=8, api_name="chat_examples") + .then(finish_chat, None, [textbox], api_name="finish_chat_examples") + ) + + def change_sample_questions(key): + index = list(CLAIMS_Type.keys()).index(key) + visible_bools = [False] * len(samples) + visible_bools[index] = True + return [gr.update(visible=visible_bools[i]) for i in range(len(samples))] + + dropdown_samples.change(change_sample_questions, dropdown_samples, samples) + demo.queue() + + demo.launch(share=True) + + +if __name__ == "__main__": + main() diff --git a/drqa/__init__.py b/drqa/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7d281e4860c5888761bac0bf815f1981a75e770b --- /dev/null +++ b/drqa/__init__.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python3 +# Copyright 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os +import sys +from pathlib import PosixPath + +if sys.version_info < (3, 5): + raise RuntimeError('DrQA supports Python 3.5 or higher.') + +DATA_DIR = ( + os.getenv('DRQA_DATA') or + os.path.join(PosixPath(__file__).absolute().parents[1].as_posix(), 'data') +) + +from . import tokenizers +from . import reader +from . import retriever +from . import pipeline diff --git a/drqa/__pycache__/__init__.cpython-38.pyc b/drqa/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8064258e6b8fa1ab0513dab86c64a68218acffcc Binary files /dev/null and b/drqa/__pycache__/__init__.cpython-38.pyc differ diff --git a/drqa/pipeline/__init__.py b/drqa/pipeline/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4ddbf9d65b91676ff9e974d92005207c35997e29 --- /dev/null +++ b/drqa/pipeline/__init__.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +# Copyright 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os +from ..tokenizers import CoreNLPTokenizer +from ..retriever import TfidfDocRanker +from ..retriever import DocDB +from .. import DATA_DIR + +DEFAULTS = { + 'tokenizer': CoreNLPTokenizer, + 'ranker': TfidfDocRanker, + 'db': DocDB, + 'reader_model': os.path.join(DATA_DIR, 'reader/multitask.mdl'), +} + + +def set_default(key, value): + global DEFAULTS + DEFAULTS[key] = value + + +from .drqa import DrQA diff --git a/drqa/pipeline/__pycache__/__init__.cpython-38.pyc b/drqa/pipeline/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..57f4e5d10d8feabc451f2e799b14f1dfb806b6a8 Binary files /dev/null and b/drqa/pipeline/__pycache__/__init__.cpython-38.pyc differ diff --git a/drqa/pipeline/__pycache__/drqa.cpython-38.pyc b/drqa/pipeline/__pycache__/drqa.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..69bb524d30ce019d35b6bb097c0a0fa77b3bbf09 Binary files /dev/null and b/drqa/pipeline/__pycache__/drqa.cpython-38.pyc differ diff --git a/drqa/pipeline/drqa.py b/drqa/pipeline/drqa.py new file mode 100644 index 0000000000000000000000000000000000000000..7830e9ee9a46a9d8f0f196c2a269becf1b158235 --- /dev/null +++ b/drqa/pipeline/drqa.py @@ -0,0 +1,312 @@ +#!/usr/bin/env python3 +# Copyright 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Full DrQA pipeline.""" + +import torch +import regex +import heapq +import math +import time +import logging + +from multiprocessing import Pool as ProcessPool +from multiprocessing.util import Finalize + +from ..reader.vector import batchify +from ..reader.data import ReaderDataset, SortedBatchSampler +from .. import reader +from .. import tokenizers +from . import DEFAULTS + +logger = logging.getLogger(__name__) + + +# ------------------------------------------------------------------------------ +# Multiprocessing functions to fetch and tokenize text +# ------------------------------------------------------------------------------ + +PROCESS_TOK = None +PROCESS_DB = None +PROCESS_CANDS = None + + +def init(tokenizer_class, tokenizer_opts, db_class, db_opts, candidates=None): + global PROCESS_TOK, PROCESS_DB, PROCESS_CANDS + PROCESS_TOK = tokenizer_class(**tokenizer_opts) + Finalize(PROCESS_TOK, PROCESS_TOK.shutdown, exitpriority=100) + PROCESS_DB = db_class(**db_opts) + Finalize(PROCESS_DB, PROCESS_DB.close, exitpriority=100) + PROCESS_CANDS = candidates + + +def fetch_text(doc_id): + global PROCESS_DB + return PROCESS_DB.get_doc_text(doc_id) + + +def tokenize_text(text): + global PROCESS_TOK + return PROCESS_TOK.tokenize(text) + + +# ------------------------------------------------------------------------------ +# Main DrQA pipeline +# ------------------------------------------------------------------------------ + + +class DrQA(object): + # Target size for squashing short paragraphs together. + # 0 = read every paragraph independently + # infty = read all paragraphs together + GROUP_LENGTH = 0 + + def __init__( + self, + reader_model=None, + embedding_file=None, + tokenizer=None, + fixed_candidates=None, + batch_size=128, + cuda=True, + data_parallel=False, + max_loaders=5, + num_workers=None, + db_config=None, + ranker_config=None + ): + """Initialize the pipeline. + + Args: + reader_model: model file from which to load the DocReader. + embedding_file: if given, will expand DocReader dictionary to use + all available pretrained embeddings. + tokenizer: string option to specify tokenizer used on docs. + fixed_candidates: if given, all predictions will be constrated to + the set of candidates contained in the file. One entry per line. + batch_size: batch size when processing paragraphs. + cuda: whether to use the gpu. + data_parallel: whether to use multile gpus. + max_loaders: max number of async data loading workers when reading. + (default is fine). + num_workers: number of parallel CPU processes to use for tokenizing + and post processing resuls. + db_config: config for doc db. + ranker_config: config for ranker. + """ + self.batch_size = batch_size + self.max_loaders = max_loaders + self.fixed_candidates = fixed_candidates is not None + self.cuda = cuda + + logger.info('Initializing document ranker...') + ranker_config = ranker_config or {} + ranker_class = ranker_config.get('class', DEFAULTS['ranker']) + ranker_opts = ranker_config.get('options', {}) + self.ranker = ranker_class(**ranker_opts) + + logger.info('Initializing document reader...') + reader_model = reader_model or DEFAULTS['reader_model'] + self.reader = reader.DocReader.load(reader_model, normalize=False) + if embedding_file: + logger.info('Expanding dictionary...') + words = reader.utils.index_embedding_words(embedding_file) + added = self.reader.expand_dictionary(words) + self.reader.load_embeddings(added, embedding_file) + if cuda: + self.reader.cuda() + if data_parallel: + self.reader.parallelize() + + if not tokenizer: + tok_class = DEFAULTS['tokenizer'] + else: + tok_class = tokenizers.get_class(tokenizer) + annotators = tokenizers.get_annotators_for_model(self.reader) + tok_opts = {'annotators': annotators} + + # ElasticSearch is also used as backend if used as ranker + if hasattr(self.ranker, 'es'): + db_config = ranker_config + db_class = ranker_class + db_opts = ranker_opts + else: + db_config = db_config or {} + db_class = db_config.get('class', DEFAULTS['db']) + db_opts = db_config.get('options', {}) + + logger.info('Initializing tokenizers and document retrievers...') + self.num_workers = num_workers + self.processes = ProcessPool( + num_workers, + initializer=init, + initargs=(tok_class, tok_opts, db_class, db_opts, fixed_candidates) + ) + + def _split_doc(self, doc): + """Given a doc, split it into chunks (by paragraph).""" + curr = [] + curr_len = 0 + for split in regex.split(r'\n+', doc): + split = split.strip() + if len(split) == 0: + continue + # Maybe group paragraphs together until we hit a length limit + if len(curr) > 0 and curr_len + len(split) > self.GROUP_LENGTH: + yield ' '.join(curr) + curr = [] + curr_len = 0 + curr.append(split) + curr_len += len(split) + if len(curr) > 0: + yield ' '.join(curr) + + def _get_loader(self, data, num_loaders): + """Return a pytorch data iterator for provided examples.""" + dataset = ReaderDataset(data, self.reader) + sampler = SortedBatchSampler( + dataset.lengths(), + self.batch_size, + shuffle=False + ) + loader = torch.utils.data.DataLoader( + dataset, + batch_size=self.batch_size, + sampler=sampler, + num_workers=num_loaders, + collate_fn=batchify, + pin_memory=self.cuda, + ) + return loader + + def process(self, query, candidates=None, top_n=1, n_docs=5, + return_context=False): + """Run a single query.""" + predictions = self.process_batch( + [query], [candidates] if candidates else None, + top_n, n_docs, return_context + ) + return predictions[0] + + def process_batch(self, queries, candidates=None, top_n=1, n_docs=5, + return_context=False): + """Run a batch of queries (more efficient).""" + t0 = time.time() + logger.info('Processing %d queries...' % len(queries)) + logger.info('Retrieving top %d docs...' % n_docs) + + # Rank documents for queries. + if len(queries) == 1: + ranked = [self.ranker.closest_docs(queries[0], k=n_docs)] + else: + ranked = self.ranker.batch_closest_docs( + queries, k=n_docs, num_workers=self.num_workers + ) + all_docids, all_doc_scores = zip(*ranked) + + # Flatten document ids and retrieve text from database. + # We remove duplicates for processing efficiency. + flat_docids = list({d for docids in all_docids for d in docids}) + did2didx = {did: didx for didx, did in enumerate(flat_docids)} + doc_texts = self.processes.map(fetch_text, flat_docids) + + # Split and flatten documents. Maintain a mapping from doc (index in + # flat list) to split (index in flat list). + flat_splits = [] + didx2sidx = [] + for text in doc_texts: + splits = self._split_doc(text) + didx2sidx.append([len(flat_splits), -1]) + for split in splits: + flat_splits.append(split) + didx2sidx[-1][1] = len(flat_splits) + + # Push through the tokenizers as fast as possible. + q_tokens = self.processes.map_async(tokenize_text, queries) + s_tokens = self.processes.map_async(tokenize_text, flat_splits) + q_tokens = q_tokens.get() + s_tokens = s_tokens.get() + + # Group into structured example inputs. Examples' ids represent + # mappings to their question, document, and split ids. + examples = [] + for qidx in range(len(queries)): + for rel_didx, did in enumerate(all_docids[qidx]): + start, end = didx2sidx[did2didx[did]] + for sidx in range(start, end): + if (len(q_tokens[qidx].words()) > 0 and + len(s_tokens[sidx].words()) > 0): + examples.append({ + 'id': (qidx, rel_didx, sidx), + 'question': q_tokens[qidx].words(), + 'qlemma': q_tokens[qidx].lemmas(), + 'document': s_tokens[sidx].words(), + 'lemma': s_tokens[sidx].lemmas(), + 'pos': s_tokens[sidx].pos(), + 'ner': s_tokens[sidx].entities(), + }) + + logger.info('Reading %d paragraphs...' % len(examples)) + + # Push all examples through the document reader. + # We decode argmax start/end indices asychronously on CPU. + result_handles = [] + num_loaders = min(self.max_loaders, math.floor(len(examples) / 1e3)) + for batch in self._get_loader(examples, num_loaders): + if candidates or self.fixed_candidates: + batch_cands = [] + for ex_id in batch[-1]: + batch_cands.append({ + 'input': s_tokens[ex_id[2]], + 'cands': candidates[ex_id[0]] if candidates else None + }) + handle = self.reader.predict( + batch, batch_cands, async_pool=self.processes + ) + else: + handle = self.reader.predict(batch, async_pool=self.processes) + result_handles.append((handle, batch[-1], batch[0].size(0))) + + # Iterate through the predictions, and maintain priority queues for + # top scored answers for each question in the batch. + queues = [[] for _ in range(len(queries))] + for result, ex_ids, batch_size in result_handles: + s, e, score = result.get() + for i in range(batch_size): + # We take the top prediction per split. + if len(score[i]) > 0: + item = (score[i][0], ex_ids[i], s[i][0], e[i][0]) + queue = queues[ex_ids[i][0]] + if len(queue) < top_n: + heapq.heappush(queue, item) + else: + heapq.heappushpop(queue, item) + + # Arrange final top prediction data. + all_predictions = [] + for queue in queues: + predictions = [] + while len(queue) > 0: + score, (qidx, rel_didx, sidx), s, e = heapq.heappop(queue) + prediction = { + 'doc_id': all_docids[qidx][rel_didx], + 'span': s_tokens[sidx].slice(s, e + 1).untokenize(), + 'doc_score': float(all_doc_scores[qidx][rel_didx]), + 'span_score': float(score), + } + if return_context: + prediction['context'] = { + 'text': s_tokens[sidx].untokenize(), + 'start': s_tokens[sidx].offsets()[s][0], + 'end': s_tokens[sidx].offsets()[e][1], + } + predictions.append(prediction) + all_predictions.append(predictions[-1::-1]) + + logger.info('Processed %d queries in %.4f (s)' % + (len(queries), time.time() - t0)) + + return all_predictions diff --git a/drqa/reader/__init__.py b/drqa/reader/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e9c6cc94bacc74e65061800b8ee00fd4eb78ef56 --- /dev/null +++ b/drqa/reader/__init__.py @@ -0,0 +1,28 @@ +#!/usr/bin/env python3 +# Copyright 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os +from ..tokenizers import CoreNLPTokenizer +from .. import DATA_DIR + + +DEFAULTS = { + 'tokenizer': CoreNLPTokenizer, + 'model': os.path.join(DATA_DIR, 'reader/single.mdl'), +} + + +def set_default(key, value): + global DEFAULTS + DEFAULTS[key] = value + +from .model import DocReader +from .predictor import Predictor +from . import config +from . import vector +from . import data +from . import utils diff --git a/drqa/reader/__pycache__/__init__.cpython-38.pyc b/drqa/reader/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7538e4dd229b00dab82f8a5a50c03b5e4d270685 Binary files /dev/null and b/drqa/reader/__pycache__/__init__.cpython-38.pyc differ diff --git a/drqa/reader/__pycache__/config.cpython-38.pyc b/drqa/reader/__pycache__/config.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a9a1f0cae7bd1cfe8c72bb4d72e6af6df941147 Binary files /dev/null and b/drqa/reader/__pycache__/config.cpython-38.pyc differ diff --git a/drqa/reader/__pycache__/data.cpython-38.pyc b/drqa/reader/__pycache__/data.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b72e5a48fd606cbc29aae9b7b341edc47e27c53 Binary files /dev/null and b/drqa/reader/__pycache__/data.cpython-38.pyc differ diff --git a/drqa/reader/__pycache__/layers.cpython-38.pyc b/drqa/reader/__pycache__/layers.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..457836a1d2c2823c327952bbee8b64f944c05c73 Binary files /dev/null and b/drqa/reader/__pycache__/layers.cpython-38.pyc differ diff --git a/drqa/reader/__pycache__/model.cpython-38.pyc b/drqa/reader/__pycache__/model.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e486a981ffe2eb326ed03ec54358912ca0db51f9 Binary files /dev/null and b/drqa/reader/__pycache__/model.cpython-38.pyc differ diff --git a/drqa/reader/__pycache__/predictor.cpython-38.pyc b/drqa/reader/__pycache__/predictor.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..757189fb0cddd9a58600c57421a5088943953752 Binary files /dev/null and b/drqa/reader/__pycache__/predictor.cpython-38.pyc differ diff --git a/drqa/reader/__pycache__/rnn_reader.cpython-38.pyc b/drqa/reader/__pycache__/rnn_reader.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5377e66e076949b6ac6b5b01d70151789de24387 Binary files /dev/null and b/drqa/reader/__pycache__/rnn_reader.cpython-38.pyc differ diff --git a/drqa/reader/__pycache__/utils.cpython-38.pyc b/drqa/reader/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cf4748aa1b5c99b7d9f4483d0db9cb88c6d94b45 Binary files /dev/null and b/drqa/reader/__pycache__/utils.cpython-38.pyc differ diff --git a/drqa/reader/__pycache__/vector.cpython-38.pyc b/drqa/reader/__pycache__/vector.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b9e2bd6baafc0a3ad43e1b46e70f0c6aaa53d113 Binary files /dev/null and b/drqa/reader/__pycache__/vector.cpython-38.pyc differ diff --git a/drqa/reader/config.py b/drqa/reader/config.py new file mode 100644 index 0000000000000000000000000000000000000000..4d49b191f99b93da233da5db0bd4b6287aabf6f9 --- /dev/null +++ b/drqa/reader/config.py @@ -0,0 +1,128 @@ +#!/usr/bin/env python3 +# Copyright 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Model architecture/optimization options for DrQA document reader.""" + +import argparse +import logging + +logger = logging.getLogger(__name__) + +# Index of arguments concerning the core model architecture +MODEL_ARCHITECTURE = { + 'model_type', 'embedding_dim', 'hidden_size', 'doc_layers', + 'question_layers', 'rnn_type', 'concat_rnn_layers', 'question_merge', + 'use_qemb', 'use_in_question', 'use_pos', 'use_ner', 'use_lemma', 'use_tf' +} + +# Index of arguments concerning the model optimizer/training +MODEL_OPTIMIZER = { + 'fix_embeddings', 'optimizer', 'learning_rate', 'momentum', 'weight_decay', + 'rnn_padding', 'dropout_rnn', 'dropout_rnn_output', 'dropout_emb', + 'max_len', 'grad_clipping', 'tune_partial' +} + + +def str2bool(v): + return v.lower() in ('yes', 'true', 't', '1', 'y') + + +def add_model_args(parser): + parser.register('type', 'bool', str2bool) + + # Model architecture + model = parser.add_argument_group('DrQA Reader Model Architecture') + model.add_argument('--model-type', type=str, default='rnn', + help='Model architecture type') + model.add_argument('--embedding-dim', type=int, default=300, + help='Embedding size if embedding_file is not given') + model.add_argument('--hidden-size', type=int, default=128, + help='Hidden size of RNN units') + model.add_argument('--doc-layers', type=int, default=3, + help='Number of encoding layers for document') + model.add_argument('--question-layers', type=int, default=3, + help='Number of encoding layers for question') + model.add_argument('--rnn-type', type=str, default='lstm', + help='RNN type: LSTM, GRU, or RNN') + + # Model specific details + detail = parser.add_argument_group('DrQA Reader Model Details') + detail.add_argument('--concat-rnn-layers', type='bool', default=True, + help='Combine hidden states from each encoding layer') + detail.add_argument('--question-merge', type=str, default='self_attn', + help='The way of computing the question representation') + detail.add_argument('--use-qemb', type='bool', default=True, + help='Whether to use weighted question embeddings') + detail.add_argument('--use-in-question', type='bool', default=True, + help='Whether to use in_question_* features') + detail.add_argument('--use-pos', type='bool', default=True, + help='Whether to use pos features') + detail.add_argument('--use-ner', type='bool', default=True, + help='Whether to use ner features') + detail.add_argument('--use-lemma', type='bool', default=True, + help='Whether to use lemma features') + detail.add_argument('--use-tf', type='bool', default=True, + help='Whether to use term frequency features') + + # Optimization details + optim = parser.add_argument_group('DrQA Reader Optimization') + optim.add_argument('--dropout-emb', type=float, default=0.4, + help='Dropout rate for word embeddings') + optim.add_argument('--dropout-rnn', type=float, default=0.4, + help='Dropout rate for RNN states') + optim.add_argument('--dropout-rnn-output', type='bool', default=True, + help='Whether to dropout the RNN output') + optim.add_argument('--optimizer', type=str, default='adamax', + help='Optimizer: sgd or adamax') + optim.add_argument('--learning-rate', type=float, default=0.1, + help='Learning rate for SGD only') + optim.add_argument('--grad-clipping', type=float, default=10, + help='Gradient clipping') + optim.add_argument('--weight-decay', type=float, default=0, + help='Weight decay factor') + optim.add_argument('--momentum', type=float, default=0, + help='Momentum factor') + optim.add_argument('--fix-embeddings', type='bool', default=True, + help='Keep word embeddings fixed (use pretrained)') + optim.add_argument('--tune-partial', type=int, default=0, + help='Backprop through only the top N question words') + optim.add_argument('--rnn-padding', type='bool', default=False, + help='Explicitly account for padding in RNN encoding') + optim.add_argument('--max-len', type=int, default=15, + help='The max span allowed during decoding') + + +def get_model_args(args): + """Filter args for model ones. + + From a args Namespace, return a new Namespace with *only* the args specific + to the model architecture or optimization. (i.e. the ones defined here.) + """ + global MODEL_ARCHITECTURE, MODEL_OPTIMIZER + required_args = MODEL_ARCHITECTURE | MODEL_OPTIMIZER + arg_values = {k: v for k, v in vars(args).items() if k in required_args} + return argparse.Namespace(**arg_values) + + +def override_model_args(old_args, new_args): + """Set args to new parameters. + + Decide which model args to keep and which to override when resolving a set + of saved args and new args. + + We keep the new optimation, but leave the model architecture alone. + """ + global MODEL_OPTIMIZER + old_args, new_args = vars(old_args), vars(new_args) + for k in old_args.keys(): + if k in new_args and old_args[k] != new_args[k]: + if k in MODEL_OPTIMIZER: + logger.info('Overriding saved %s: %s --> %s' % + (k, old_args[k], new_args[k])) + old_args[k] = new_args[k] + else: + logger.info('Keeping saved %s: %s' % (k, old_args[k])) + return argparse.Namespace(**old_args) diff --git a/drqa/reader/data.py b/drqa/reader/data.py new file mode 100644 index 0000000000000000000000000000000000000000..a9e0215814add331aec33c1ca757020765fbdef5 --- /dev/null +++ b/drqa/reader/data.py @@ -0,0 +1,131 @@ +#!/usr/bin/env python3 +# Copyright 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Data processing/loading helpers.""" + +import numpy as np +import logging +import unicodedata + +from torch.utils.data import Dataset +from torch.utils.data.sampler import Sampler +from .vector import vectorize + +logger = logging.getLogger(__name__) + + +# ------------------------------------------------------------------------------ +# Dictionary class for tokens. +# ------------------------------------------------------------------------------ + + +class Dictionary(object): + NULL = '' + UNK = '' + START = 2 + + @staticmethod + def normalize(token): + return unicodedata.normalize('NFD', token) + + def __init__(self): + self.tok2ind = {self.NULL: 0, self.UNK: 1} + self.ind2tok = {0: self.NULL, 1: self.UNK} + + def __len__(self): + return len(self.tok2ind) + + def __iter__(self): + return iter(self.tok2ind) + + def __contains__(self, key): + if type(key) == int: + return key in self.ind2tok + elif type(key) == str: + return self.normalize(key) in self.tok2ind + + def __getitem__(self, key): + if type(key) == int: + return self.ind2tok.get(key, self.UNK) + if type(key) == str: + return self.tok2ind.get(self.normalize(key), + self.tok2ind.get(self.UNK)) + + def __setitem__(self, key, item): + if type(key) == int and type(item) == str: + self.ind2tok[key] = item + elif type(key) == str and type(item) == int: + self.tok2ind[key] = item + else: + raise RuntimeError('Invalid (key, item) types.') + + def add(self, token): + token = self.normalize(token) + if token not in self.tok2ind: + index = len(self.tok2ind) + self.tok2ind[token] = index + self.ind2tok[index] = token + + def tokens(self): + """Get dictionary tokens. + + Return all the words indexed by this dictionary, except for special + tokens. + """ + tokens = [k for k in self.tok2ind.keys() + if k not in {'', ''}] + return tokens + + +# ------------------------------------------------------------------------------ +# PyTorch dataset class for SQuAD (and SQuAD-like) data. +# ------------------------------------------------------------------------------ + + +class ReaderDataset(Dataset): + + def __init__(self, examples, model, single_answer=False): + self.model = model + self.examples = examples + self.single_answer = single_answer + + def __len__(self): + return len(self.examples) + + def __getitem__(self, index): + return vectorize(self.examples[index], self.model, self.single_answer) + + def lengths(self): + return [(len(ex['document']), len(ex['question'])) + for ex in self.examples] + + +# ------------------------------------------------------------------------------ +# PyTorch sampler returning batched of sorted lengths (by doc and question). +# ------------------------------------------------------------------------------ + + +class SortedBatchSampler(Sampler): + + def __init__(self, lengths, batch_size, shuffle=True): + self.lengths = lengths + self.batch_size = batch_size + self.shuffle = shuffle + + def __iter__(self): + lengths = np.array( + [(-l[0], -l[1], np.random.random()) for l in self.lengths], + dtype=[('l1', np.int_), ('l2', np.int_), ('rand', np.float_)] + ) + indices = np.argsort(lengths, order=('l1', 'l2', 'rand')) + batches = [indices[i:i + self.batch_size] + for i in range(0, len(indices), self.batch_size)] + if self.shuffle: + np.random.shuffle(batches) + return iter([i for batch in batches for i in batch]) + + def __len__(self): + return len(self.lengths) diff --git a/drqa/reader/layers.py b/drqa/reader/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..71a967d806a04160ef6bce9741001c33eecca62e --- /dev/null +++ b/drqa/reader/layers.py @@ -0,0 +1,311 @@ +#!/usr/bin/env python3 +# Copyright 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Definitions of model layers/NN modules""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +# ------------------------------------------------------------------------------ +# Modules +# ------------------------------------------------------------------------------ + + +class StackedBRNN(nn.Module): + """Stacked Bi-directional RNNs. + + Differs from standard PyTorch library in that it has the option to save + and concat the hidden states between layers. (i.e. the output hidden size + for each sequence input is num_layers * hidden_size). + """ + + def __init__(self, input_size, hidden_size, num_layers, + dropout_rate=0, dropout_output=False, rnn_type=nn.LSTM, + concat_layers=False, padding=False): + super(StackedBRNN, self).__init__() + self.padding = padding + self.dropout_output = dropout_output + self.dropout_rate = dropout_rate + self.num_layers = num_layers + self.concat_layers = concat_layers + self.rnns = nn.ModuleList() + for i in range(num_layers): + input_size = input_size if i == 0 else 2 * hidden_size + self.rnns.append(rnn_type(input_size, hidden_size, + num_layers=1, + bidirectional=True)) + + def forward(self, x, x_mask): + """Encode either padded or non-padded sequences. + + Can choose to either handle or ignore variable length sequences. + Always handle padding in eval. + + Args: + x: batch * len * hdim + x_mask: batch * len (1 for padding, 0 for true) + Output: + x_encoded: batch * len * hdim_encoded + """ + if x_mask.data.sum() == 0: + # No padding necessary. + output = self._forward_unpadded(x, x_mask) + elif self.padding or not self.training: + # Pad if we care or if its during eval. + output = self._forward_padded(x, x_mask) + else: + # We don't care. + output = self._forward_unpadded(x, x_mask) + + return output.contiguous() + + def _forward_unpadded(self, x, x_mask): + """Faster encoding that ignores any padding.""" + # Transpose batch and sequence dims + x = x.transpose(0, 1) + + # Encode all layers + outputs = [x] + for i in range(self.num_layers): + rnn_input = outputs[-1] + + # Apply dropout to hidden input + if self.dropout_rate > 0: + rnn_input = F.dropout(rnn_input, + p=self.dropout_rate, + training=self.training) + # Forward + rnn_output = self.rnns[i](rnn_input)[0] + outputs.append(rnn_output) + + # Concat hidden layers + if self.concat_layers: + output = torch.cat(outputs[1:], 2) + else: + output = outputs[-1] + + # Transpose back + output = output.transpose(0, 1) + + # Dropout on output layer + if self.dropout_output and self.dropout_rate > 0: + output = F.dropout(output, + p=self.dropout_rate, + training=self.training) + return output + + def _forward_padded(self, x, x_mask): + """Slower (significantly), but more precise, encoding that handles + padding. + """ + # Compute sorted sequence lengths + lengths = x_mask.data.eq(0).long().sum(1).squeeze() + _, idx_sort = torch.sort(lengths, dim=0, descending=True) + _, idx_unsort = torch.sort(idx_sort, dim=0) + lengths = list(lengths[idx_sort]) + + # Sort x + x = x.index_select(0, idx_sort) + + # Transpose batch and sequence dims + x = x.transpose(0, 1) + + # Pack it up + rnn_input = nn.utils.rnn.pack_padded_sequence(x, lengths) + + # Encode all layers + outputs = [rnn_input] + for i in range(self.num_layers): + rnn_input = outputs[-1] + + # Apply dropout to input + if self.dropout_rate > 0: + dropout_input = F.dropout(rnn_input.data, + p=self.dropout_rate, + training=self.training) + rnn_input = nn.utils.rnn.PackedSequence(dropout_input, + rnn_input.batch_sizes) + outputs.append(self.rnns[i](rnn_input)[0]) + + # Unpack everything + for i, o in enumerate(outputs[1:], 1): + outputs[i] = nn.utils.rnn.pad_packed_sequence(o)[0] + + # Concat hidden layers or take final + if self.concat_layers: + output = torch.cat(outputs[1:], 2) + else: + output = outputs[-1] + + # Transpose and unsort + output = output.transpose(0, 1) + output = output.index_select(0, idx_unsort) + + # Pad up to original batch sequence length + if output.size(1) != x_mask.size(1): + padding = torch.zeros(output.size(0), + x_mask.size(1) - output.size(1), + output.size(2)).type(output.data.type()) + output = torch.cat([output, padding], 1) + + # Dropout on output layer + if self.dropout_output and self.dropout_rate > 0: + output = F.dropout(output, + p=self.dropout_rate, + training=self.training) + return output + + +class SeqAttnMatch(nn.Module): + """Given sequences X and Y, match sequence Y to each element in X. + + * o_i = sum(alpha_j * y_j) for i in X + * alpha_j = softmax(y_j * x_i) + """ + + def __init__(self, input_size, identity=False): + super(SeqAttnMatch, self).__init__() + if not identity: + self.linear = nn.Linear(input_size, input_size) + else: + self.linear = None + + def forward(self, x, y, y_mask): + """ + Args: + x: batch * len1 * hdim + y: batch * len2 * hdim + y_mask: batch * len2 (1 for padding, 0 for true) + Output: + matched_seq: batch * len1 * hdim + """ + # Project vectors + if self.linear: + x_proj = self.linear(x.view(-1, x.size(2))).view(x.size()) + x_proj = F.relu(x_proj) + y_proj = self.linear(y.view(-1, y.size(2))).view(y.size()) + y_proj = F.relu(y_proj) + else: + x_proj = x + y_proj = y + + # Compute scores + scores = x_proj.bmm(y_proj.transpose(2, 1)) + + # Mask padding + y_mask = y_mask.unsqueeze(1).expand(scores.size()) + scores.data.masked_fill_(y_mask.data, -float('inf')) + + # Normalize with softmax + alpha_flat = F.softmax(scores.view(-1, y.size(1)), dim=-1) + alpha = alpha_flat.view(-1, x.size(1), y.size(1)) + + # Take weighted average + matched_seq = alpha.bmm(y) + return matched_seq + + +class BilinearSeqAttn(nn.Module): + """A bilinear attention layer over a sequence X w.r.t y: + + * o_i = softmax(x_i'Wy) for x_i in X. + + Optionally don't normalize output weights. + """ + + def __init__(self, x_size, y_size, identity=False, normalize=True): + super(BilinearSeqAttn, self).__init__() + self.normalize = normalize + + # If identity is true, we just use a dot product without transformation. + if not identity: + self.linear = nn.Linear(y_size, x_size) + else: + self.linear = None + + def forward(self, x, y, x_mask): + """ + Args: + x: batch * len * hdim1 + y: batch * hdim2 + x_mask: batch * len (1 for padding, 0 for true) + Output: + alpha = batch * len + """ + Wy = self.linear(y) if self.linear is not None else y + xWy = x.bmm(Wy.unsqueeze(2)).squeeze(2) + xWy.data.masked_fill_(x_mask.data, -float('inf')) + if self.normalize: + if self.training: + # In training we output log-softmax for NLL + alpha = F.log_softmax(xWy, dim=-1) + else: + # ...Otherwise 0-1 probabilities + alpha = F.softmax(xWy, dim=-1) + else: + alpha = xWy.exp() + return alpha + + +class LinearSeqAttn(nn.Module): + """Self attention over a sequence: + + * o_i = softmax(Wx_i) for x_i in X. + """ + + def __init__(self, input_size): + super(LinearSeqAttn, self).__init__() + self.linear = nn.Linear(input_size, 1) + + def forward(self, x, x_mask): + """ + Args: + x: batch * len * hdim + x_mask: batch * len (1 for padding, 0 for true) + Output: + alpha: batch * len + """ + x_flat = x.view(-1, x.size(-1)) + scores = self.linear(x_flat).view(x.size(0), x.size(1)) + scores.data.masked_fill_(x_mask.data, -float('inf')) + alpha = F.softmax(scores, dim=-1) + return alpha + + +# ------------------------------------------------------------------------------ +# Functional +# ------------------------------------------------------------------------------ + + +def uniform_weights(x, x_mask): + """Return uniform weights over non-masked x (a sequence of vectors). + + Args: + x: batch * len * hdim + x_mask: batch * len (1 for padding, 0 for true) + Output: + x_avg: batch * hdim + """ + alpha = torch.ones(x.size(0), x.size(1)) + if x.data.is_cuda: + alpha = alpha.cuda() + alpha = alpha * x_mask.eq(0).float() + alpha = alpha / alpha.sum(1).expand(alpha.size()) + return alpha + + +def weighted_avg(x, weights): + """Return a weighted average of x (a sequence of vectors). + + Args: + x: batch * len * hdim + weights: batch * len, sum(dim = 1) = 1 + Output: + x_avg: batch * hdim + """ + return weights.unsqueeze(1).bmm(x).squeeze(1) diff --git a/drqa/reader/model.py b/drqa/reader/model.py new file mode 100644 index 0000000000000000000000000000000000000000..c52fca0d067f6c06b55422a0e0a8507c8982a94a --- /dev/null +++ b/drqa/reader/model.py @@ -0,0 +1,482 @@ +#!/usr/bin/env python3 +# Copyright 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""DrQA Document Reader model""" + +import torch +import torch.optim as optim +import torch.nn.functional as F +import numpy as np +import logging +import copy + +from .config import override_model_args +from .rnn_reader import RnnDocReader + +logger = logging.getLogger(__name__) + + +class DocReader(object): + """High level model that handles intializing the underlying network + architecture, saving, updating examples, and predicting examples. + """ + + # -------------------------------------------------------------------------- + # Initialization + # -------------------------------------------------------------------------- + + def __init__(self, args, word_dict, feature_dict, + state_dict=None, normalize=True): + # Book-keeping. + self.args = args + self.word_dict = word_dict + self.args.vocab_size = len(word_dict) + self.feature_dict = feature_dict + self.args.num_features = len(feature_dict) + self.updates = 0 + self.use_cuda = False + self.parallel = False + + # Building network. If normalize if false, scores are not normalized + # 0-1 per paragraph (no softmax). + if args.model_type == 'rnn': + self.network = RnnDocReader(args, normalize) + else: + raise RuntimeError('Unsupported model: %s' % args.model_type) + + # Load saved state + if state_dict: + # Load buffer separately + if 'fixed_embedding' in state_dict: + fixed_embedding = state_dict.pop('fixed_embedding') + self.network.load_state_dict(state_dict) + self.network.register_buffer('fixed_embedding', fixed_embedding) + else: + self.network.load_state_dict(state_dict) + + def expand_dictionary(self, words): + """Add words to the DocReader dictionary if they do not exist. The + underlying embedding matrix is also expanded (with random embeddings). + + Args: + words: iterable of tokens to add to the dictionary. + Output: + added: set of tokens that were added. + """ + to_add = {self.word_dict.normalize(w) for w in words + if w not in self.word_dict} + + # Add words to dictionary and expand embedding layer + if len(to_add) > 0: + logger.info('Adding %d new words to dictionary...' % len(to_add)) + for w in to_add: + self.word_dict.add(w) + self.args.vocab_size = len(self.word_dict) + logger.info('New vocab size: %d' % len(self.word_dict)) + + old_embedding = self.network.embedding.weight.data + self.network.embedding = torch.nn.Embedding(self.args.vocab_size, + self.args.embedding_dim, + padding_idx=0) + new_embedding = self.network.embedding.weight.data + new_embedding[:old_embedding.size(0)] = old_embedding + + # Return added words + return to_add + + def load_embeddings(self, words, embedding_file): + """Load pretrained embeddings for a given list of words, if they exist. + + Args: + words: iterable of tokens. Only those that are indexed in the + dictionary are kept. + embedding_file: path to text file of embeddings, space separated. + """ + words = {w for w in words if w in self.word_dict} + logger.info('Loading pre-trained embeddings for %d words from %s' % + (len(words), embedding_file)) + embedding = self.network.embedding.weight.data + + # When normalized, some words are duplicated. (Average the embeddings). + vec_counts = {} + with open(embedding_file) as f: + # Skip first line if of form count/dim. + line = f.readline().rstrip().split(' ') + if len(line) != 2: + f.seek(0) + for line in f: + parsed = line.rstrip().split(' ') + assert(len(parsed) == embedding.size(1) + 1) + w = self.word_dict.normalize(parsed[0]) + if w in words: + vec = torch.Tensor([float(i) for i in parsed[1:]]) + if w not in vec_counts: + vec_counts[w] = 1 + embedding[self.word_dict[w]].copy_(vec) + else: + logging.warning( + 'WARN: Duplicate embedding found for %s' % w + ) + vec_counts[w] = vec_counts[w] + 1 + embedding[self.word_dict[w]].add_(vec) + + for w, c in vec_counts.items(): + embedding[self.word_dict[w]].div_(c) + + logger.info('Loaded %d embeddings (%.2f%%)' % + (len(vec_counts), 100 * len(vec_counts) / len(words))) + + def tune_embeddings(self, words): + """Unfix the embeddings of a list of words. This is only relevant if + only some of the embeddings are being tuned (tune_partial = N). + + Shuffles the N specified words to the front of the dictionary, and saves + the original vectors of the other N + 1:vocab words in a fixed buffer. + + Args: + words: iterable of tokens contained in dictionary. + """ + words = {w for w in words if w in self.word_dict} + + if len(words) == 0: + logger.warning('Tried to tune embeddings, but no words given!') + return + + if len(words) == len(self.word_dict): + logger.warning('Tuning ALL embeddings in dictionary') + return + + # Shuffle words and vectors + embedding = self.network.embedding.weight.data + for idx, swap_word in enumerate(words, self.word_dict.START): + # Get current word + embedding for this index + curr_word = self.word_dict[idx] + curr_emb = embedding[idx].clone() + old_idx = self.word_dict[swap_word] + + # Swap embeddings + dictionary indices + embedding[idx].copy_(embedding[old_idx]) + embedding[old_idx].copy_(curr_emb) + self.word_dict[swap_word] = idx + self.word_dict[idx] = swap_word + self.word_dict[curr_word] = old_idx + self.word_dict[old_idx] = curr_word + + # Save the original, fixed embeddings + self.network.register_buffer( + 'fixed_embedding', embedding[idx + 1:].clone() + ) + + def init_optimizer(self, state_dict=None): + """Initialize an optimizer for the free parameters of the network. + + Args: + state_dict: network parameters + """ + if self.args.fix_embeddings: + for p in self.network.embedding.parameters(): + p.requires_grad = False + parameters = [p for p in self.network.parameters() if p.requires_grad] + if self.args.optimizer == 'sgd': + self.optimizer = optim.SGD(parameters, self.args.learning_rate, + momentum=self.args.momentum, + weight_decay=self.args.weight_decay) + elif self.args.optimizer == 'adamax': + self.optimizer = optim.Adamax(parameters, + weight_decay=self.args.weight_decay) + else: + raise RuntimeError('Unsupported optimizer: %s' % + self.args.optimizer) + + # -------------------------------------------------------------------------- + # Learning + # -------------------------------------------------------------------------- + + def update(self, ex): + """Forward a batch of examples; step the optimizer to update weights.""" + if not self.optimizer: + raise RuntimeError('No optimizer set.') + + # Train mode + self.network.train() + + # Transfer to GPU + if self.use_cuda: + inputs = [e if e is None else e.cuda(non_blocking=True) + for e in ex[:5]] + target_s = ex[5].cuda(non_blocking=True) + target_e = ex[6].cuda(non_blocking=True) + else: + inputs = [e if e is None else e for e in ex[:5]] + target_s = ex[5] + target_e = ex[6] + + # Run forward + score_s, score_e = self.network(*inputs) + + # Compute loss and accuracies + loss = F.nll_loss(score_s, target_s) + F.nll_loss(score_e, target_e) + + # Clear gradients and run backward + self.optimizer.zero_grad() + loss.backward() + + # Clip gradients + torch.nn.utils.clip_grad_norm_(self.network.parameters(), + self.args.grad_clipping) + + # Update parameters + self.optimizer.step() + self.updates += 1 + + # Reset any partially fixed parameters (e.g. rare words) + self.reset_parameters() + + return loss.item(), ex[0].size(0) + + def reset_parameters(self): + """Reset any partially fixed parameters to original states.""" + + # Reset fixed embeddings to original value + if self.args.tune_partial > 0: + if self.parallel: + embedding = self.network.module.embedding.weight.data + fixed_embedding = self.network.module.fixed_embedding + else: + embedding = self.network.embedding.weight.data + fixed_embedding = self.network.fixed_embedding + + # Embeddings to fix are the last indices + offset = embedding.size(0) - fixed_embedding.size(0) + if offset >= 0: + embedding[offset:] = fixed_embedding + + # -------------------------------------------------------------------------- + # Prediction + # -------------------------------------------------------------------------- + + def predict(self, ex, candidates=None, top_n=1, async_pool=None): + """Forward a batch of examples only to get predictions. + + Args: + ex: the batch + candidates: batch * variable length list of string answer options. + The model will only consider exact spans contained in this list. + top_n: Number of predictions to return per batch element. + async_pool: If provided, non-gpu post-processing will be offloaded + to this CPU process pool. + Output: + pred_s: batch * top_n predicted start indices + pred_e: batch * top_n predicted end indices + pred_score: batch * top_n prediction scores + + If async_pool is given, these will be AsyncResult handles. + """ + # Eval mode + self.network.eval() + + # Transfer to GPU + if self.use_cuda: + inputs = [e if e is None else e.cuda(non_blocking=True) + for e in ex[:5]] + else: + inputs = [e for e in ex[:5]] + + # Run forward + with torch.no_grad(): + score_s, score_e = self.network(*inputs) + + # Decode predictions + score_s = score_s.data.cpu() + score_e = score_e.data.cpu() + if candidates: + args = (score_s, score_e, candidates, top_n, self.args.max_len) + if async_pool: + return async_pool.apply_async(self.decode_candidates, args) + else: + return self.decode_candidates(*args) + else: + args = (score_s, score_e, top_n, self.args.max_len) + if async_pool: + return async_pool.apply_async(self.decode, args) + else: + return self.decode(*args) + + @staticmethod + def decode(score_s, score_e, top_n=1, max_len=None): + """Take argmax of constrained score_s * score_e. + + Args: + score_s: independent start predictions + score_e: independent end predictions + top_n: number of top scored pairs to take + max_len: max span length to consider + """ + pred_s = [] + pred_e = [] + pred_score = [] + max_len = max_len or score_s.size(1) + for i in range(score_s.size(0)): + # Outer product of scores to get full p_s * p_e matrix + scores = torch.ger(score_s[i], score_e[i]) + + # Zero out negative length and over-length span scores + scores.triu_().tril_(max_len - 1) + + # Take argmax or top n + scores = scores.numpy() + scores_flat = scores.flatten() + if top_n == 1: + idx_sort = [np.argmax(scores_flat)] + elif len(scores_flat) < top_n: + idx_sort = np.argsort(-scores_flat) + else: + idx = np.argpartition(-scores_flat, top_n)[0:top_n] + idx_sort = idx[np.argsort(-scores_flat[idx])] + s_idx, e_idx = np.unravel_index(idx_sort, scores.shape) + pred_s.append(s_idx) + pred_e.append(e_idx) + pred_score.append(scores_flat[idx_sort]) + return pred_s, pred_e, pred_score + + @staticmethod + def decode_candidates(score_s, score_e, candidates, top_n=1, max_len=None): + """Take argmax of constrained score_s * score_e. Except only consider + spans that are in the candidates list. + """ + pred_s = [] + pred_e = [] + pred_score = [] + for i in range(score_s.size(0)): + # Extract original tokens stored with candidates + tokens = candidates[i]['input'] + cands = candidates[i]['cands'] + + if not cands: + # try getting from globals? (multiprocessing in pipeline mode) + from ..pipeline.drqa import PROCESS_CANDS + cands = PROCESS_CANDS + if not cands: + raise RuntimeError('No candidates given.') + + # Score all valid candidates found in text. + # Brute force get all ngrams and compare against the candidate list. + max_len = max_len or len(tokens) + scores, s_idx, e_idx = [], [], [] + for s, e in tokens.ngrams(n=max_len, as_strings=False): + span = tokens.slice(s, e).untokenize() + if span in cands or span.lower() in cands: + # Match! Record its score. + scores.append(score_s[i][s] * score_e[i][e - 1]) + s_idx.append(s) + e_idx.append(e - 1) + + if len(scores) == 0: + # No candidates present + pred_s.append([]) + pred_e.append([]) + pred_score.append([]) + else: + # Rank found candidates + scores = np.array(scores) + s_idx = np.array(s_idx) + e_idx = np.array(e_idx) + + idx_sort = np.argsort(-scores)[0:top_n] + pred_s.append(s_idx[idx_sort]) + pred_e.append(e_idx[idx_sort]) + pred_score.append(scores[idx_sort]) + return pred_s, pred_e, pred_score + + # -------------------------------------------------------------------------- + # Saving and loading + # -------------------------------------------------------------------------- + + def save(self, filename): + if self.parallel: + network = self.network.module + else: + network = self.network + state_dict = copy.copy(network.state_dict()) + if 'fixed_embedding' in state_dict: + state_dict.pop('fixed_embedding') + params = { + 'state_dict': state_dict, + 'word_dict': self.word_dict, + 'feature_dict': self.feature_dict, + 'args': self.args, + } + try: + torch.save(params, filename) + except BaseException: + logger.warning('WARN: Saving failed... continuing anyway.') + + def checkpoint(self, filename, epoch): + if self.parallel: + network = self.network.module + else: + network = self.network + params = { + 'state_dict': network.state_dict(), + 'word_dict': self.word_dict, + 'feature_dict': self.feature_dict, + 'args': self.args, + 'epoch': epoch, + 'optimizer': self.optimizer.state_dict(), + } + try: + torch.save(params, filename) + except BaseException: + logger.warning('WARN: Saving failed... continuing anyway.') + + @staticmethod + def load(filename, new_args=None, normalize=True): + logger.info('Loading model %s' % filename) + saved_params = torch.load( + filename, map_location=lambda storage, loc: storage + ) + word_dict = saved_params['word_dict'] + feature_dict = saved_params['feature_dict'] + state_dict = saved_params['state_dict'] + args = saved_params['args'] + if new_args: + args = override_model_args(args, new_args) + return DocReader(args, word_dict, feature_dict, state_dict, normalize) + + @staticmethod + def load_checkpoint(filename, normalize=True): + logger.info('Loading model %s' % filename) + saved_params = torch.load( + filename, map_location=lambda storage, loc: storage + ) + word_dict = saved_params['word_dict'] + feature_dict = saved_params['feature_dict'] + state_dict = saved_params['state_dict'] + epoch = saved_params['epoch'] + optimizer = saved_params['optimizer'] + args = saved_params['args'] + model = DocReader(args, word_dict, feature_dict, state_dict, normalize) + model.init_optimizer(optimizer) + return model, epoch + + # -------------------------------------------------------------------------- + # Runtime + # -------------------------------------------------------------------------- + + def cuda(self): + self.use_cuda = True + self.network = self.network.cuda() + + def cpu(self): + self.use_cuda = False + self.network = self.network.cpu() + + def parallelize(self): + """Use data parallel to copy the model across several gpus. + This will take all gpus visible with CUDA_VISIBLE_DEVICES. + """ + self.parallel = True + self.network = torch.nn.DataParallel(self.network) diff --git a/drqa/reader/predictor.py b/drqa/reader/predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..1f8f7c1cae8490166c895f51ef571e121f3045e9 --- /dev/null +++ b/drqa/reader/predictor.py @@ -0,0 +1,145 @@ +#!/usr/bin/env python3 +# Copyright 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""DrQA Document Reader predictor""" + +import logging + +from multiprocessing import Pool as ProcessPool +from multiprocessing.util import Finalize + +from .vector import vectorize, batchify +from .model import DocReader +from . import DEFAULTS, utils +from .. import tokenizers + +logger = logging.getLogger(__name__) + + +# ------------------------------------------------------------------------------ +# Tokenize + annotate +# ------------------------------------------------------------------------------ + +PROCESS_TOK = None + + +def init(tokenizer_class, annotators): + global PROCESS_TOK + PROCESS_TOK = tokenizer_class(annotators=annotators) + Finalize(PROCESS_TOK, PROCESS_TOK.shutdown, exitpriority=100) + + +def tokenize(text): + global PROCESS_TOK + return PROCESS_TOK.tokenize(text) + + +# ------------------------------------------------------------------------------ +# Predictor class. +# ------------------------------------------------------------------------------ + + +class Predictor(object): + """Load a pretrained DocReader model and predict inputs on the fly.""" + + def __init__(self, model=None, tokenizer=None, normalize=True, + embedding_file=None, num_workers=None): + """ + Args: + model: path to saved model file. + tokenizer: option string to select tokenizer class. + normalize: squash output score to 0-1 probabilities with a softmax. + embedding_file: if provided, will expand dictionary to use all + available pretrained vectors in this file. + num_workers: number of CPU processes to use to preprocess batches. + """ + logger.info('Initializing model...') + self.model = DocReader.load(model or DEFAULTS['model'], + normalize=normalize) + + if embedding_file: + logger.info('Expanding dictionary...') + words = utils.index_embedding_words(embedding_file) + added = self.model.expand_dictionary(words) + self.model.load_embeddings(added, embedding_file) + + logger.info('Initializing tokenizer...') + annotators = tokenizers.get_annotators_for_model(self.model) + if not tokenizer: + tokenizer_class = DEFAULTS['tokenizer'] + else: + tokenizer_class = tokenizers.get_class(tokenizer) + + if num_workers is None or num_workers > 0: + self.workers = ProcessPool( + num_workers, + initializer=init, + initargs=(tokenizer_class, annotators), + ) + else: + self.workers = None + self.tokenizer = tokenizer_class(annotators=annotators) + + def predict(self, document, question, candidates=None, top_n=1): + """Predict a single document - question pair.""" + results = self.predict_batch([(document, question, candidates,)], top_n) + return results[0] + + def predict_batch(self, batch, top_n=1): + """Predict a batch of document - question pairs.""" + documents, questions, candidates = [], [], [] + for b in batch: + documents.append(b[0]) + questions.append(b[1]) + candidates.append(b[2] if len(b) == 3 else None) + candidates = candidates if any(candidates) else None + + # Tokenize the inputs, perhaps multi-processed. + if self.workers: + q_tokens = self.workers.map_async(tokenize, questions) + d_tokens = self.workers.map_async(tokenize, documents) + q_tokens = list(q_tokens.get()) + d_tokens = list(d_tokens.get()) + else: + q_tokens = list(map(self.tokenizer.tokenize, questions)) + d_tokens = list(map(self.tokenizer.tokenize, documents)) + + examples = [] + for i in range(len(questions)): + examples.append({ + 'id': i, + 'question': q_tokens[i].words(), + 'qlemma': q_tokens[i].lemmas(), + 'document': d_tokens[i].words(), + 'lemma': d_tokens[i].lemmas(), + 'pos': d_tokens[i].pos(), + 'ner': d_tokens[i].entities(), + }) + + # Stick document tokens in candidates for decoding + if candidates: + candidates = [{'input': d_tokens[i], 'cands': candidates[i]} + for i in range(len(candidates))] + + # Build the batch and run it through the model + batch_exs = batchify([vectorize(e, self.model) for e in examples]) + s, e, score = self.model.predict(batch_exs, candidates, top_n) + + # Retrieve the predicted spans + results = [] + for i in range(len(s)): + predictions = [] + for j in range(len(s[i])): + span = d_tokens[i].slice(s[i][j], e[i][j] + 1).untokenize() + predictions.append((span, score[i][j].item())) + results.append(predictions) + return results + + def cuda(self): + self.model.cuda() + + def cpu(self): + self.model.cpu() diff --git a/drqa/reader/rnn_reader.py b/drqa/reader/rnn_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..326dd2666b15c4e6403ccbe3ac7a62c1074642f2 --- /dev/null +++ b/drqa/reader/rnn_reader.py @@ -0,0 +1,135 @@ +#!/usr/bin/env python3 +# Copyright 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Implementation of the RNN based DrQA reader.""" + +import torch +import torch.nn as nn +from . import layers + + +# ------------------------------------------------------------------------------ +# Network +# ------------------------------------------------------------------------------ + + +class RnnDocReader(nn.Module): + RNN_TYPES = {'lstm': nn.LSTM, 'gru': nn.GRU, 'rnn': nn.RNN} + + def __init__(self, args, normalize=True): + super(RnnDocReader, self).__init__() + # Store config + self.args = args + + # Word embeddings (+1 for padding) + self.embedding = nn.Embedding(args.vocab_size, + args.embedding_dim, + padding_idx=0) + + # Projection for attention weighted question + if args.use_qemb: + self.qemb_match = layers.SeqAttnMatch(args.embedding_dim) + + # Input size to RNN: word emb + question emb + manual features + doc_input_size = args.embedding_dim + args.num_features + if args.use_qemb: + doc_input_size += args.embedding_dim + + # RNN document encoder + self.doc_rnn = layers.StackedBRNN( + input_size=doc_input_size, + hidden_size=args.hidden_size, + num_layers=args.doc_layers, + dropout_rate=args.dropout_rnn, + dropout_output=args.dropout_rnn_output, + concat_layers=args.concat_rnn_layers, + rnn_type=self.RNN_TYPES[args.rnn_type], + padding=args.rnn_padding, + ) + + # RNN question encoder + self.question_rnn = layers.StackedBRNN( + input_size=args.embedding_dim, + hidden_size=args.hidden_size, + num_layers=args.question_layers, + dropout_rate=args.dropout_rnn, + dropout_output=args.dropout_rnn_output, + concat_layers=args.concat_rnn_layers, + rnn_type=self.RNN_TYPES[args.rnn_type], + padding=args.rnn_padding, + ) + + # Output sizes of rnn encoders + doc_hidden_size = 2 * args.hidden_size + question_hidden_size = 2 * args.hidden_size + if args.concat_rnn_layers: + doc_hidden_size *= args.doc_layers + question_hidden_size *= args.question_layers + + # Question merging + if args.question_merge not in ['avg', 'self_attn']: + raise NotImplementedError('merge_mode = %s' % args.merge_mode) + if args.question_merge == 'self_attn': + self.self_attn = layers.LinearSeqAttn(question_hidden_size) + + # Bilinear attention for span start/end + self.start_attn = layers.BilinearSeqAttn( + doc_hidden_size, + question_hidden_size, + normalize=normalize, + ) + self.end_attn = layers.BilinearSeqAttn( + doc_hidden_size, + question_hidden_size, + normalize=normalize, + ) + + def forward(self, x1, x1_f, x1_mask, x2, x2_mask): + """Inputs: + x1 = document word indices [batch * len_d] + x1_f = document word features indices [batch * len_d * nfeat] + x1_mask = document padding mask [batch * len_d] + x2 = question word indices [batch * len_q] + x2_mask = question padding mask [batch * len_q] + """ + # Embed both document and question + x1_emb = self.embedding(x1) + x2_emb = self.embedding(x2) + + # Dropout on embeddings + if self.args.dropout_emb > 0: + x1_emb = nn.functional.dropout(x1_emb, p=self.args.dropout_emb, + training=self.training) + x2_emb = nn.functional.dropout(x2_emb, p=self.args.dropout_emb, + training=self.training) + + # Form document encoding inputs + drnn_input = [x1_emb] + + # Add attention-weighted question representation + if self.args.use_qemb: + x2_weighted_emb = self.qemb_match(x1_emb, x2_emb, x2_mask) + drnn_input.append(x2_weighted_emb) + + # Add manual features + if self.args.num_features > 0: + drnn_input.append(x1_f) + + # Encode document with RNN + doc_hiddens = self.doc_rnn(torch.cat(drnn_input, 2), x1_mask) + + # Encode question with RNN + merge hiddens + question_hiddens = self.question_rnn(x2_emb, x2_mask) + if self.args.question_merge == 'avg': + q_merge_weights = layers.uniform_weights(question_hiddens, x2_mask) + elif self.args.question_merge == 'self_attn': + q_merge_weights = self.self_attn(question_hiddens, x2_mask) + question_hidden = layers.weighted_avg(question_hiddens, q_merge_weights) + + # Predict start and end positions + start_scores = self.start_attn(doc_hiddens, question_hidden, x1_mask) + end_scores = self.end_attn(doc_hiddens, question_hidden, x1_mask) + return start_scores, end_scores diff --git a/drqa/reader/utils.py b/drqa/reader/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..66eb542600e8e85382773f77d2fc5683669c0c73 --- /dev/null +++ b/drqa/reader/utils.py @@ -0,0 +1,288 @@ +#!/usr/bin/env python3 +# Copyright 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""DrQA reader utilities.""" + +import json +import time +import logging +import string +import regex as re + +from collections import Counter +from .data import Dictionary + +logger = logging.getLogger(__name__) + + +# ------------------------------------------------------------------------------ +# Data loading +# ------------------------------------------------------------------------------ + + +def load_data(args, filename, skip_no_answer=False): + """Load examples from preprocessed file. + One example per line, JSON encoded. + """ + # Load JSON lines + with open(filename) as f: + examples = [json.loads(line) for line in f] + + # Make case insensitive? + if args.uncased_question or args.uncased_doc: + for ex in examples: + if args.uncased_question: + ex['question'] = [w.lower() for w in ex['question']] + if args.uncased_doc: + ex['document'] = [w.lower() for w in ex['document']] + + # Skip unparsed (start/end) examples + if skip_no_answer: + examples = [ex for ex in examples if len(ex['answers']) > 0] + + return examples + + +def load_text(filename): + """Load the paragraphs only of a SQuAD dataset. Store as qid -> text.""" + # Load JSON file + with open(filename) as f: + examples = json.load(f)['data'] + + texts = {} + for article in examples: + for paragraph in article['paragraphs']: + for qa in paragraph['qas']: + texts[qa['id']] = paragraph['context'] + return texts + + +def load_answers(filename): + """Load the answers only of a SQuAD dataset. Store as qid -> [answers].""" + # Load JSON file + with open(filename) as f: + examples = json.load(f)['data'] + + ans = {} + for article in examples: + for paragraph in article['paragraphs']: + for qa in paragraph['qas']: + ans[qa['id']] = list(map(lambda x: x['text'], qa['answers'])) + return ans + + +# ------------------------------------------------------------------------------ +# Dictionary building +# ------------------------------------------------------------------------------ + + +def index_embedding_words(embedding_file): + """Put all the words in embedding_file into a set.""" + words = set() + with open(embedding_file) as f: + for line in f: + w = Dictionary.normalize(line.rstrip().split(' ')[0]) + words.add(w) + return words + + +def load_words(args, examples): + """Iterate and index all the words in examples (documents + questions).""" + def _insert(iterable): + for w in iterable: + w = Dictionary.normalize(w) + if valid_words and w not in valid_words: + continue + words.add(w) + + if args.restrict_vocab and args.embedding_file: + logger.info('Restricting to words in %s' % args.embedding_file) + valid_words = index_embedding_words(args.embedding_file) + logger.info('Num words in set = %d' % len(valid_words)) + else: + valid_words = None + + words = set() + for ex in examples: + _insert(ex['question']) + _insert(ex['document']) + return words + + +def build_word_dict(args, examples): + """Return a dictionary from question and document words in + provided examples. + """ + word_dict = Dictionary() + for w in load_words(args, examples): + word_dict.add(w) + return word_dict + + +def top_question_words(args, examples, word_dict): + """Count and return the most common question words in provided examples.""" + word_count = Counter() + for ex in examples: + for w in ex['question']: + w = Dictionary.normalize(w) + if w in word_dict: + word_count.update([w]) + return word_count.most_common(args.tune_partial) + + +def build_feature_dict(args, examples): + """Index features (one hot) from fields in examples and options.""" + def _insert(feature): + if feature not in feature_dict: + feature_dict[feature] = len(feature_dict) + + feature_dict = {} + + # Exact match features + if args.use_in_question: + _insert('in_question') + _insert('in_question_uncased') + if args.use_lemma: + _insert('in_question_lemma') + + # Part of speech tag features + if args.use_pos: + for ex in examples: + for w in ex['pos']: + _insert('pos=%s' % w) + + # Named entity tag features + if args.use_ner: + for ex in examples: + for w in ex['ner']: + _insert('ner=%s' % w) + + # Term frequency feature + if args.use_tf: + _insert('tf') + return feature_dict + + +# ------------------------------------------------------------------------------ +# Evaluation. Follows official evalutation script for v1.1 of the SQuAD dataset. +# ------------------------------------------------------------------------------ + + +def normalize_answer(s): + """Lower text and remove punctuation, articles and extra whitespace.""" + def remove_articles(text): + return re.sub(r'\b(a|an|the)\b', ' ', text) + + def white_space_fix(text): + return ' '.join(text.split()) + + def remove_punc(text): + exclude = set(string.punctuation) + return ''.join(ch for ch in text if ch not in exclude) + + def lower(text): + return text.lower() + + return white_space_fix(remove_articles(remove_punc(lower(s)))) + + +def f1_score(prediction, ground_truth): + """Compute the geometric mean of precision and recall for answer tokens.""" + prediction_tokens = normalize_answer(prediction).split() + ground_truth_tokens = normalize_answer(ground_truth).split() + common = Counter(prediction_tokens) & Counter(ground_truth_tokens) + num_same = sum(common.values()) + if num_same == 0: + return 0 + precision = 1.0 * num_same / len(prediction_tokens) + recall = 1.0 * num_same / len(ground_truth_tokens) + f1 = (2 * precision * recall) / (precision + recall) + return f1 + + +def exact_match_score(prediction, ground_truth): + """Check if the prediction is a (soft) exact match with the ground truth.""" + return normalize_answer(prediction) == normalize_answer(ground_truth) + + +def regex_match_score(prediction, pattern): + """Check if the prediction matches the given regular expression.""" + try: + compiled = re.compile( + pattern, + flags=re.IGNORECASE + re.UNICODE + re.MULTILINE + ) + except BaseException: + logger.warn('Regular expression failed to compile: %s' % pattern) + return False + return compiled.match(prediction) is not None + + +def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): + """Given a prediction and multiple valid answers, return the score of + the best prediction-answer_n pair given a metric function. + """ + scores_for_ground_truths = [] + for ground_truth in ground_truths: + score = metric_fn(prediction, ground_truth) + scores_for_ground_truths.append(score) + return max(scores_for_ground_truths) + + +# ------------------------------------------------------------------------------ +# Utility classes +# ------------------------------------------------------------------------------ + + +class AverageMeter(object): + """Computes and stores the average and current value.""" + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +class Timer(object): + """Computes elapsed time.""" + + def __init__(self): + self.running = True + self.total = 0 + self.start = time.time() + + def reset(self): + self.running = True + self.total = 0 + self.start = time.time() + return self + + def resume(self): + if not self.running: + self.running = True + self.start = time.time() + return self + + def stop(self): + if self.running: + self.running = False + self.total += time.time() - self.start + return self + + def time(self): + if self.running: + return self.total + time.time() - self.start + return self.total diff --git a/drqa/reader/vector.py b/drqa/reader/vector.py new file mode 100644 index 0000000000000000000000000000000000000000..1a721c7589b92524b470aaea1ce22bfa70ed6b46 --- /dev/null +++ b/drqa/reader/vector.py @@ -0,0 +1,127 @@ +#!/usr/bin/env python3 +# Copyright 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Functions for putting examples into torch format.""" + +from collections import Counter +import torch + + +def vectorize(ex, model, single_answer=False): + """Torchify a single example.""" + args = model.args + word_dict = model.word_dict + feature_dict = model.feature_dict + + # Index words + document = torch.LongTensor([word_dict[w] for w in ex['document']]) + question = torch.LongTensor([word_dict[w] for w in ex['question']]) + + # Create extra features vector + if len(feature_dict) > 0: + features = torch.zeros(len(ex['document']), len(feature_dict)) + else: + features = None + + # f_{exact_match} + if args.use_in_question: + q_words_cased = {w for w in ex['question']} + q_words_uncased = {w.lower() for w in ex['question']} + q_lemma = {w for w in ex['qlemma']} if args.use_lemma else None + for i in range(len(ex['document'])): + if ex['document'][i] in q_words_cased: + features[i][feature_dict['in_question']] = 1.0 + if ex['document'][i].lower() in q_words_uncased: + features[i][feature_dict['in_question_uncased']] = 1.0 + if q_lemma and ex['lemma'][i] in q_lemma: + features[i][feature_dict['in_question_lemma']] = 1.0 + + # f_{token} (POS) + if args.use_pos: + for i, w in enumerate(ex['pos']): + f = 'pos=%s' % w + if f in feature_dict: + features[i][feature_dict[f]] = 1.0 + + # f_{token} (NER) + if args.use_ner: + for i, w in enumerate(ex['ner']): + f = 'ner=%s' % w + if f in feature_dict: + features[i][feature_dict[f]] = 1.0 + + # f_{token} (TF) + if args.use_tf: + counter = Counter([w.lower() for w in ex['document']]) + l = len(ex['document']) + for i, w in enumerate(ex['document']): + features[i][feature_dict['tf']] = counter[w.lower()] * 1.0 / l + + # Maybe return without target + if 'answers' not in ex: + return document, features, question, ex['id'] + + # ...or with target(s) (might still be empty if answers is empty) + if single_answer: + assert(len(ex['answers']) > 0) + start = torch.LongTensor(1).fill_(ex['answers'][0][0]) + end = torch.LongTensor(1).fill_(ex['answers'][0][1]) + else: + start = [a[0] for a in ex['answers']] + end = [a[1] for a in ex['answers']] + + return document, features, question, start, end, ex['id'] + + +def batchify(batch): + """Gather a batch of individual examples into one batch.""" + NUM_INPUTS = 3 + NUM_TARGETS = 2 + NUM_EXTRA = 1 + + ids = [ex[-1] for ex in batch] + docs = [ex[0] for ex in batch] + features = [ex[1] for ex in batch] + questions = [ex[2] for ex in batch] + + # Batch documents and features + max_length = max([d.size(0) for d in docs]) + x1 = torch.LongTensor(len(docs), max_length).zero_() + x1_mask = torch.ByteTensor(len(docs), max_length).fill_(1) + if features[0] is None: + x1_f = None + else: + x1_f = torch.zeros(len(docs), max_length, features[0].size(1)) + for i, d in enumerate(docs): + x1[i, :d.size(0)].copy_(d) + x1_mask[i, :d.size(0)].fill_(0) + if x1_f is not None: + x1_f[i, :d.size(0)].copy_(features[i]) + + # Batch questions + max_length = max([q.size(0) for q in questions]) + x2 = torch.LongTensor(len(questions), max_length).zero_() + x2_mask = torch.ByteTensor(len(questions), max_length).fill_(1) + for i, q in enumerate(questions): + x2[i, :q.size(0)].copy_(q) + x2_mask[i, :q.size(0)].fill_(0) + + # Maybe return without targets + if len(batch[0]) == NUM_INPUTS + NUM_EXTRA: + return x1, x1_f, x1_mask, x2, x2_mask, ids + + elif len(batch[0]) == NUM_INPUTS + NUM_EXTRA + NUM_TARGETS: + # ...Otherwise add targets + if torch.is_tensor(batch[0][3]): + y_s = torch.cat([ex[3] for ex in batch]) + y_e = torch.cat([ex[4] for ex in batch]) + else: + y_s = [ex[3] for ex in batch] + y_e = [ex[4] for ex in batch] + else: + raise RuntimeError('Incorrect number of inputs per example.') + + return x1, x1_f, x1_mask, x2, x2_mask, y_s, y_e, ids diff --git a/drqa/retriever/__init__.py b/drqa/retriever/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..13b4166edc7bea8b51fdfc248dffcd3859a7f4df --- /dev/null +++ b/drqa/retriever/__init__.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python3 +# Copyright 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os +from .. import DATA_DIR + +DEFAULTS = { + 'db_path': os.path.join(DATA_DIR, 'wikipedia/docs.db'), + 'tfidf_path': os.path.join( + DATA_DIR, + 'wikipedia/docs-tfidf-ngram=2-hash=16777216-tokenizer=simple.npz' + ), + 'elastic_url': 'localhost:9200' +} + + +def set_default(key, value): + global DEFAULTS + DEFAULTS[key] = value + + +def get_class(name): + if name == 'tfidf': + return TfidfDocRanker + if name == 'sqlite': + return DocDB + if name == 'elasticsearch': + return ElasticDocRanker + raise RuntimeError('Invalid retriever class: %s' % name) + + +from .doc_db import DocDB +from .tfidf_doc_ranker import TfidfDocRanker +from .elastic_doc_ranker import ElasticDocRanker diff --git a/drqa/retriever/__pycache__/__init__.cpython-38.pyc b/drqa/retriever/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd16cdb0bade54f9f9f4ba79d0e686006a88134d Binary files /dev/null and b/drqa/retriever/__pycache__/__init__.cpython-38.pyc differ diff --git a/drqa/retriever/__pycache__/doc_db.cpython-38.pyc b/drqa/retriever/__pycache__/doc_db.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee67d549cf21212978b158bc3681bd9618f21361 Binary files /dev/null and b/drqa/retriever/__pycache__/doc_db.cpython-38.pyc differ diff --git a/drqa/retriever/__pycache__/elastic_doc_ranker.cpython-38.pyc b/drqa/retriever/__pycache__/elastic_doc_ranker.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb6f8effd22ddd42ee3c7bd8cdd57d233e3f013c Binary files /dev/null and b/drqa/retriever/__pycache__/elastic_doc_ranker.cpython-38.pyc differ diff --git a/drqa/retriever/__pycache__/tfidf_doc_ranker.cpython-38.pyc b/drqa/retriever/__pycache__/tfidf_doc_ranker.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0e95b15af679a981e1655ba49d262c31c55fbfbc Binary files /dev/null and b/drqa/retriever/__pycache__/tfidf_doc_ranker.cpython-38.pyc differ diff --git a/drqa/retriever/__pycache__/utils.cpython-38.pyc b/drqa/retriever/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b4b5b6dfe381946e8f9663bec5c3d5632eb710f6 Binary files /dev/null and b/drqa/retriever/__pycache__/utils.cpython-38.pyc differ diff --git a/drqa/retriever/doc_db.py b/drqa/retriever/doc_db.py new file mode 100644 index 0000000000000000000000000000000000000000..3d1a451a37be375cdad1909536c6c38920700d34 --- /dev/null +++ b/drqa/retriever/doc_db.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python3 +# Copyright 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Documents, in a sqlite database.""" + +import sqlite3 +from . import utils +from . import DEFAULTS + + +class DocDB(object): + """Sqlite backed document storage. + + Implements get_doc_text(doc_id). + """ + + def __init__(self, db_path=None): + self.path = db_path or DEFAULTS['db_path'] + self.connection = sqlite3.connect(self.path, check_same_thread=False) + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + + def path(self): + """Return the path to the file that backs this database.""" + return self.path + + def close(self): + """Close the connection to the database.""" + self.connection.close() + + def get_doc_ids(self): + """Fetch all ids of docs stored in the db.""" + cursor = self.connection.cursor() + cursor.execute("SELECT id FROM documents") + results = [r[0] for r in cursor.fetchall()] + cursor.close() + return results + + def get_doc_text(self, doc_id): + """Fetch the raw text of the doc for 'doc_id'.""" + cursor = self.connection.cursor() + cursor.execute( + "SELECT text FROM documents WHERE id = ?", + (utils.normalize(doc_id), ) + # (doc_id, ) + ) + result = cursor.fetchone() + cursor.close() + return result if result is None else result[0] + + + def get_doc_title(self, doc_id): + """Fetch the raw text of the doc for 'doc_id'.""" + cursor = self.connection.cursor() + cursor.execute( + "SELECT title FROM documents WHERE id = ?", + (utils.normalize(doc_id),) + # (doc_id, ) + ) + result = cursor.fetchone() + cursor.close() + return result if result is None else result[0] + + def get_doc_intro(self, doc_id): + """Fetch the raw text of the doc for 'doc_id'.""" + cursor = self.connection.cursor() + cursor.execute( + "SELECT intro FROM documents WHERE id = ?", # intro: the introduction of Wikipedia page + (utils.normalize(doc_id),) + # (doc_id, ) + ) + result = cursor.fetchone() + cursor.close() + return result if result is None else result[0] diff --git a/drqa/retriever/elastic_doc_ranker.py b/drqa/retriever/elastic_doc_ranker.py new file mode 100644 index 0000000000000000000000000000000000000000..41d1bddd7bc07558aff0ffb7781df46978511129 --- /dev/null +++ b/drqa/retriever/elastic_doc_ranker.py @@ -0,0 +1,109 @@ +#!/usr/bin/env python3 +# Copyright 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Rank documents with an ElasticSearch index""" + +import logging +import scipy.sparse as sp + +from multiprocessing.pool import ThreadPool +from functools import partial +from elasticsearch import Elasticsearch + +from . import utils +from . import DEFAULTS +from .. import tokenizers + +logger = logging.getLogger(__name__) + + +class ElasticDocRanker(object): + """ Connect to an ElasticSearch index. + Score pairs based on Elasticsearch + """ + + def __init__(self, elastic_url=None, elastic_index=None, elastic_fields=None, elastic_field_doc_name=None, strict=True, elastic_field_content=None): + """ + Args: + elastic_url: URL of the ElasticSearch server containing port + elastic_index: Index name of ElasticSearch + elastic_fields: Fields of the Elasticsearch index to search in + elastic_field_doc_name: Field containing the name of the document (index) + strict: fail on empty queries or continue (and return empty result) + elastic_field_content: Field containing the content of document in plaint text + """ + # Load from disk + elastic_url = elastic_url or DEFAULTS['elastic_url'] + logger.info('Connecting to %s' % elastic_url) + self.es = Elasticsearch(hosts=elastic_url) + self.elastic_index = elastic_index + self.elastic_fields = elastic_fields + self.elastic_field_doc_name = elastic_field_doc_name + self.elastic_field_content = elastic_field_content + self.strict = strict + + # Elastic Ranker + + def get_doc_index(self, doc_id): + """Convert doc_id --> doc_index""" + field_index = self.elastic_field_doc_name + if isinstance(field_index, list): + field_index = '.'.join(field_index) + result = self.es.search(index=self.elastic_index, body={'query':{'match': + {field_index: doc_id}}}) + return result['hits']['hits'][0]['_id'] + + + def get_doc_id(self, doc_index): + """Convert doc_index --> doc_id""" + result = self.es.search(index=self.elastic_index, body={'query': { 'match': {"_id": doc_index}}}) + source = result['hits']['hits'][0]['_source'] + return utils.get_field(source, self.elastic_field_doc_name) + + def closest_docs(self, query, k=1): + """Closest docs by using ElasticSearch + """ + results = self.es.search(index=self.elastic_index, body={'size':k ,'query': + {'multi_match': { + 'query': query, + 'type': 'most_fields', + 'fields': self.elastic_fields}}}) + hits = results['hits']['hits'] + doc_ids = [utils.get_field(row['_source'], self.elastic_field_doc_name) for row in hits] + doc_scores = [row['_score'] for row in hits] + return doc_ids, doc_scores + + def batch_closest_docs(self, queries, k=1, num_workers=None): + """Process a batch of closest_docs requests multithreaded. + Note: we can use plain threads here as scipy is outside of the GIL. + """ + with ThreadPool(num_workers) as threads: + closest_docs = partial(self.closest_docs, k=k) + results = threads.map(closest_docs, queries) + return results + + # Elastic DB + + def __enter__(self): + return self + + def close(self): + """Close the connection to the database.""" + self.es = None + + def get_doc_ids(self): + """Fetch all ids of docs stored in the db.""" + results = self.es.search(index= self.elastic_index, body={ + "query": {"match_all": {}}}) + doc_ids = [utils.get_field(result['_source'], self.elastic_field_doc_name) for result in results['hits']['hits']] + return doc_ids + + def get_doc_text(self, doc_id): + """Fetch the raw text of the doc for 'doc_id'.""" + idx = self.get_doc_index(doc_id) + result = self.es.get(index=self.elastic_index, doc_type='_doc', id=idx) + return result if result is None else result['_source'][self.elastic_field_content] + diff --git a/drqa/retriever/tfidf_doc_ranker.py b/drqa/retriever/tfidf_doc_ranker.py new file mode 100644 index 0000000000000000000000000000000000000000..551ed879c13e78dc4029475e1e501e10ce1829ec --- /dev/null +++ b/drqa/retriever/tfidf_doc_ranker.py @@ -0,0 +1,121 @@ +#!/usr/bin/env python3 +# Copyright 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Rank documents with TF-IDF scores""" + +import logging +import numpy as np +import scipy.sparse as sp + +from multiprocessing.pool import ThreadPool +from functools import partial + +from . import utils +from . import DEFAULTS +from .. import tokenizers + +logger = logging.getLogger(__name__) + + +class TfidfDocRanker(object): + """Loads a pre-weighted inverted index of token/document terms. + Scores new queries by taking sparse dot products. + """ + + def __init__(self, tfidf_path=None, strict=True): + """ + Args: + tfidf_path: path to saved model file + strict: fail on empty queries or continue (and return empty result) + """ + # Load from disk + tfidf_path = tfidf_path or DEFAULTS['tfidf_path'] + logger.info('Loading %s' % tfidf_path) + matrix, metadata = utils.load_sparse_csr(tfidf_path) + self.doc_mat = matrix + self.ngrams = metadata['ngram'] + self.hash_size = metadata['hash_size'] + self.tokenizer = tokenizers.get_class(metadata['tokenizer'])() + self.doc_freqs = metadata['doc_freqs'].squeeze() + self.doc_dict = metadata['doc_dict'] + self.num_docs = len(self.doc_dict[0]) + self.strict = strict + + def get_doc_index(self, doc_id): + """Convert doc_id --> doc_index""" + return self.doc_dict[0][doc_id] + + def get_doc_id(self, doc_index): + """Convert doc_index --> doc_id""" + return self.doc_dict[1][doc_index] + + def closest_docs(self, query, k=1): + """Closest docs by dot product between query and documents + in tfidf weighted word vector space. + """ + spvec = self.text2spvec(query) + res = spvec * self.doc_mat + + if len(res.data) <= k: + o_sort = np.argsort(-res.data) + else: + o = np.argpartition(-res.data, k)[0:k] + o_sort = o[np.argsort(-res.data[o])] + + doc_scores = res.data[o_sort] + doc_ids = [self.get_doc_id(i) for i in res.indices[o_sort]] + return doc_ids, doc_scores + + def batch_closest_docs(self, queries, k=1, num_workers=None): + """Process a batch of closest_docs requests multithreaded. + Note: we can use plain threads here as scipy is outside of the GIL. + """ + with ThreadPool(num_workers) as threads: + closest_docs = partial(self.closest_docs, k=k) + results = threads.map(closest_docs, queries) + return results + + def parse(self, query): + """Parse the query into tokens (either ngrams or tokens).""" + tokens = self.tokenizer.tokenize(query) + return tokens.ngrams(n=self.ngrams, uncased=True, + filter_fn=utils.filter_ngram) + + def text2spvec(self, query): + """Create a sparse tfidf-weighted word vector from query. + + tfidf = log(tf + 1) * log((N - Nt + 0.5) / (Nt + 0.5)) + """ + # Get hashed ngrams + words = self.parse(utils.normalize(query)) + wids = [utils.hash(w, self.hash_size) for w in words] + + if len(wids) == 0: + if self.strict: + raise RuntimeError('No valid word in: %s' % query) + else: + logger.warning('No valid word in: %s' % query) + return sp.csr_matrix((1, self.hash_size)) + + # Count TF + wids_unique, wids_counts = np.unique(wids, return_counts=True) + tfs = np.log1p(wids_counts) + + # Count IDF + Ns = self.doc_freqs[wids_unique] + idfs = np.log((self.num_docs - Ns + 0.5) / (Ns + 0.5)) + idfs[idfs < 0] = 0 + + # TF-IDF + data = np.multiply(tfs, idfs) + + # One row, sparse csr matrix + indptr = np.array([0, len(wids_unique)]) + spvec = sp.csr_matrix( + (data, wids_unique, indptr), shape=(1, self.hash_size) + ) + + return spvec diff --git a/drqa/retriever/utils.py b/drqa/retriever/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c08c555c9d6b7929100f60266358d0da6423fa8c --- /dev/null +++ b/drqa/retriever/utils.py @@ -0,0 +1,120 @@ +#!/usr/bin/env python3 +# Copyright 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Various retriever utilities.""" + +import regex +import unicodedata +import numpy as np +import scipy.sparse as sp +from sklearn.utils import murmurhash3_32 + + +# ------------------------------------------------------------------------------ +# Sparse matrix saving/loading helpers. +# ------------------------------------------------------------------------------ + + +def save_sparse_csr(filename, matrix, metadata=None): + data = { + 'data': matrix.data, + 'indices': matrix.indices, + 'indptr': matrix.indptr, + 'shape': matrix.shape, + 'metadata': metadata, + } + np.savez(filename, **data) + + +def load_sparse_csr(filename): + loader = np.load(filename, allow_pickle=True) + matrix = sp.csr_matrix((loader['data'], loader['indices'], + loader['indptr']), shape=loader['shape']) + return matrix, loader['metadata'].item(0) if 'metadata' in loader else None + + +# ------------------------------------------------------------------------------ +# Token hashing. +# ------------------------------------------------------------------------------ + + +def hash(token, num_buckets): + """Unsigned 32 bit murmurhash for feature hashing.""" + return murmurhash3_32(token, positive=True) % num_buckets + + +# ------------------------------------------------------------------------------ +# Text cleaning. +# ------------------------------------------------------------------------------ + + +STOPWORDS = { + 'i', 'me', 'my', 'myself', 'we', 'our', 'ours', 'ourselves', 'you', 'your', + 'yours', 'yourself', 'yourselves', 'he', 'him', 'his', 'himself', 'she', + 'her', 'hers', 'herself', 'it', 'its', 'itself', 'they', 'them', 'their', + 'theirs', 'themselves', 'what', 'which', 'who', 'whom', 'this', 'that', + 'these', 'those', 'am', 'is', 'are', 'was', 'were', 'be', 'been', 'being', + 'have', 'has', 'had', 'having', 'do', 'does', 'did', 'doing', 'a', 'an', + 'the', 'and', 'but', 'if', 'or', 'because', 'as', 'until', 'while', 'of', + 'at', 'by', 'for', 'with', 'about', 'against', 'between', 'into', 'through', + 'during', 'before', 'after', 'above', 'below', 'to', 'from', 'up', 'down', + 'in', 'out', 'on', 'off', 'over', 'under', 'again', 'further', 'then', + 'once', 'here', 'there', 'when', 'where', 'why', 'how', 'all', 'any', + 'both', 'each', 'few', 'more', 'most', 'other', 'some', 'such', 'no', 'nor', + 'not', 'only', 'own', 'same', 'so', 'than', 'too', 'very', 's', 't', 'can', + 'will', 'just', 'don', 'should', 'now', 'd', 'll', 'm', 'o', 're', 've', + 'y', 'ain', 'aren', 'couldn', 'didn', 'doesn', 'hadn', 'hasn', 'haven', + 'isn', 'ma', 'mightn', 'mustn', 'needn', 'shan', 'shouldn', 'wasn', 'weren', + 'won', 'wouldn', "'ll", "'re", "'ve", "n't", "'s", "'d", "'m", "''", "``" +} + + +def normalize(text): + """Resolve different type of unicode encodings.""" + return unicodedata.normalize('NFD', text) + + +def filter_word(text): + """Take out english stopwords, punctuation, and compound endings.""" + text = normalize(text) + if regex.match(r'^\p{P}+$', text): + return True + if text.lower() in STOPWORDS: + return True + return False + + +def filter_ngram(gram, mode='any'): + """Decide whether to keep or discard an n-gram. + + Args: + gram: list of tokens (length N) + mode: Option to throw out ngram if + 'any': any single token passes filter_word + 'all': all tokens pass filter_word + 'ends': book-ended by filterable tokens + """ + filtered = [filter_word(w) for w in gram] + if mode == 'any': + return any(filtered) + elif mode == 'all': + return all(filtered) + elif mode == 'ends': + return filtered[0] or filtered[-1] + else: + raise ValueError('Invalid mode: %s' % mode) + +def get_field(d, field_list): + """get the subfield associated to a list of elastic fields + E.g. ['file', 'filename'] to d['file']['filename'] + """ + if isinstance(field_list, str): + return d[field_list] + else: + idx = d.copy() + for field in field_list: + idx = idx[field] + return idx diff --git a/drqa/tokenizers/__init__.py b/drqa/tokenizers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9d8616047064bdc3e70ef2cf8fc3509364f95b6a --- /dev/null +++ b/drqa/tokenizers/__init__.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 +# Copyright 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os + +DEFAULTS = { + 'corenlp_classpath': os.getenv('CLASSPATH') +} + + +def set_default(key, value): + global DEFAULTS + DEFAULTS[key] = value + + +from .corenlp_tokenizer import CoreNLPTokenizer +from .regexp_tokenizer import RegexpTokenizer +from .simple_tokenizer import SimpleTokenizer + +# Spacy is optional +try: + from .spacy_tokenizer import SpacyTokenizer +except ImportError: + pass + + +def get_class(name): + if name == 'spacy': + return SpacyTokenizer + if name == 'corenlp': + return CoreNLPTokenizer + if name == 'regexp': + return RegexpTokenizer + if name == 'simple': + return SimpleTokenizer + + raise RuntimeError('Invalid tokenizer: %s' % name) + + +def get_annotators_for_args(args): + annotators = set() + if args.use_pos: + annotators.add('pos') + if args.use_lemma: + annotators.add('lemma') + if args.use_ner: + annotators.add('ner') + return annotators + + +def get_annotators_for_model(model): + return get_annotators_for_args(model.args) diff --git a/drqa/tokenizers/__pycache__/__init__.cpython-38.pyc b/drqa/tokenizers/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a6038417baf1360eceea0a0cdb790b860834257b Binary files /dev/null and b/drqa/tokenizers/__pycache__/__init__.cpython-38.pyc differ diff --git a/drqa/tokenizers/__pycache__/corenlp_tokenizer.cpython-38.pyc b/drqa/tokenizers/__pycache__/corenlp_tokenizer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93161290b2662df9610b5ea11ece7a8d888db8d5 Binary files /dev/null and b/drqa/tokenizers/__pycache__/corenlp_tokenizer.cpython-38.pyc differ diff --git a/drqa/tokenizers/__pycache__/regexp_tokenizer.cpython-38.pyc b/drqa/tokenizers/__pycache__/regexp_tokenizer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2cc7d05c1da9129bc7dc9e73d6af499fff7ef9f9 Binary files /dev/null and b/drqa/tokenizers/__pycache__/regexp_tokenizer.cpython-38.pyc differ diff --git a/drqa/tokenizers/__pycache__/simple_tokenizer.cpython-38.pyc b/drqa/tokenizers/__pycache__/simple_tokenizer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..92bd1e7ff012c3f0d16145187b77a0cc09a225db Binary files /dev/null and b/drqa/tokenizers/__pycache__/simple_tokenizer.cpython-38.pyc differ diff --git a/drqa/tokenizers/__pycache__/spacy_tokenizer.cpython-38.pyc b/drqa/tokenizers/__pycache__/spacy_tokenizer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97d019553b1cbb918f481cb57a94b17e3f368e67 Binary files /dev/null and b/drqa/tokenizers/__pycache__/spacy_tokenizer.cpython-38.pyc differ diff --git a/drqa/tokenizers/__pycache__/tokenizer.cpython-38.pyc b/drqa/tokenizers/__pycache__/tokenizer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cc84891b0e751430423d5e89b9c27763f0bd9a7a Binary files /dev/null and b/drqa/tokenizers/__pycache__/tokenizer.cpython-38.pyc differ diff --git a/drqa/tokenizers/corenlp_tokenizer.py b/drqa/tokenizers/corenlp_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..10dece7432d7caf30f3845f37446a5ce353dd853 --- /dev/null +++ b/drqa/tokenizers/corenlp_tokenizer.py @@ -0,0 +1,122 @@ +#!/usr/bin/env python3 +# Copyright 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Simple wrapper around the Stanford CoreNLP pipeline. + +Serves commands to a java subprocess running the jar. Requires java 8. +""" + +import copy +import json +import pexpect + +from .tokenizer import Tokens, Tokenizer +from . import DEFAULTS + + +class CoreNLPTokenizer(Tokenizer): + + def __init__(self, **kwargs): + """ + Args: + annotators: set that can include pos, lemma, and ner. + classpath: Path to the corenlp directory of jars + mem: Java heap memory + """ + self.classpath = (kwargs.get('classpath') or + DEFAULTS['corenlp_classpath']) + self.annotators = copy.deepcopy(kwargs.get('annotators', set())) + self.mem = kwargs.get('mem', '2g') + self._launch() + + def _launch(self): + """Start the CoreNLP jar with pexpect.""" + annotators = ['tokenize', 'ssplit'] + if 'ner' in self.annotators: + annotators.extend(['pos', 'lemma', 'ner']) + elif 'lemma' in self.annotators: + annotators.extend(['pos', 'lemma']) + elif 'pos' in self.annotators: + annotators.extend(['pos']) + annotators = ','.join(annotators) + options = ','.join(['untokenizable=noneDelete', + 'invertible=true']) + cmd = ['java', '-mx' + self.mem, '-cp', '"%s"' % self.classpath, + 'edu.stanford.nlp.pipeline.StanfordCoreNLP', '-annotators', + annotators, '-tokenize.options', options, + '-outputFormat', 'json', '-prettyPrint', 'false'] + + # We use pexpect to keep the subprocess alive and feed it commands. + # Because we don't want to get hit by the max terminal buffer size, + # we turn off canonical input processing to have unlimited bytes. + self.corenlp = pexpect.spawn('/bin/bash', maxread=100000, timeout=60) + self.corenlp.setecho(False) + self.corenlp.sendline('stty -icanon') + self.corenlp.sendline(' '.join(cmd)) + self.corenlp.delaybeforesend = 0 + self.corenlp.delayafterread = 0 + self.corenlp.expect_exact('NLP>', searchwindowsize=100) + + @staticmethod + def _convert(token): + if token == '-LRB-': + return '(' + if token == '-RRB-': + return ')' + if token == '-LSB-': + return '[' + if token == '-RSB-': + return ']' + if token == '-LCB-': + return '{' + if token == '-RCB-': + return '}' + return token + + def tokenize(self, text): + # Since we're feeding text to the commandline, we're waiting on seeing + # the NLP> prompt. Hacky! + if 'NLP>' in text: + raise RuntimeError('Bad token (NLP>) in text!') + + # Sending q will cause the process to quit -- manually override + if text.lower().strip() == 'q': + token = text.strip() + index = text.index(token) + data = [(token, text[index:], (index, index + 1), 'NN', 'q', 'O')] + return Tokens(data, self.annotators) + + # Minor cleanup before tokenizing. + clean_text = text.replace('\n', ' ') + + self.corenlp.sendline(clean_text.encode('utf-8')) + self.corenlp.expect_exact('NLP>', searchwindowsize=100) + + # Skip to start of output (may have been stderr logging messages) + output = self.corenlp.before + start = output.find(b'{"sentences":') + output = json.loads(output[start:].decode('utf-8')) + + data = [] + tokens = [t for s in output['sentences'] for t in s['tokens']] + for i in range(len(tokens)): + # Get whitespace + start_ws = tokens[i]['characterOffsetBegin'] + if i + 1 < len(tokens): + end_ws = tokens[i + 1]['characterOffsetBegin'] + else: + end_ws = tokens[i]['characterOffsetEnd'] + + data.append(( + self._convert(tokens[i]['word']), + text[start_ws: end_ws], + (tokens[i]['characterOffsetBegin'], + tokens[i]['characterOffsetEnd']), + tokens[i].get('pos', None), + tokens[i].get('lemma', None), + tokens[i].get('ner', None) + )) + return Tokens(data, self.annotators) diff --git a/drqa/tokenizers/regexp_tokenizer.py b/drqa/tokenizers/regexp_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..94858ec71845c13bc5a1e049845e8ab3225cc894 --- /dev/null +++ b/drqa/tokenizers/regexp_tokenizer.py @@ -0,0 +1,100 @@ +#!/usr/bin/env python3 +# Copyright 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Regex based tokenizer that emulates the Stanford/NLTK PTB tokenizers. + +However it is purely in Python, supports robust untokenization, unicode, +and requires minimal dependencies. +""" + +import regex +import logging +from .tokenizer import Tokens, Tokenizer + +logger = logging.getLogger(__name__) + + +class RegexpTokenizer(Tokenizer): + DIGIT = r'\p{Nd}+([:\.\,]\p{Nd}+)*' + TITLE = (r'(dr|esq|hon|jr|mr|mrs|ms|prof|rev|sr|st|rt|messrs|mmes|msgr)' + r'\.(?=\p{Z})') + ABBRV = r'([\p{L}]\.){2,}(?=\p{Z}|$)' + ALPHA_NUM = r'[\p{L}\p{N}\p{M}]++' + HYPHEN = r'{A}([-\u058A\u2010\u2011]{A})+'.format(A=ALPHA_NUM) + NEGATION = r"((?!n't)[\p{L}\p{N}\p{M}])++(?=n't)|n't" + CONTRACTION1 = r"can(?=not\b)" + CONTRACTION2 = r"'([tsdm]|re|ll|ve)\b" + START_DQUOTE = r'(?<=[\p{Z}\(\[{<]|^)(``|["\u0093\u201C\u00AB])(?!\p{Z})' + START_SQUOTE = r'(?<=[\p{Z}\(\[{<]|^)[\'\u0091\u2018\u201B\u2039](?!\p{Z})' + END_DQUOTE = r'(?%s)|(?P%s)|(?P<abbr>%s)|(?P<neg>%s)|(?P<hyph>%s)|' + '(?P<contr1>%s)|(?P<alphanum>%s)|(?P<contr2>%s)|(?P<sdquote>%s)|' + '(?P<edquote>%s)|(?P<ssquote>%s)|(?P<esquote>%s)|(?P<dash>%s)|' + '(?<ellipses>%s)|(?P<punct>%s)|(?P<nonws>%s)' % + (self.DIGIT, self.TITLE, self.ABBRV, self.NEGATION, self.HYPHEN, + self.CONTRACTION1, self.ALPHA_NUM, self.CONTRACTION2, + self.START_DQUOTE, self.END_DQUOTE, self.START_SQUOTE, + self.END_SQUOTE, self.DASH, self.ELLIPSES, self.PUNCT, + self.NON_WS), + flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE + ) + if len(kwargs.get('annotators', {})) > 0: + logger.warning('%s only tokenizes! Skipping annotators: %s' % + (type(self).__name__, kwargs.get('annotators'))) + self.annotators = set() + self.substitutions = kwargs.get('substitutions', True) + + def tokenize(self, text): + data = [] + matches = [m for m in self._regexp.finditer(text)] + for i in range(len(matches)): + # Get text + token = matches[i].group() + + # Make normalizations for special token types + if self.substitutions: + groups = matches[i].groupdict() + if groups['sdquote']: + token = "``" + elif groups['edquote']: + token = "''" + elif groups['ssquote']: + token = "`" + elif groups['esquote']: + token = "'" + elif groups['dash']: + token = '--' + elif groups['ellipses']: + token = '...' + + # Get whitespace + span = matches[i].span() + start_ws = span[0] + if i + 1 < len(matches): + end_ws = matches[i + 1].span()[0] + else: + end_ws = span[1] + + # Format data + data.append(( + token, + text[start_ws: end_ws], + span, + )) + return Tokens(data, self.annotators) diff --git a/drqa/tokenizers/simple_tokenizer.py b/drqa/tokenizers/simple_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..bad601390457c08125c44007714eb40e36ac7ee1 --- /dev/null +++ b/drqa/tokenizers/simple_tokenizer.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python3 +# Copyright 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Basic tokenizer that splits text into alpha-numeric tokens and +non-whitespace tokens. +""" + +import regex +import logging +from .tokenizer import Tokens, Tokenizer + +logger = logging.getLogger(__name__) + + +class SimpleTokenizer(Tokenizer): + ALPHA_NUM = r'[\p{L}\p{N}\p{M}]+' + NON_WS = r'[^\p{Z}\p{C}]' + + def __init__(self, **kwargs): + """ + Args: + annotators: None or empty set (only tokenizes). + """ + self._regexp = regex.compile( + '(%s)|(%s)' % (self.ALPHA_NUM, self.NON_WS), + flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE + ) + if len(kwargs.get('annotators', {})) > 0: + logger.warning('%s only tokenizes! Skipping annotators: %s' % + (type(self).__name__, kwargs.get('annotators'))) + self.annotators = set() + + def tokenize(self, text): + data = [] + matches = [m for m in self._regexp.finditer(text)] + for i in range(len(matches)): + # Get text + token = matches[i].group() + + # Get whitespace + span = matches[i].span() + start_ws = span[0] + if i + 1 < len(matches): + end_ws = matches[i + 1].span()[0] + else: + end_ws = span[1] + + # Format data + data.append(( + token, + text[start_ws: end_ws], + span, + )) + return Tokens(data, self.annotators) diff --git a/drqa/tokenizers/spacy_tokenizer.py b/drqa/tokenizers/spacy_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..fc05d59afde21541e3a392674388263b61326382 --- /dev/null +++ b/drqa/tokenizers/spacy_tokenizer.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 +# Copyright 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Tokenizer that is backed by spaCy (spacy.io). + +Requires spaCy package and the spaCy english model. +""" + +import spacy +import copy +from .tokenizer import Tokens, Tokenizer + + +class SpacyTokenizer(Tokenizer): + + def __init__(self, **kwargs): + """ + Args: + annotators: set that can include pos, lemma, and ner. + model: spaCy model to use (either path, or keyword like 'en'). + """ + model = kwargs.get('model', 'en') + self.annotators = copy.deepcopy(kwargs.get('annotators', set())) + nlp_kwargs = {'parser': False} + if not any([p in self.annotators for p in ['lemma', 'pos', 'ner']]): + nlp_kwargs['tagger'] = False + if 'ner' not in self.annotators: + nlp_kwargs['entity'] = False + self.nlp = spacy.load(model, **nlp_kwargs) + + def tokenize(self, text): + # We don't treat new lines as tokens. + clean_text = text.replace('\n', ' ') + tokens = self.nlp.tokenizer(clean_text) + if any([p in self.annotators for p in ['lemma', 'pos', 'ner']]): + self.nlp.tagger(tokens) + if 'ner' in self.annotators: + self.nlp.entity(tokens) + + data = [] + for i in range(len(tokens)): + # Get whitespace + start_ws = tokens[i].idx + if i + 1 < len(tokens): + end_ws = tokens[i + 1].idx + else: + end_ws = tokens[i].idx + len(tokens[i].text) + + data.append(( + tokens[i].text, + text[start_ws: end_ws], + (tokens[i].idx, tokens[i].idx + len(tokens[i].text)), + tokens[i].tag_, + tokens[i].lemma_, + tokens[i].ent_type_, + )) + + # Set special option for non-entity tag: '' vs 'O' in spaCy + return Tokens(data, self.annotators, opts={'non_ent': ''}) diff --git a/drqa/tokenizers/tokenizer.py b/drqa/tokenizers/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..e4e588c3055414a55e06981a1818cb0394ca990a --- /dev/null +++ b/drqa/tokenizers/tokenizer.py @@ -0,0 +1,139 @@ +#!/usr/bin/env python3 +# Copyright 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Base tokenizer/tokens classes and utilities.""" + +import copy + + +class Tokens(object): + """A class to represent a list of tokenized text.""" + TEXT = 0 + TEXT_WS = 1 + SPAN = 2 + POS = 3 + LEMMA = 4 + NER = 5 + + def __init__(self, data, annotators, opts=None): + self.data = data + self.annotators = annotators + self.opts = opts or {} + + def __len__(self): + """The number of tokens.""" + return len(self.data) + + def slice(self, i=None, j=None): + """Return a view of the list of tokens from [i, j).""" + new_tokens = copy.copy(self) + new_tokens.data = self.data[i: j] + return new_tokens + + def untokenize(self): + """Returns the original text (with whitespace reinserted).""" + return ''.join([t[self.TEXT_WS] for t in self.data]).strip() + + def words(self, uncased=False): + """Returns a list of the text of each token + + Args: + uncased: lower cases text + """ + if uncased: + return [t[self.TEXT].lower() for t in self.data] + else: + return [t[self.TEXT] for t in self.data] + + def offsets(self): + """Returns a list of [start, end) character offsets of each token.""" + return [t[self.SPAN] for t in self.data] + + def pos(self): + """Returns a list of part-of-speech tags of each token. + Returns None if this annotation was not included. + """ + if 'pos' not in self.annotators: + return None + return [t[self.POS] for t in self.data] + + def lemmas(self): + """Returns a list of the lemmatized text of each token. + Returns None if this annotation was not included. + """ + if 'lemma' not in self.annotators: + return None + return [t[self.LEMMA] for t in self.data] + + def entities(self): + """Returns a list of named-entity-recognition tags of each token. + Returns None if this annotation was not included. + """ + if 'ner' not in self.annotators: + return None + return [t[self.NER] for t in self.data] + + def ngrams(self, n=1, uncased=False, filter_fn=None, as_strings=True): + """Returns a list of all ngrams from length 1 to n. + + Args: + n: upper limit of ngram length + uncased: lower cases text + filter_fn: user function that takes in an ngram list and returns + True or False to keep or not keep the ngram + as_string: return the ngram as a string vs list + """ + def _skip(gram): + if not filter_fn: + return False + return filter_fn(gram) + + words = self.words(uncased) + ngrams = [(s, e + 1) + for s in range(len(words)) + for e in range(s, min(s + n, len(words))) + if not _skip(words[s:e + 1])] + + # Concatenate into strings + if as_strings: + ngrams = ['{}'.format(' '.join(words[s:e])) for (s, e) in ngrams] + + return ngrams + + def entity_groups(self): + """Group consecutive entity tokens with the same NER tag.""" + entities = self.entities() + if not entities: + return None + non_ent = self.opts.get('non_ent', 'O') + groups = [] + idx = 0 + while idx < len(entities): + ner_tag = entities[idx] + # Check for entity tag + if ner_tag != non_ent: + # Chomp the sequence + start = idx + while (idx < len(entities) and entities[idx] == ner_tag): + idx += 1 + groups.append((self.slice(start, idx).untokenize(), ner_tag)) + else: + idx += 1 + return groups + + +class Tokenizer(object): + """Base tokenizer class. + Tokenizers implement tokenize, which should return a Tokens class. + """ + def tokenize(self, text): + raise NotImplementedError + + def shutdown(self): + pass + + def __del__(self): + self.shutdown() diff --git a/html2lines.py b/html2lines.py new file mode 100644 index 0000000000000000000000000000000000000000..420948dc393cf0047c7a177a26d8faecbe839c6c --- /dev/null +++ b/html2lines.py @@ -0,0 +1,72 @@ +from distutils.command.config import config +import requests +from time import sleep +import trafilatura +from trafilatura.meta import reset_caches +from trafilatura.settings import DEFAULT_CONFIG +import spacy +import os +os.system("python -m spacy download en_core_web_sm") +nlp = spacy.load('en_core_web_sm') +import sys + +DEFAULT_CONFIG.MAX_FILE_SIZE = 50000 + +def get_page(url): + page = None + for i in range(3): + try: + page = trafilatura.fetch_url(url, config=DEFAULT_CONFIG) + assert page is not None + print("Fetched "+url, file=sys.stderr) + break + except: + sleep(3) + return page + +def url2lines(url): + page = get_page(url) + + if page is None: + return [] + + lines = html2lines(page) + return lines + +def line_correction(lines, max_size=100): + out_lines = [] + for line in lines: + if len(line) < 4: + continue + + if len(line) > max_size: + doc = nlp(line[:5000]) # We split lines into sentences, but for performance we take only the first 5k characters per line + stack = "" + for sent in doc.sents: + if len(stack) > 0: + stack += " " + stack += str(sent).strip() + if len(stack) > max_size: + out_lines.append(stack) + stack = "" + + if len(stack) > 0: + out_lines.append(stack) + else: + out_lines.append(line) + + return out_lines + +def html2lines(page): + out_lines = [] + + if len(page.strip()) == 0 or page is None: + return out_lines + + text = trafilatura.extract(page, config=DEFAULT_CONFIG) + reset_caches() + + if text is None: + return out_lines + + return text.split("\n") # We just spit out the entire page, so need to reformat later. diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..94899466fcfdd3b8fba95c0bc3ef68adf2dda2e6 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,22 @@ +gradio +nltk +rank_bm25 +accelerate +trafilatura +spacy +pytorch_lightning +transformers==4.29.2 +datasets +leven +scikit-learn +pexpect +elasticsearch +torch +huggingface_hub +google-api-python-client +wikipedia-api +beautifulsoup4 +azure-storage-file-share +azure-storage-blob +bm25s +PyStemmer diff --git a/setup.sh b/setup.sh new file mode 100644 index 0000000000000000000000000000000000000000..ad8bb8ee847c128cbc233e57fa8f1b0d62c84d4e --- /dev/null +++ b/setup.sh @@ -0,0 +1 @@ +python -m spacy download en_core_web_sm \ No newline at end of file diff --git a/style.css b/style.css new file mode 100644 index 0000000000000000000000000000000000000000..a152a297f35b617595cb3ac3e5c8f31991a3f2f6 --- /dev/null +++ b/style.css @@ -0,0 +1,379 @@ + +/* :root { + --user-image: url('https://ih1.redbubble.net/image.4776899543.6215/st,small,507x507-pad,600x600,f8f8f8.jpg'); + } */ + +.warning-box { + background-color: #fff3cd; + border: 1px solid #ffeeba; + border-radius: 4px; + padding: 15px 20px; + font-size: 14px; + color: #856404; + display: inline-block; + margin-bottom: 15px; + } + + +.tip-box { + background-color: #f0f9ff; + border: 1px solid #80d4fa; + border-radius: 4px; + margin-top:20px; + padding: 15px 20px; + font-size: 14px; + display: inline-block; + margin-bottom: 15px; + width: auto; + color:black !important; +} + +body.dark .warning-box * { + color:black !important; +} + + +body.dark .tip-box * { + color:black !important; +} + + +.tip-box-title { + font-weight: bold; + font-size: 14px; + margin-bottom: 5px; +} + +.light-bulb { + display: inline; + margin-right: 5px; +} + +.gr-box {border-color: #d6c37c} + +#hidden-message{ + display:none; +} + +.message{ + font-size:14px !important; +} + + +a { + text-decoration: none; + color: inherit; +} + +.card { + background-color: white; + border-radius: 10px; + box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); + overflow: hidden; + display: flex; + flex-direction: column; + margin:20px; +} + +.card-content { + padding: 20px; +} + +.card-content h2 { + font-size: 14px !important; + font-weight: bold; + margin-bottom: 10px; + margin-top:0px !important; + color:#dc2626!important;; +} + +.card-content p { + font-size: 12px; + margin-bottom: 0; +} + +.card-footer { + background-color: #f4f4f4; + font-size: 10px; + padding: 10px; + display: flex; + justify-content: space-between; + align-items: center; +} + +.card-footer span { + flex-grow: 1; + text-align: left; + color: #999 !important; +} + +.pdf-link { + display: inline-flex; + align-items: center; + margin-left: auto; + text-decoration: none!important; + font-size: 14px; + word-wrap: break-word; /* For IE */ + word-break: break-all; /* For all other browsers */ +} + + + +.message.user{ + /* background-color:#7494b0 !important; */ + border:none; + /* color:white!important; */ +} + +.message.bot{ + /* background-color:#f2f2f7 !important; */ + border:none; +} + +/* .gallery-item > div:hover{ + background-color:#7494b0 !important; + color:white!important; +} + +.gallery-item:hover{ + border:#7494b0 !important; +} + +.gallery-item > div{ + background-color:white !important; + color:#577b9b!important; +} + +.label{ + color:#577b9b!important; +} */ + +/* .paginate{ + color:#577b9b!important; +} */ + + + +/* span[data-testid="block-info"]{ + background:none !important; + color:#577b9b; + } */ + +/* Pseudo-element for the circularly cropped picture */ +/* .message.bot::before { + content: ''; + position: absolute; + top: -10px; + left: -10px; + width: 30px; + height: 30px; + background-image: var(--user-image); + background-size: cover; + background-position: center; + border-radius: 50%; + z-index: 10; + } + */ + +label.selected{ + background:none !important; +} + +#submit-button{ + padding:0px !important; +} + + +@media screen and (min-width: 1024px) { + div#tab-examples{ + height:calc(100vh - 190px) !important; + overflow-y: auto; + } + + div#sources-textbox{ + height:calc(100vh - 190px) !important; + overflow-y: auto !important; + } + + div#tab-config{ + height:calc(100vh - 190px) !important; + overflow-y: auto !important; + } + + div#chatbot-row{ + height:calc(100vh - 90px) !important; + } + + div#chatbot{ + height:calc(100vh - 170px) !important; + } + + .max-height{ + height:calc(100vh - 90px) !important; + overflow-y: auto; + } + + /* .tabitem:nth-child(n+3) { + padding-top:30px; + padding-left:40px; + padding-right:40px; + } */ +} + +footer { + visibility: hidden; + display:none !important; +} + + +@media screen and (max-width: 767px) { + /* Your mobile-specific styles go here */ + + div#chatbot{ + height:500px !important; + } + + #submit-button{ + padding:0px !important; + min-width: 80px; + } + + /* This will hide all list items */ + div.tab-nav button { + display: none !important; + } + + /* This will show only the first list item */ + div.tab-nav button:first-child { + display: block !important; + } + + /* This will show only the first list item */ + div.tab-nav button:nth-child(2) { + display: block !important; + } + + #right-panel button{ + display: block !important; + } + + /* ... add other mobile-specific styles ... */ +} + + +body.dark .card{ + background-color: #374151; +} + +body.dark .card-content h2{ + color:#f4dbd3 !important; +} + +body.dark .card-footer { + background-color: #404652; +} + +body.dark .card-footer span { + color:white !important; +} + + +.doc-ref sup{ + color:#dc2626 !important; + margin-right:1px; +} + +.doc-ref_ori{ + color:#dc2626 !important; + margin-right:1px; +} + +.tabitem{ + border:none !important; +} + +.other-tabs > div{ + padding-left:40px; + padding-right:40px; + padding-top:10px; +} + +.gallery-item > div{ + white-space: normal !important; /* Allow the text to wrap */ + word-break: break-word !important; /* Break words to prevent overflow */ + overflow-wrap: break-word !important; /* Break long words if necessary */ + } + +span.chatbot > p > img{ + margin-top:40px !important; + max-height: none !important; + max-width: 80% !important; + border-radius:0px !important; +} + + +.chatbot-caption{ + font-size:11px; + font-style:italic; + color:#508094; +} + +.ai-generated{ + font-size:11px!important; + font-style:italic; + color:#73b8d4 !important; +} + +.card-image > .card-content{ + background-color:#f1f7fa !important; +} + + + +.tab-nav > button.selected{ + color:#4b8ec3; + font-weight:bold; + border:none; +} + +.tab-nav{ + border:none !important; +} + +#input-textbox > label > textarea{ + border-radius:40px; + padding-left:30px; + resize:none; +} + +#input-message > div{ + border:none; +} + +#dropdown-samples{ + /*! border:none !important; */ + /*! border-width:0px !important; */ + background:none !important; + +} + +#dropdown-samples > .container > .wrap{ + background-color:white; +} + + +#tab-examples > div > .form{ + border:none; + background:none !important; +} + +.a-doc-ref{ + text-decoration: none !important; +} + + + + + + + diff --git a/utils.py b/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ddc12a212abbc735fe244784c2bfbb298c37b28d --- /dev/null +++ b/utils.py @@ -0,0 +1,12 @@ +import numpy as np +import random +import string +import uuid + + +def create_user_id(): + """Create user_id + str: String to id user + """ + user_id = str(uuid.uuid4()) + return user_id \ No newline at end of file