diff --git a/README.md b/README.md
index 4718020c6296e8b180cbcffaf5d6d0a179d7653b..66df3a6dfd3ed12f68ec86fa947cb05b0556f6c2 100644
--- a/README.md
+++ b/README.md
@@ -1,12 +1,13 @@
---
-title: AVeriTeC API
-emoji: 🚀
-colorFrom: blue
-colorTo: gray
+title: AVeriTeC
+emoji: 🏆
+colorFrom: purple
+colorTo: red
sdk: gradio
-sdk_version: 4.38.1
+sdk_version: 4.37.2
app_file: app.py
pinned: false
+license: apache-2.0
---
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..fa0ac52430050f3a3d58b64c0de4fa45fc5bce0c
--- /dev/null
+++ b/app.py
@@ -0,0 +1,1368 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Created by zd302 at 08/07/2024
+
+import gradio as gr
+import tqdm
+import torch
+import numpy as np
+from time import sleep
+import threading
+import gc
+import os
+import json
+import pytorch_lightning as pl
+from urllib.parse import urlparse
+from accelerate import Accelerator
+
+from transformers import BartTokenizer, BartForConditionalGeneration
+from transformers import BloomTokenizerFast, BloomForCausalLM, BertTokenizer, BertForSequenceClassification
+from transformers import RobertaTokenizer, RobertaForSequenceClassification
+
+from rank_bm25 import BM25Okapi
+# import bm25s
+# import Stemmer # optional: for stemming
+from html2lines import url2lines
+from googleapiclient.discovery import build
+from averitec.models.DualEncoderModule import DualEncoderModule
+from averitec.models.SequenceClassificationModule import SequenceClassificationModule
+from averitec.models.JustificationGenerationModule import JustificationGenerationModule
+from averitec.data.sample_claims import CLAIMS_Type
+
+# ---------------------------------------------------------------------------
+# load .env
+from utils import create_user_id
+user_id = create_user_id()
+
+from datetime import datetime
+from azure.storage.fileshare import ShareServiceClient
+try:
+ from dotenv import load_dotenv
+ load_dotenv()
+except Exception as e:
+ pass
+
+account_url = os.environ["AZURE_ACCOUNT_URL"]
+credential = {
+ "account_key": os.environ['AZURE_ACCOUNT_KEY'],
+ "account_name": os.environ['AZURE_ACCOUNT_NAME']
+}
+
+file_share_name = "averitec"
+azure_service = ShareServiceClient(account_url=account_url, credential=credential)
+azure_share_client = azure_service.get_share_client(file_share_name)
+
+# ---------- Setting ----------
+import requests
+from bs4 import BeautifulSoup
+import wikipediaapi
+wiki_wiki = wikipediaapi.Wikipedia('AVeriTeC (zd302@cam.ac.uk)', 'en')
+
+import nltk
+nltk.download('punkt')
+from nltk import pos_tag, word_tokenize, sent_tokenize
+
+import spacy
+os.system("python -m spacy download en_core_web_sm")
+nlp = spacy.load("en_core_web_sm")
+
+# ---------------------------------------------------------------------------
+# Load sample dict for AVeriTeC search
+# all_samples_dict = json.load(open('averitec/data/all_samples.json', 'r'))
+
+# ---------------------------------------------------------------------------
+# ---------- Load pretrained models ----------
+# ---------- load Evidence retrieval model ----------
+# from drqa import retriever
+# db_class = retriever.get_class('sqlite')
+# doc_db = db_class("averitec/data/wikipedia_dumps/enwiki.db")
+# ranker = retriever.get_class('tfidf')(tfidf_path="averitec/data/wikipedia_dumps/enwiki-tfidf-with-id-title.npz")
+
+# ---------- Load Veracity and Justification prediction model ----------
+print("Loading models ...")
+LABEL = [
+ "Supported",
+ "Refuted",
+ "Not Enough Evidence",
+ "Conflicting Evidence/Cherrypicking",
+]
+# Veracity
+device = "cuda:0" if torch.cuda.is_available() else "cpu"
+veracity_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
+bert_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=4, problem_type="single_label_classification")
+veracity_model = SequenceClassificationModule.load_from_checkpoint("averitec/pretrained_models/bert_veracity.ckpt",
+ tokenizer=veracity_tokenizer, model=bert_model).to(device)
+# Justification
+justification_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large', add_prefix_space=True)
+bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large")
+best_checkpoint = 'averitec/pretrained_models/bart_justifications_verdict-epoch=13-val_loss=2.03-val_meteor=0.28.ckpt'
+justification_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=justification_tokenizer, model=bart_model).to(device)
+# ---------------------------------------------------------------------------
+
+
+# Set up Gradio Theme
+theme = gr.themes.Base(
+ primary_hue="blue",
+ secondary_hue="red",
+ font=[gr.themes.GoogleFont("Poppins"), "ui-sans-serif", "system-ui", "sans-serif"],
+)
+
+# ---------- Setting ----------
+
+class Docs:
+ def __init__(self, metadata=dict(), page_content=""):
+ self.metadata = metadata
+ self.page_content = page_content
+
+
+def make_html_source(source, i):
+ meta = source.metadata
+ content = source.page_content.strip()
+
+ card = f"""
+
+ """
+
+ return card
+
+
+# ----- veracity_prediction -----
+class SequenceClassificationDataLoader(pl.LightningDataModule):
+ def __init__(self, tokenizer, data_file, batch_size, add_extra_nee=False):
+ super().__init__()
+ self.tokenizer = tokenizer
+ self.data_file = data_file
+ self.batch_size = batch_size
+ self.add_extra_nee = add_extra_nee
+
+ def tokenize_strings(
+ self,
+ source_sentences,
+ max_length=400,
+ pad_to_max_length=False,
+ return_tensors="pt",
+ ):
+ encoded_dict = self.tokenizer(
+ source_sentences,
+ max_length=max_length,
+ padding="max_length" if pad_to_max_length else "longest",
+ truncation=True,
+ return_tensors=return_tensors,
+ )
+
+ input_ids = encoded_dict["input_ids"]
+ attention_masks = encoded_dict["attention_mask"]
+
+ return input_ids, attention_masks
+
+ def quadruple_to_string(self, claim, question, answer, bool_explanation=""):
+ if bool_explanation is not None and len(bool_explanation) > 0:
+ bool_explanation = ", because " + bool_explanation.lower().strip()
+ else:
+ bool_explanation = ""
+ return (
+ "[CLAIM] "
+ + claim.strip()
+ + " [QUESTION] "
+ + question.strip()
+ + " "
+ + answer.strip()
+ + bool_explanation
+ )
+
+
+def averitec_veracity_prediction(claim, qa_evidence):
+ bert_model_name = "bert-base-uncased"
+ tokenizer = BertTokenizer.from_pretrained(bert_model_name)
+ bert_model = BertForSequenceClassification.from_pretrained(bert_model_name, num_labels=4,
+ problem_type="single_label_classification")
+
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
+ trained_model = SequenceClassificationModule.load_from_checkpoint("averitec/pretrained_models/bert_veracity.ckpt",
+ tokenizer=tokenizer, model=bert_model).to(device)
+
+ dataLoader = SequenceClassificationDataLoader(
+ tokenizer=tokenizer,
+ data_file="this_is_discontinued",
+ batch_size=32,
+ add_extra_nee=False,
+ )
+
+ evidence_strings = []
+ for evidence in qa_evidence:
+ evidence_strings.append(
+ dataLoader.quadruple_to_string(claim, evidence.metadata["query"], evidence.metadata["answer"], ""))
+
+ if len(evidence_strings) == 0: # If we found no evidence e.g. because google returned 0 pages, just output NEI.
+ pred_label = "Not Enough Evidence"
+ return pred_label
+
+ tokenized_strings, attention_mask = dataLoader.tokenize_strings(evidence_strings)
+ example_support = torch.argmax(
+ trained_model(tokenized_strings.to(device), attention_mask=attention_mask.to(device)).logits, axis=1)
+
+ has_unanswerable = False
+ has_true = False
+ has_false = False
+
+ for v in example_support:
+ if v == 0:
+ has_true = True
+ if v == 1:
+ has_false = True
+ if v in (2, 3,): # TODO another hack -- we cant have different labels for train and test so we do this
+ has_unanswerable = True
+
+ if has_unanswerable:
+ answer = 2
+ elif has_true and not has_false:
+ answer = 0
+ elif not has_true and has_false:
+ answer = 1
+ else:
+ answer = 3
+
+ pred_label = LABEL[answer]
+
+ return pred_label
+
+
+def fever_veracity_prediction(claim, evidence):
+ tokenizer = RobertaTokenizer.from_pretrained('Dzeniks/roberta-fact-check')
+ model = RobertaForSequenceClassification.from_pretrained('Dzeniks/roberta-fact-check')
+
+ evidence_string = ""
+ for evi in evidence:
+ evidence_string += evi.metadata['title'] + evi.metadata['evidence'] + ' '
+
+ input_sequence = tokenizer.encode_plus(claim, evidence_string, return_tensors="pt")
+ with torch.no_grad():
+ prediction = model(**input_sequence)
+
+ label = torch.argmax(prediction[0]).item()
+ pred_label = LABEL[label]
+
+ return pred_label
+
+
+def veracity_prediction(claim, qa_evidence):
+ # bert_model_name = "bert-base-uncased"
+ # tokenizer = BertTokenizer.from_pretrained(bert_model_name)
+ # bert_model = BertForSequenceClassification.from_pretrained(bert_model_name, num_labels=4,
+ # problem_type="single_label_classification")
+ #
+ # device = "cuda:0" if torch.cuda.is_available() else "cpu"
+ # trained_model = SequenceClassificationModule.load_from_checkpoint("averitec/pretrained_models/bert_veracity.ckpt",
+ # tokenizer=tokenizer, model=bert_model).to(device)
+
+ dataLoader = SequenceClassificationDataLoader(
+ tokenizer=veracity_tokenizer,
+ data_file="this_is_discontinued",
+ batch_size=32,
+ add_extra_nee=False,
+ )
+
+ evidence_strings = []
+ for evidence in qa_evidence:
+ evidence_strings.append(
+ dataLoader.quadruple_to_string(claim, evidence.metadata["query"], evidence.metadata["answer"], ""))
+
+ if len(evidence_strings) == 0: # If we found no evidence e.g. because google returned 0 pages, just output NEI.
+ pred_label = "Not Enough Evidence"
+ return pred_label
+
+ tokenized_strings, attention_mask = dataLoader.tokenize_strings(evidence_strings)
+ example_support = torch.argmax(veracity_model(tokenized_strings.to(device), attention_mask=attention_mask.to(device)).logits, axis=1)
+
+ has_unanswerable = False
+ has_true = False
+ has_false = False
+
+ for v in example_support:
+ if v == 0:
+ has_true = True
+ if v == 1:
+ has_false = True
+ if v in (2, 3,): # TODO another hack -- we cant have different labels for train and test so we do this
+ has_unanswerable = True
+
+ if has_unanswerable:
+ answer = 2
+ elif has_true and not has_false:
+ answer = 0
+ elif not has_true and has_false:
+ answer = 1
+ else:
+ answer = 3
+
+ pred_label = LABEL[answer]
+
+ return pred_label
+
+
+def extract_claim_str(claim, qa_evidence, verdict_label):
+ claim_str = "[CLAIM] " + claim + " [EVIDENCE] "
+
+ for evidence in qa_evidence:
+ q_text = evidence.metadata['query'].strip()
+
+ if len(q_text) == 0:
+ continue
+
+ if not q_text[-1] == "?":
+ q_text += "?"
+
+ answer_strings = []
+ answer_strings.append(evidence.metadata['answer'])
+
+ claim_str += q_text
+ for a_text in answer_strings:
+ if a_text:
+ if not a_text[-1] == ".":
+ a_text += "."
+ claim_str += " " + a_text.strip()
+
+ claim_str += " "
+
+ claim_str += " [VERDICT] " + verdict_label
+
+ return claim_str
+
+
+def averitec_justification_generation(claim, qa_evidence, verdict_label):
+ #
+ claim_str = extract_claim_str(claim, qa_evidence, verdict_label)
+ claim_str.strip()
+
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
+ tokenizer = BartTokenizer.from_pretrained('facebook/bart-large', add_prefix_space=True)
+ bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large")
+
+ best_checkpoint = 'averitec/pretrained_models/bart_justifications_verdict-epoch=13-val_loss=2.03-val_meteor=0.28.ckpt'
+ trained_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=tokenizer,
+ model=bart_model).to(device)
+
+ pred_justification = trained_model.generate(claim_str, device=device)
+
+ return pred_justification.strip()
+
+
+def justification_generation(claim, qa_evidence, verdict_label):
+ #
+ claim_str = extract_claim_str(claim, qa_evidence, verdict_label)
+ claim_str.strip()
+
+ # device = "cuda:0" if torch.cuda.is_available() else "cpu"
+ # tokenizer = BartTokenizer.from_pretrained('facebook/bart-large', add_prefix_space=True)
+ # bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large")
+ #
+ # best_checkpoint = 'averitec/pretrained_models/bart_justifications_verdict-epoch=13-val_loss=2.03-val_meteor=0.28.ckpt'
+ # trained_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=tokenizer,
+ # model=bart_model).to(device)
+
+ pred_justification = justification_model.generate(claim_str, device=device)
+
+ return pred_justification.strip()
+
+
+def QAprediction(claim, evidence, sources):
+ parts = []
+ #
+ evidence_title = f"""Retrieved Evidence:
"""
+ for i, evi in enumerate(evidence, 1):
+ part = f"""Doc {i}"""
+ subpart = f"""{i}"""
+ # subpart = f"""{i}"""
+ subparts = "".join([part, subpart])
+ parts.append(subparts)
+
+ evidence_part = ", ".join(parts)
+
+ prediction_title = f"""Prediction:
"""
+ # if 'Google' in sources or 'AVeriTeC' in sources:
+ # verdict_label = averitec_veracity_prediction(claim, evidence)
+ # justification_label = averitec_justification_generation(claim, evidence, verdict_label)
+ # # justification_label = "See retrieved docs."
+ # justification_part = f"""Justification: {justification_label}"""
+ # if 'WikiPedia' in sources:
+ # # verdict_label = fever_veracity_prediction(claim, evidence)
+ # justification_label = averitec_justification_generation(claim, evidence, verdict_label)
+ # # justification_label = "See retrieved docs."
+ # justification_part = f"""Justification: {justification_label}"""
+
+ verdict_label = veracity_prediction(claim, evidence)
+ justification_label = justification_generation(claim, evidence, verdict_label)
+ # justification_label = "See retrieved docs."
+ justification_part = f"""Justification: {justification_label}"""
+
+
+ verdict_part = f"""Verdict: {verdict_label}.
"""
+
+ content_parts = "".join([evidence_title, evidence_part, prediction_title, verdict_part, justification_part])
+ # content_parts = "".join([evidence_title, evidence_part, verdict_title, verdict_part, justification_title, justification_part])
+
+ return content_parts, [verdict_label, justification_label]
+
+
+# ----------GoogleAPIretriever---------
+def generate_reference_corpus(reference_file):
+ with open(reference_file) as f:
+ j = json.load(f)
+ train_examples = j
+
+ all_data_corpus = []
+ tokenized_corpus = []
+
+ for train_example in train_examples:
+ train_claim = train_example["claim"]
+
+ speaker = train_example["speaker"].strip() if train_example["speaker"] is not None and len(
+ train_example["speaker"]) > 1 else "they"
+
+ questions = [q["question"] for q in train_example["questions"]]
+
+ claim_dict_builder = {}
+ claim_dict_builder["claim"] = train_claim
+ claim_dict_builder["speaker"] = speaker
+ claim_dict_builder["questions"] = questions
+
+ tokenized_corpus.append(nltk.word_tokenize(claim_dict_builder["claim"]))
+ all_data_corpus.append(claim_dict_builder)
+
+ return tokenized_corpus, all_data_corpus
+
+
+def doc2prompt(doc):
+ prompt_parts = "Outrageously, " + doc["speaker"] + " claimed that \"" + doc[
+ "claim"].strip() + "\". Criticism includes questions like: "
+ questions = [q.strip() for q in doc["questions"]]
+ return prompt_parts + " ".join(questions)
+
+
+def docs2prompt(top_docs):
+ return "\n\n".join([doc2prompt(d) for d in top_docs])
+
+
+def prompt_question_generation(test_claim, speaker="they", topk=10):
+ #
+ reference_file = "averitec_code/data/train.json"
+ tokenized_corpus, all_data_corpus = generate_reference_corpus(reference_file)
+ bm25 = BM25Okapi(tokenized_corpus)
+
+ # Define the bloom model:
+ accelerator = Accelerator()
+ accel_device = accelerator.device
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
+ tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-7b1")
+ model = BloomForCausalLM.from_pretrained("bigscience/bloom-7b1", torch_dtype=torch.bfloat16).to(device)
+
+ # --------------------------------------------------
+ # test claim
+ s = bm25.get_scores(nltk.word_tokenize(test_claim))
+ top_n = np.argsort(s)[::-1][:topk]
+ docs = [all_data_corpus[i] for i in top_n]
+ # --------------------------------------------------
+
+ prompt = docs2prompt(docs) + "\n\n" + "Outrageously, " + speaker + " claimed that \"" + test_claim.strip() + \
+ "\". Criticism includes questions like: "
+ sentences = [prompt]
+
+ inputs = tokenizer(sentences, padding=True, return_tensors="pt").to(device)
+ outputs = model.generate(inputs["input_ids"], max_length=2000, num_beams=2, no_repeat_ngram_size=2,
+ early_stopping=True)
+
+ tgt_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
+ in_len = len(sentences[0])
+ questions_str = tgt_text[in_len:].split("\n")[0]
+
+ qs = questions_str.split("?")
+ qs = [q.strip() + "?" for q in qs if q.strip() and len(q.strip()) < 300]
+
+ #
+ generate_question = [{"question": q, "answers": []} for q in qs]
+
+ return generate_question
+
+
+def check_claim_date(check_date):
+ try:
+ year, month, date = check_date.split("-")
+ except:
+ month, date, year = "01", "01", "2022"
+
+ if len(year) == 2 and int(year) <= 30:
+ year = "20" + year
+ elif len(year) == 2:
+ year = "19" + year
+ elif len(year) == 1:
+ year = "200" + year
+
+ if len(month) == 1:
+ month = "0" + month
+
+ if len(date) == 1:
+ date = "0" + date
+
+ sort_date = year + month + date
+
+ return sort_date
+
+
+def string_to_search_query(text, author):
+ parts = word_tokenize(text.strip())
+ tags = pos_tag(parts)
+
+ keep_tags = ["CD", "JJ", "NN", "VB"]
+
+ if author is not None:
+ search_string = author.split()
+ else:
+ search_string = []
+
+ for token, tag in zip(parts, tags):
+ for keep_tag in keep_tags:
+ if tag[1].startswith(keep_tag):
+ search_string.append(token)
+
+ search_string = " ".join(search_string)
+ return search_string
+
+
+def google_search(search_term, api_key, cse_id, **kwargs):
+ service = build("customsearch", "v1", developerKey=api_key)
+ res = service.cse().list(q=search_term, cx=cse_id, **kwargs).execute()
+
+ if "items" in res:
+ return res['items']
+ else:
+ return []
+
+
+def get_domain_name(url):
+ if '://' not in url:
+ url = 'http://' + url
+
+ domain = urlparse(url).netloc
+
+ if domain.startswith("www."):
+ return domain[4:]
+ else:
+ return domain
+
+
+def get_and_store(url_link, fp, worker, worker_stack):
+ page_lines = url2lines(url_link)
+
+ with open(fp, "w") as out_f:
+ print("\n".join([url_link] + page_lines), file=out_f)
+
+ worker_stack.append(worker)
+ gc.collect()
+
+
+def get_google_search_results(api_key, search_engine_id, google_search, sort_date, search_string, page=0):
+ search_results = []
+ for i in range(3):
+ try:
+ search_results += google_search(
+ search_string,
+ api_key,
+ search_engine_id,
+ num=10,
+ start=0 + 10 * page,
+ sort="date:r:19000101:" + sort_date,
+ dateRestrict=None,
+ gl="US"
+ )
+ break
+ except:
+ sleep(3)
+
+ return search_results
+
+
+def averitec_search(claim, generate_question, speaker="they", check_date="2024-01-01", n_pages=1): # n_pages=3
+ # default config
+ api_key = os.environ["GOOGLE_API_KEY"]
+ search_engine_id = os.environ["GOOGLE_SEARCH_ENGINE_ID"]
+
+ blacklist = [
+ "jstor.org", # Blacklisted because their pdfs are not labelled as such, and clog up the download
+ "facebook.com", # Blacklisted because only post titles can be scraped, but the scraper doesn't know this,
+ "ftp.cs.princeton.edu", # Blacklisted because it hosts many large NLP corpora that keep showing up
+ "nlp.cs.princeton.edu",
+ "huggingface.co"
+ ]
+
+ blacklist_files = [ # Blacklisted some NLP nonsense that crashes my machine with OOM errors
+ "/glove.",
+ "ftp://ftp.cs.princeton.edu/pub/cs226/autocomplete/words-333333.txt",
+ "https://web.mit.edu/adamrose/Public/googlelist",
+ ]
+
+ # save to folder
+ store_folder = "averitec_code/store/retrieved_docs"
+ #
+ index = 0
+ questions = [q["question"] for q in generate_question]
+
+ # check the date of the claim
+ sort_date = check_claim_date(check_date) # check_date="2022-01-01"
+
+ #
+ search_strings = []
+ search_types = []
+
+ search_string_2 = string_to_search_query(claim, None)
+ search_strings += [search_string_2, claim, ]
+ search_types += ["claim", "claim-noformat", ]
+
+ search_strings += questions
+ search_types += ["question" for _ in questions]
+
+ # start to search
+ search_results = []
+ visited = {}
+ store_counter = 0
+ worker_stack = list(range(10))
+
+ retrieve_evidence = []
+
+ for this_search_string, this_search_type in zip(search_strings, search_types):
+ for page_num in range(n_pages):
+ search_results = get_google_search_results(api_key, search_engine_id, google_search, sort_date,
+ this_search_string, page=page_num)
+
+ for result in search_results:
+ link = str(result["link"])
+ domain = get_domain_name(link)
+
+ if domain in blacklist:
+ continue
+ broken = False
+ for b_file in blacklist_files:
+ if b_file in link:
+ broken = True
+ if broken:
+ continue
+ if link.endswith(".pdf") or link.endswith(".doc"):
+ continue
+
+ store_file_path = ""
+
+ if link in visited:
+ store_file_path = visited[link]
+ else:
+ store_counter += 1
+ store_file_path = store_folder + "/search_result_" + str(index) + "_" + str(
+ store_counter) + ".store"
+ visited[link] = store_file_path
+
+ while len(worker_stack) == 0: # Wait for a wrrker to become available. Check every second.
+ sleep(1)
+
+ worker = worker_stack.pop()
+
+ t = threading.Thread(target=get_and_store, args=(link, store_file_path, worker, worker_stack))
+ t.start()
+
+ line = [str(index), claim, link, str(page_num), this_search_string, this_search_type, store_file_path]
+ retrieve_evidence.append(line)
+
+ return retrieve_evidence
+
+
+def claim2prompts(example):
+ claim = example["claim"]
+
+ # claim_str = "Claim: " + claim + "||Evidence: "
+ claim_str = "Evidence: "
+
+ for question in example["questions"]:
+ q_text = question["question"].strip()
+ if len(q_text) == 0:
+ continue
+
+ if not q_text[-1] == "?":
+ q_text += "?"
+
+ answer_strings = []
+
+ for a in question["answers"]:
+ if a["answer_type"] in ["Extractive", "Abstractive"]:
+ answer_strings.append(a["answer"])
+ if a["answer_type"] == "Boolean":
+ answer_strings.append(a["answer"] + ", because " + a["boolean_explanation"].lower().strip())
+
+ for a_text in answer_strings:
+ if not a_text[-1] in [".", "!", ":", "?"]:
+ a_text += "."
+
+ # prompt_lookup_str = claim + " " + a_text
+ prompt_lookup_str = a_text
+ this_q_claim_str = claim_str + " " + a_text.strip() + "||Question answered: " + q_text
+ yield (prompt_lookup_str, this_q_claim_str.replace("\n", " ").replace("||", "\n"))
+
+
+def generate_step2_reference_corpus(reference_file):
+ with open(reference_file) as f:
+ train_examples = json.load(f)
+
+ prompt_corpus = []
+ tokenized_corpus = []
+
+ for example in train_examples:
+ for lookup_str, prompt in claim2prompts(example):
+ entry = nltk.word_tokenize(lookup_str)
+ tokenized_corpus.append(entry)
+ prompt_corpus.append(prompt)
+
+ return tokenized_corpus, prompt_corpus
+
+
+def decorate_with_questions(claim, retrieve_evidence, top_k=10): # top_k=100
+ #
+ reference_file = "averitec_code/data/train.json"
+ tokenized_corpus, prompt_corpus = generate_step2_reference_corpus(reference_file)
+ prompt_bm25 = BM25Okapi(tokenized_corpus)
+
+ # Define the bloom model:
+ accelerator = Accelerator()
+ accel_device = accelerator.device
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
+ tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-7b1")
+ model = BloomForCausalLM.from_pretrained(
+ "bigscience/bloom-7b1",
+ device_map="auto",
+ torch_dtype=torch.bfloat16,
+ offload_folder="./offload"
+ )
+
+ #
+ tokenized_corpus = []
+ all_data_corpus = []
+
+ for retri_evi in tqdm.tqdm(retrieve_evidence):
+ store_file = retri_evi[-1]
+
+ with open(store_file, 'r') as f:
+ first = True
+ for line in f:
+ line = line.strip()
+
+ if first:
+ first = False
+ location_url = line
+ continue
+
+ if len(line) > 3:
+ entry = nltk.word_tokenize(line)
+ if (location_url, line) not in all_data_corpus:
+ tokenized_corpus.append(entry)
+ all_data_corpus.append((location_url, line))
+
+ if len(tokenized_corpus) == 0:
+ print("")
+
+ bm25 = BM25Okapi(tokenized_corpus)
+ s = bm25.get_scores(nltk.word_tokenize(claim))
+ top_n = np.argsort(s)[::-1][:top_k]
+ docs = [all_data_corpus[i] for i in top_n]
+
+ generate_qa_pairs = []
+ # Then, generate questions for those top 50:
+ for doc in tqdm.tqdm(docs):
+ # prompt_lookup_str = example["claim"] + " " + doc[1]
+ prompt_lookup_str = doc[1]
+
+ prompt_s = prompt_bm25.get_scores(nltk.word_tokenize(prompt_lookup_str))
+ prompt_n = 10
+ prompt_top_n = np.argsort(prompt_s)[::-1][:prompt_n]
+ prompt_docs = [prompt_corpus[i] for i in prompt_top_n]
+
+ claim_prompt = "Evidence: " + doc[1].replace("\n", " ") + "\nQuestion answered: "
+ prompt = "\n\n".join(prompt_docs + [claim_prompt])
+ sentences = [prompt]
+
+ inputs = tokenizer(sentences, padding=True, return_tensors="pt").to(device)
+ outputs = model.generate(inputs["input_ids"], max_length=5000, num_beams=2, no_repeat_ngram_size=2,
+ early_stopping=True)
+
+ tgt_text = tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:], skip_special_tokens=True)[0]
+ # We are not allowed to generate more than 250 characters:
+ tgt_text = tgt_text[:250]
+
+ qa_pair = [tgt_text.strip().split("?")[0].replace("\n", " ") + "?", doc[1].replace("\n", " "), doc[0]]
+ generate_qa_pairs.append(qa_pair)
+
+ return generate_qa_pairs
+
+
+def triple_to_string(x):
+ return " ".join([item.strip() for item in x])
+
+
+def rerank_questions(claim, bm25_qas, topk=3):
+ #
+ tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
+ bert_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2,
+ problem_type="single_label_classification") # Must specify single_label for some reason
+ best_checkpoint = "averitec_code/pretrained_models/bert_dual_encoder.ckpt"
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
+ trained_model = DualEncoderModule.load_from_checkpoint(best_checkpoint, tokenizer=tokenizer, model=bert_model).to(
+ device)
+
+ #
+ strs_to_score = []
+ values = []
+
+ for question, answer, source in bm25_qas:
+ str_to_score = triple_to_string([claim, question, answer])
+
+ strs_to_score.append(str_to_score)
+ values.append([question, answer, source])
+
+ if len(bm25_qas) > 0:
+ encoded_dict = tokenizer(strs_to_score, max_length=512, padding="longest", truncation=True,
+ return_tensors="pt").to(device)
+
+ input_ids = encoded_dict['input_ids']
+ attention_masks = encoded_dict['attention_mask']
+
+ scores = torch.softmax(trained_model(input_ids, attention_mask=attention_masks).logits, axis=-1)[:, 1]
+
+ top_n = torch.argsort(scores, descending=True)[:topk]
+ pass_through = [{"question": values[i][0], "answers": values[i][1], "source_url": values[i][2]} for i in top_n]
+ else:
+ pass_through = []
+
+ top3_qa_pairs = pass_through
+
+ return top3_qa_pairs
+
+
+def GoogleAPIretriever(query):
+ # ----- Generate QA pairs using AVeriTeC
+ top3_qa_pairs_path = "averitec_code/top3_qa_pairs1.json"
+ if not os.path.exists(top3_qa_pairs_path):
+ # step 1: generate questions for the query/claim using Bloom
+ generate_question = prompt_question_generation(query)
+ # step 2: retrieve evidence for the generated questions using Google API
+ retrieve_evidence = averitec_search(query, generate_question)
+ # step 3: generate QA pairs for each retrieved document
+ bm25_qa_pairs = decorate_with_questions(query, retrieve_evidence)
+ # step 4: rerank QA pairs
+ top3_qa_pairs = rerank_questions(query, bm25_qa_pairs)
+ else:
+ top3_qa_pairs = json.load(open(top3_qa_pairs_path, 'r'))
+
+ # Add score to metadata
+ results = []
+ for i, qa in enumerate(top3_qa_pairs):
+ metadata = dict()
+
+ metadata['name'] = qa['question']
+ metadata['url'] = qa['source_url']
+ metadata['cached_source_url'] = qa['source_url']
+ metadata['short_name'] = "Evidence {}".format(i + 1)
+ metadata['page_number'] = ""
+ metadata['query'] = qa['question']
+ metadata['answer'] = qa['answers']
+ metadata['page_content'] = "Question: " + qa['question'] + "
" + "Answer: " + qa['answers']
+ page_content = f"""{metadata['page_content']}"""
+ results.append((metadata, page_content))
+
+ return results
+
+
+# ----------GoogleAPIretriever---------
+
+# ----------Wikipediaretriever---------
+def bm25_retriever(query, corpus, topk=3):
+ bm25 = BM25Okapi(corpus)
+ #
+ query_tokens = word_tokenize(query)
+ scores = bm25.get_scores(query_tokens)
+ top_n = np.argsort(scores)[::-1][:topk]
+ top_n_scores = [scores[i] for i in top_n]
+
+ return top_n, top_n_scores
+
+
+def bm25s_retriever(query, corpus, topk=3):
+ # optional: create a stemmer
+ stemmer = Stemmer.Stemmer("english")
+ # Tokenize the corpus and only keep the ids (faster and saves memory)
+ corpus_tokens = bm25s.tokenize(corpus, stopwords="en", stemmer=stemmer)
+ # Create the BM25 model and index the corpus
+ retriever = bm25s.BM25()
+ retriever.index(corpus_tokens)
+ # Query the corpus
+ query_tokens = bm25s.tokenize(query, stemmer=stemmer)
+ # Get top-k results as a tuple of (doc ids, scores). Both are arrays of shape (n_queries, k)
+ results, scores = retriever.retrieve(query_tokens, corpus=corpus, k=topk)
+ top_n = [corpus.index(res) for res in results[0]]
+ return top_n, scores
+
+
+def find_evidence_from_wikipedia_dumps(claim):
+ #
+ doc = nlp(claim)
+ entities_in_claim = [str(ent).lower() for ent in doc.ents]
+ title2id = ranker.doc_dict[0]
+ wiki_intro, ent_list = [], []
+ for ent in entities_in_claim:
+ if ent in title2id.keys():
+ ids = title2id[ent]
+ introduction = doc_db.get_doc_intro(ids)
+ wiki_intro.append([ent, introduction])
+ # fulltext = doc_db.get_doc_text(ids)
+ # evidence.append([ent, fulltext])
+ ent_list.append(ent)
+
+ if len(wiki_intro) < 5:
+ evidence_tfidf = process_topk(claim, title2id, ent_list, k=5)
+ wiki_intro.extend(evidence_tfidf)
+
+ return wiki_intro, doc
+
+
+def relevant_sentence_retrieval(query, wiki_intro, k):
+ # 1. Create corpus here
+ corpus, sentences = [], []
+ titles = []
+ for i, (title, intro) in enumerate(wiki_intro):
+ sents_in_intro = sent_tokenize(intro)
+
+ for sent in sents_in_intro:
+ corpus.append(word_tokenize(sent))
+ sentences.append(sent)
+ titles.append(title)
+ #
+ # ----- BM25
+ bm25_top_n, bm25_top_n_scores = bm25_retriever(query, corpus, topk=k)
+ bm25_top_n_sents = [sentences[i] for i in bm25_top_n]
+ bm25_top_n_titles = [titles[i] for i in bm25_top_n]
+
+ # ----- BM25s
+ # bm25s_top_n, bm25s_top_n_scores = bm25s_retriever(query, sentences, topk=k) # corpus->sentences
+ # bm25s_top_n_sents = [sentences[i] for i in bm25s_top_n]
+ # bm25s_top_n_titles = [titles[i] for i in bm25s_top_n]
+
+ return bm25_top_n_sents, bm25_top_n_titles
+
+
+def process_topk(query, title2id, ent_list, k=1):
+ doc_names, doc_scores = ranker.closest_docs(query, k)
+ evidence_tfidf = []
+
+ for _name in doc_names:
+ if _name not in ent_list and len(ent_list) < 5:
+ ent_list.append(_name)
+ idx = title2id[_name]
+ introduction = doc_db.get_doc_intro(idx)
+ evidence_tfidf.append([_name, introduction])
+ # fulltext = doc_db.get_doc_text(idx)
+ # evidence_tfidf.append([_name,fulltext])
+
+ return evidence_tfidf
+
+
+def WikipediaDumpsretriever(claim):
+ #
+ # 1. extract relevant wikipedia pages from wikipedia dumps
+ wiki_intro, doc = find_evidence_from_wikipedia_dumps(claim)
+ # wiki_intro = [['trump', "'''Trump''' most commonly refers to:\n* Donald Trump (born 1946), President of the United States from 2017 to 2021 \n* Trump (card games), any playing card given an ad-hoc high rank\n\n'''Trump''' may also refer to:"]]
+
+ # 2. extract relevant sentences from extracted wikipedia pages
+ sents, titles = relevant_sentence_retrieval(claim, wiki_intro, k=3)
+
+ #
+ results = []
+ for i, (sent, title) in enumerate(zip(sents, titles)):
+ metadata = dict()
+ metadata['name'] = claim
+ metadata['url'] = "https://en.wikipedia.org/wiki/" + "_".join(title.split())
+ metadata['cached_source_url'] = "https://en.wikipedia.org/wiki/" + "_".join(title.split())
+ metadata['short_name'] = "Evidence {}".format(i + 1)
+ metadata['page_number'] = ""
+ metadata['query'] = sent
+ metadata['title'] = title
+ metadata['evidence'] = sent
+ metadata['answer'] = ""
+ metadata['page_content'] = "Title: " + str(metadata['title']) + "
" + "Evidence: " + metadata[
+ 'evidence']
+ page_content = f"""{metadata['page_content']}"""
+
+ results.append(Docs(metadata, page_content))
+
+ return results
+
+# ----------WikipediaAPIretriever---------
+def clean_str(p):
+ return p.encode().decode("unicode-escape").encode("latin1").decode("utf-8")
+
+
+def get_page_obs(page):
+ # find all paragraphs
+ paragraphs = page.split("\n")
+ paragraphs = [p.strip() for p in paragraphs if p.strip()]
+
+ # # find all sentence
+ # sentences = []
+ # for p in paragraphs:
+ # sentences += p.split('. ')
+ # sentences = [s.strip() + '.' for s in sentences if s.strip()]
+ # # return ' '.join(sentences[:5])
+ # return ' '.join(sentences)
+
+ return ' '.join(paragraphs[:5])
+
+
+def search_entity_wikipeida(entity):
+ find_evidence = []
+
+ page_py = wiki_wiki.page(entity)
+ if page_py.exists():
+ introduction = page_py.summary
+
+ find_evidence.append([str(entity), introduction])
+
+ return find_evidence
+
+
+def search_step(entity):
+ ent_ = entity.replace(" ", "+")
+ search_url = f"https://en.wikipedia.org/w/index.php?search={ent_}"
+ response_text = requests.get(search_url).text
+ soup = BeautifulSoup(response_text, features="html.parser")
+ result_divs = soup.find_all("div", {"class": "mw-search-result-heading"})
+
+ find_evidence = []
+
+ if result_divs: # mismatch
+ # If the wikipeida page of the entity is not exist, find similar wikipedia pages.
+ result_titles = [clean_str(div.get_text().strip()) for div in result_divs]
+ similar_titles = result_titles[:5]
+
+ for _t in similar_titles:
+ if len(find_evidence) < 5:
+ _evi = search_step(_t)
+ find_evidence.extend(_evi)
+ else:
+ page = [p.get_text().strip() for p in soup.find_all("p") + soup.find_all("ul")]
+ if any("may refer to:" in p for p in page):
+ _evi = search_step("[" + entity + "]")
+ find_evidence.extend(_evi)
+ else:
+ # page_py = wiki_wiki.page(entity)
+ #
+ # if page_py.exists():
+ # introduction = page_py.summary
+ # else:
+ page_text = ""
+ for p in page:
+ if len(p.split(" ")) > 2:
+ page_text += clean_str(p)
+ if not p.endswith("\n"):
+ page_text += "\n"
+ introduction = get_page_obs(page_text)
+
+ find_evidence.append([entity, introduction])
+
+ return find_evidence
+
+
+def find_similar_wikipedia(entity, relevant_wikipages):
+ # If the relevant wikipeida page of the entity is less than 5, find similar wikipedia pages.
+ ent_ = entity.replace(" ", "+")
+ search_url = f"https://en.wikipedia.org/w/index.php?search={ent_}&title=Special:Search&profile=advanced&fulltext=1&ns0=1"
+ response_text = requests.get(search_url).text
+ soup = BeautifulSoup(response_text, features="html.parser")
+ result_divs = soup.find_all("div", {"class": "mw-search-result-heading"})
+
+ if result_divs:
+ result_titles = [clean_str(div.get_text().strip()) for div in result_divs]
+ similar_titles = result_titles[:5]
+
+ saved_titles = [ent[0] for ent in relevant_wikipages] if relevant_wikipages else relevant_wikipages
+ for _t in similar_titles:
+ if _t not in saved_titles and len(relevant_wikipages) < 5:
+ _evi = search_entity_wikipeida(_t)
+ # _evi = search_step(_t)
+ relevant_wikipages.extend(_evi)
+
+ return relevant_wikipages
+
+
+def find_evidence_from_wikipedia(claim):
+ #
+ doc = nlp(claim)
+ #
+ wikipedia_page = []
+ for ent in doc.ents:
+ relevant_wikipages = search_entity_wikipeida(ent)
+
+ if len(relevant_wikipages) < 5:
+ relevant_wikipages = find_similar_wikipedia(str(ent), relevant_wikipages)
+
+ wikipedia_page.extend(relevant_wikipages)
+
+ return wikipedia_page
+
+
+def relevant_wikipedia_API_retriever(claim):
+ #
+ doc = nlp(claim)
+
+ wiki_intro = []
+ for ent in doc.ents:
+ page_py = wiki_wiki.page(ent)
+
+ if page_py.exists():
+ introduction = page_py.summary
+ else:
+ introduction = "No documents found."
+
+ wiki_intro.append([str(ent), introduction])
+
+ return wiki_intro, doc
+
+
+def Wikipediaretriever(claim, sources):
+ #
+ # 1. extract relevant wikipedia pages from wikipedia dumps
+ if "Dump" in sources:
+ wikipedia_page = find_evidence_from_wikipedia_dumps(claim)
+ else:
+ wikipedia_page = find_evidence_from_wikipedia(claim)
+ # wiki_intro, doc = relevant_wikipedia_API_retriever(claim)
+
+ # 2. extract relevant sentences from extracted wikipedia pages
+ sents, titles = relevant_sentence_retrieval(claim, wikipedia_page, k=3)
+
+ #
+ results = []
+ for i, (sent, title) in enumerate(zip(sents, titles)):
+ metadata = dict()
+ metadata['name'] = claim
+ metadata['url'] = "https://en.wikipedia.org/wiki/" + "_".join(title.split())
+ metadata['cached_source_url'] = "https://en.wikipedia.org/wiki/" + "_".join(title)
+ metadata['short_name'] = "Evidence {}".format(i + 1)
+ metadata['page_number'] = ""
+ metadata['query'] = sent
+ metadata['title'] = title
+ metadata['evidence'] = sent
+ metadata['answer'] = ""
+ metadata['page_content'] = "Title: " + str(metadata['title']) + "
" + "Evidence: " + metadata['evidence']
+ page_content = f"""{metadata['page_content']}"""
+
+ results.append(Docs(metadata, page_content))
+
+ return results
+
+
+def log_on_azure(file, logs, azure_share_client):
+ logs = json.dumps(logs)
+ file_client = azure_share_client.get_file_client(file)
+ file_client.upload_file(logs)
+
+
+def chat(claim, history, sources):
+ evidence = []
+ # if 'Google' in sources:
+ # evidence = GoogleAPIretriever(query)
+
+ # if 'WikiPediaDumps' in sources:
+ # evidence = WikipediaDumpsretriever(query)
+
+ if 'WikiPedia' in sources:
+ evidence = Wikipediaretriever(claim, sources)
+
+ answer_set, answer_output = QAprediction(claim, evidence, sources)
+
+ docs_html = ""
+ if len(evidence) > 0:
+ docs_html = []
+ for i, evi in enumerate(evidence, 1):
+ docs_html.append(make_html_source(evi, i))
+ docs_html = "".join(docs_html)
+ else:
+ print("No documents found")
+
+ url_of_evidence = ""
+ output_language = "English"
+ output_query = claim
+ history[-1] = (claim, answer_set)
+ history = [tuple(x) for x in history]
+
+ ############################################################
+ evi_list = []
+ for evi in evidence:
+ title_str = evi.metadata['title']
+ evi_str = evi.metadata['evidence']
+ evi_list.append([title_str, evi_str])
+
+ try:
+ # Log answer on Azure Blob Storage
+ # IF AZURE_ISSAVE=TRUE, save the logs into the Azure share client.
+ if bool(os.environ["AZURE_ISSAVE"]):
+ timestamp = str(datetime.now().timestamp())
+ # timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
+ file = timestamp + ".json"
+ logs = {
+ "user_id": str(user_id),
+ "claim": claim,
+ "sources": sources,
+ "evidence": evi_list,
+ "url": url_of_evidence,
+ "answer": answer_output,
+ "time": timestamp,
+ }
+ log_on_azure(file, logs, azure_share_client)
+ except Exception as e:
+ print(f"Error logging on Azure Blob Storage: {e}")
+ raise gr.Error(
+ f"AVeriTeC Error: {str(e)[:100]} - The error has been noted, try another question and if the error remains, you can contact us :)")
+ ##########
+
+ return history, docs_html, output_query, output_language
+
+
+def main():
+ init_prompt = """
+ Hello, I am a fact-checking assistant designed to help you find appropriate evidence to predict the veracity of claims.
+
+ What do you want to fact-check?
+ """
+
+ with gr.Blocks(title="AVeriTeC fact-checker", css="style.css", theme=theme, elem_id="main-component") as demo:
+ with gr.Tab("AVeriTeC"):
+ with gr.Row(elem_id="chatbot-row"):
+ with gr.Column(scale=2):
+ chatbot = gr.Chatbot(
+ value=[(None, init_prompt)],
+ show_copy_button=True, show_label=False, elem_id="chatbot", layout="panel",
+ avatar_images=(None, "assets/averitec.png")
+ ) # avatar_images=(None, "https://i.ibb.co/YNyd5W2/logo4.png"),
+
+ with gr.Row(elem_id="input-message"):
+ textbox = gr.Textbox(placeholder="Ask me what claim do you want to check!", show_label=False,
+ scale=7, lines=1, interactive=True, elem_id="input-textbox")
+ # submit = gr.Button("",elem_id = "submit-button",scale = 1,interactive = True,icon = "https://static-00.iconduck.com/assets.00/settings-icon-2048x2046-cw28eevx.png")
+
+ with gr.Column(scale=1, variant="panel", elem_id="right-panel"):
+ with gr.Tabs() as tabs:
+ with gr.TabItem("Examples", elem_id="tab-examples", id=0):
+ examples_hidden = gr.Textbox(visible=False)
+ first_key = list(CLAIMS_Type.keys())[0]
+ dropdown_samples = gr.Dropdown(CLAIMS_Type.keys(), value=first_key, interactive=True,
+ show_label=True,
+ label="Select claim type",
+ elem_id="dropdown-samples")
+
+ samples = []
+ for i, key in enumerate(CLAIMS_Type.keys()):
+ examples_visible = True if i == 0 else False
+
+ with gr.Row(visible=examples_visible) as group_examples:
+ examples_questions = gr.Examples(
+ CLAIMS_Type[key],
+ [examples_hidden],
+ examples_per_page=8,
+ run_on_click=False,
+ elem_id=f"examples{i}",
+ api_name=f"examples{i}",
+ # label = "Click on the example question or enter your own",
+ # cache_examples=True,
+ )
+
+ samples.append(group_examples)
+
+ with gr.Tab("Sources", elem_id="tab-citations", id=1):
+ sources_textbox = gr.HTML(show_label=False, elem_id="sources-textbox")
+ docs_textbox = gr.State("")
+
+ with gr.Tab("Configuration", elem_id="tab-config", id=2):
+ gr.Markdown("Reminder: We currently only support fact-checking in English!")
+
+ # dropdown_sources = gr.Radio(
+ # ["AVeriTeC", "WikiPediaDumps", "Google", "WikiPediaAPI"],
+ # label="Select source",
+ # value="WikiPediaAPI",
+ # interactive=True,
+ # )
+
+ dropdown_sources = gr.Radio(
+ ["Google", "WikiPedia"],
+ label="Select source",
+ value="WikiPedia",
+ interactive=True,
+ )
+
+ dropdown_retriever = gr.Dropdown(
+ ["BM25", "BM25s"],
+ label="Select evidence retriever",
+ multiselect=False,
+ value="BM25",
+ interactive=True,
+ )
+
+ output_query = gr.Textbox(label="Query used for retrieval", show_label=True,
+ elem_id="reformulated-query", lines=2, interactive=False)
+ output_language = gr.Textbox(label="Language", show_label=True, elem_id="language", lines=1,
+ interactive=False)
+
+ with gr.Tab("About", elem_classes="max-height other-tabs"):
+ with gr.Row():
+ with gr.Column(scale=1):
+ gr.Markdown("See more info at [https://fever.ai/task.html](https://fever.ai/task.html)")
+
+ def start_chat(query, history):
+ history = history + [(query, None)]
+ history = [tuple(x) for x in history]
+ return (gr.update(interactive=False), gr.update(selected=1), history)
+
+ def finish_chat():
+ return (gr.update(interactive=True, value=""))
+
+ (textbox
+ .submit(start_chat, [textbox, chatbot], [textbox, tabs, chatbot], queue=False, api_name="start_chat_textbox")
+ .then(chat, [textbox, chatbot, dropdown_sources],
+ [chatbot, sources_textbox, output_query, output_language], concurrency_limit=8, api_name="chat_textbox")
+ .then(finish_chat, None, [textbox], api_name="finish_chat_textbox")
+ )
+
+ (examples_hidden
+ .change(start_chat, [examples_hidden, chatbot], [textbox, tabs, chatbot], queue=False,
+ api_name="start_chat_examples")
+ .then(chat, [examples_hidden, chatbot, dropdown_sources],
+ [chatbot, sources_textbox, output_query, output_language], concurrency_limit=8, api_name="chat_examples")
+ .then(finish_chat, None, [textbox], api_name="finish_chat_examples")
+ )
+
+ def change_sample_questions(key):
+ index = list(CLAIMS_Type.keys()).index(key)
+ visible_bools = [False] * len(samples)
+ visible_bools[index] = True
+ return [gr.update(visible=visible_bools[i]) for i in range(len(samples))]
+
+ dropdown_samples.change(change_sample_questions, dropdown_samples, samples)
+ demo.queue()
+
+ demo.launch(share=True)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/drqa/__init__.py b/drqa/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d281e4860c5888761bac0bf815f1981a75e770b
--- /dev/null
+++ b/drqa/__init__.py
@@ -0,0 +1,23 @@
+#!/usr/bin/env python3
+# Copyright 2017-present, Facebook, Inc.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import sys
+from pathlib import PosixPath
+
+if sys.version_info < (3, 5):
+ raise RuntimeError('DrQA supports Python 3.5 or higher.')
+
+DATA_DIR = (
+ os.getenv('DRQA_DATA') or
+ os.path.join(PosixPath(__file__).absolute().parents[1].as_posix(), 'data')
+)
+
+from . import tokenizers
+from . import reader
+from . import retriever
+from . import pipeline
diff --git a/drqa/__pycache__/__init__.cpython-38.pyc b/drqa/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8064258e6b8fa1ab0513dab86c64a68218acffcc
Binary files /dev/null and b/drqa/__pycache__/__init__.cpython-38.pyc differ
diff --git a/drqa/pipeline/__init__.py b/drqa/pipeline/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ddbf9d65b91676ff9e974d92005207c35997e29
--- /dev/null
+++ b/drqa/pipeline/__init__.py
@@ -0,0 +1,27 @@
+#!/usr/bin/env python3
+# Copyright 2017-present, Facebook, Inc.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+from ..tokenizers import CoreNLPTokenizer
+from ..retriever import TfidfDocRanker
+from ..retriever import DocDB
+from .. import DATA_DIR
+
+DEFAULTS = {
+ 'tokenizer': CoreNLPTokenizer,
+ 'ranker': TfidfDocRanker,
+ 'db': DocDB,
+ 'reader_model': os.path.join(DATA_DIR, 'reader/multitask.mdl'),
+}
+
+
+def set_default(key, value):
+ global DEFAULTS
+ DEFAULTS[key] = value
+
+
+from .drqa import DrQA
diff --git a/drqa/pipeline/__pycache__/__init__.cpython-38.pyc b/drqa/pipeline/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..57f4e5d10d8feabc451f2e799b14f1dfb806b6a8
Binary files /dev/null and b/drqa/pipeline/__pycache__/__init__.cpython-38.pyc differ
diff --git a/drqa/pipeline/__pycache__/drqa.cpython-38.pyc b/drqa/pipeline/__pycache__/drqa.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..69bb524d30ce019d35b6bb097c0a0fa77b3bbf09
Binary files /dev/null and b/drqa/pipeline/__pycache__/drqa.cpython-38.pyc differ
diff --git a/drqa/pipeline/drqa.py b/drqa/pipeline/drqa.py
new file mode 100644
index 0000000000000000000000000000000000000000..7830e9ee9a46a9d8f0f196c2a269becf1b158235
--- /dev/null
+++ b/drqa/pipeline/drqa.py
@@ -0,0 +1,312 @@
+#!/usr/bin/env python3
+# Copyright 2017-present, Facebook, Inc.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""Full DrQA pipeline."""
+
+import torch
+import regex
+import heapq
+import math
+import time
+import logging
+
+from multiprocessing import Pool as ProcessPool
+from multiprocessing.util import Finalize
+
+from ..reader.vector import batchify
+from ..reader.data import ReaderDataset, SortedBatchSampler
+from .. import reader
+from .. import tokenizers
+from . import DEFAULTS
+
+logger = logging.getLogger(__name__)
+
+
+# ------------------------------------------------------------------------------
+# Multiprocessing functions to fetch and tokenize text
+# ------------------------------------------------------------------------------
+
+PROCESS_TOK = None
+PROCESS_DB = None
+PROCESS_CANDS = None
+
+
+def init(tokenizer_class, tokenizer_opts, db_class, db_opts, candidates=None):
+ global PROCESS_TOK, PROCESS_DB, PROCESS_CANDS
+ PROCESS_TOK = tokenizer_class(**tokenizer_opts)
+ Finalize(PROCESS_TOK, PROCESS_TOK.shutdown, exitpriority=100)
+ PROCESS_DB = db_class(**db_opts)
+ Finalize(PROCESS_DB, PROCESS_DB.close, exitpriority=100)
+ PROCESS_CANDS = candidates
+
+
+def fetch_text(doc_id):
+ global PROCESS_DB
+ return PROCESS_DB.get_doc_text(doc_id)
+
+
+def tokenize_text(text):
+ global PROCESS_TOK
+ return PROCESS_TOK.tokenize(text)
+
+
+# ------------------------------------------------------------------------------
+# Main DrQA pipeline
+# ------------------------------------------------------------------------------
+
+
+class DrQA(object):
+ # Target size for squashing short paragraphs together.
+ # 0 = read every paragraph independently
+ # infty = read all paragraphs together
+ GROUP_LENGTH = 0
+
+ def __init__(
+ self,
+ reader_model=None,
+ embedding_file=None,
+ tokenizer=None,
+ fixed_candidates=None,
+ batch_size=128,
+ cuda=True,
+ data_parallel=False,
+ max_loaders=5,
+ num_workers=None,
+ db_config=None,
+ ranker_config=None
+ ):
+ """Initialize the pipeline.
+
+ Args:
+ reader_model: model file from which to load the DocReader.
+ embedding_file: if given, will expand DocReader dictionary to use
+ all available pretrained embeddings.
+ tokenizer: string option to specify tokenizer used on docs.
+ fixed_candidates: if given, all predictions will be constrated to
+ the set of candidates contained in the file. One entry per line.
+ batch_size: batch size when processing paragraphs.
+ cuda: whether to use the gpu.
+ data_parallel: whether to use multile gpus.
+ max_loaders: max number of async data loading workers when reading.
+ (default is fine).
+ num_workers: number of parallel CPU processes to use for tokenizing
+ and post processing resuls.
+ db_config: config for doc db.
+ ranker_config: config for ranker.
+ """
+ self.batch_size = batch_size
+ self.max_loaders = max_loaders
+ self.fixed_candidates = fixed_candidates is not None
+ self.cuda = cuda
+
+ logger.info('Initializing document ranker...')
+ ranker_config = ranker_config or {}
+ ranker_class = ranker_config.get('class', DEFAULTS['ranker'])
+ ranker_opts = ranker_config.get('options', {})
+ self.ranker = ranker_class(**ranker_opts)
+
+ logger.info('Initializing document reader...')
+ reader_model = reader_model or DEFAULTS['reader_model']
+ self.reader = reader.DocReader.load(reader_model, normalize=False)
+ if embedding_file:
+ logger.info('Expanding dictionary...')
+ words = reader.utils.index_embedding_words(embedding_file)
+ added = self.reader.expand_dictionary(words)
+ self.reader.load_embeddings(added, embedding_file)
+ if cuda:
+ self.reader.cuda()
+ if data_parallel:
+ self.reader.parallelize()
+
+ if not tokenizer:
+ tok_class = DEFAULTS['tokenizer']
+ else:
+ tok_class = tokenizers.get_class(tokenizer)
+ annotators = tokenizers.get_annotators_for_model(self.reader)
+ tok_opts = {'annotators': annotators}
+
+ # ElasticSearch is also used as backend if used as ranker
+ if hasattr(self.ranker, 'es'):
+ db_config = ranker_config
+ db_class = ranker_class
+ db_opts = ranker_opts
+ else:
+ db_config = db_config or {}
+ db_class = db_config.get('class', DEFAULTS['db'])
+ db_opts = db_config.get('options', {})
+
+ logger.info('Initializing tokenizers and document retrievers...')
+ self.num_workers = num_workers
+ self.processes = ProcessPool(
+ num_workers,
+ initializer=init,
+ initargs=(tok_class, tok_opts, db_class, db_opts, fixed_candidates)
+ )
+
+ def _split_doc(self, doc):
+ """Given a doc, split it into chunks (by paragraph)."""
+ curr = []
+ curr_len = 0
+ for split in regex.split(r'\n+', doc):
+ split = split.strip()
+ if len(split) == 0:
+ continue
+ # Maybe group paragraphs together until we hit a length limit
+ if len(curr) > 0 and curr_len + len(split) > self.GROUP_LENGTH:
+ yield ' '.join(curr)
+ curr = []
+ curr_len = 0
+ curr.append(split)
+ curr_len += len(split)
+ if len(curr) > 0:
+ yield ' '.join(curr)
+
+ def _get_loader(self, data, num_loaders):
+ """Return a pytorch data iterator for provided examples."""
+ dataset = ReaderDataset(data, self.reader)
+ sampler = SortedBatchSampler(
+ dataset.lengths(),
+ self.batch_size,
+ shuffle=False
+ )
+ loader = torch.utils.data.DataLoader(
+ dataset,
+ batch_size=self.batch_size,
+ sampler=sampler,
+ num_workers=num_loaders,
+ collate_fn=batchify,
+ pin_memory=self.cuda,
+ )
+ return loader
+
+ def process(self, query, candidates=None, top_n=1, n_docs=5,
+ return_context=False):
+ """Run a single query."""
+ predictions = self.process_batch(
+ [query], [candidates] if candidates else None,
+ top_n, n_docs, return_context
+ )
+ return predictions[0]
+
+ def process_batch(self, queries, candidates=None, top_n=1, n_docs=5,
+ return_context=False):
+ """Run a batch of queries (more efficient)."""
+ t0 = time.time()
+ logger.info('Processing %d queries...' % len(queries))
+ logger.info('Retrieving top %d docs...' % n_docs)
+
+ # Rank documents for queries.
+ if len(queries) == 1:
+ ranked = [self.ranker.closest_docs(queries[0], k=n_docs)]
+ else:
+ ranked = self.ranker.batch_closest_docs(
+ queries, k=n_docs, num_workers=self.num_workers
+ )
+ all_docids, all_doc_scores = zip(*ranked)
+
+ # Flatten document ids and retrieve text from database.
+ # We remove duplicates for processing efficiency.
+ flat_docids = list({d for docids in all_docids for d in docids})
+ did2didx = {did: didx for didx, did in enumerate(flat_docids)}
+ doc_texts = self.processes.map(fetch_text, flat_docids)
+
+ # Split and flatten documents. Maintain a mapping from doc (index in
+ # flat list) to split (index in flat list).
+ flat_splits = []
+ didx2sidx = []
+ for text in doc_texts:
+ splits = self._split_doc(text)
+ didx2sidx.append([len(flat_splits), -1])
+ for split in splits:
+ flat_splits.append(split)
+ didx2sidx[-1][1] = len(flat_splits)
+
+ # Push through the tokenizers as fast as possible.
+ q_tokens = self.processes.map_async(tokenize_text, queries)
+ s_tokens = self.processes.map_async(tokenize_text, flat_splits)
+ q_tokens = q_tokens.get()
+ s_tokens = s_tokens.get()
+
+ # Group into structured example inputs. Examples' ids represent
+ # mappings to their question, document, and split ids.
+ examples = []
+ for qidx in range(len(queries)):
+ for rel_didx, did in enumerate(all_docids[qidx]):
+ start, end = didx2sidx[did2didx[did]]
+ for sidx in range(start, end):
+ if (len(q_tokens[qidx].words()) > 0 and
+ len(s_tokens[sidx].words()) > 0):
+ examples.append({
+ 'id': (qidx, rel_didx, sidx),
+ 'question': q_tokens[qidx].words(),
+ 'qlemma': q_tokens[qidx].lemmas(),
+ 'document': s_tokens[sidx].words(),
+ 'lemma': s_tokens[sidx].lemmas(),
+ 'pos': s_tokens[sidx].pos(),
+ 'ner': s_tokens[sidx].entities(),
+ })
+
+ logger.info('Reading %d paragraphs...' % len(examples))
+
+ # Push all examples through the document reader.
+ # We decode argmax start/end indices asychronously on CPU.
+ result_handles = []
+ num_loaders = min(self.max_loaders, math.floor(len(examples) / 1e3))
+ for batch in self._get_loader(examples, num_loaders):
+ if candidates or self.fixed_candidates:
+ batch_cands = []
+ for ex_id in batch[-1]:
+ batch_cands.append({
+ 'input': s_tokens[ex_id[2]],
+ 'cands': candidates[ex_id[0]] if candidates else None
+ })
+ handle = self.reader.predict(
+ batch, batch_cands, async_pool=self.processes
+ )
+ else:
+ handle = self.reader.predict(batch, async_pool=self.processes)
+ result_handles.append((handle, batch[-1], batch[0].size(0)))
+
+ # Iterate through the predictions, and maintain priority queues for
+ # top scored answers for each question in the batch.
+ queues = [[] for _ in range(len(queries))]
+ for result, ex_ids, batch_size in result_handles:
+ s, e, score = result.get()
+ for i in range(batch_size):
+ # We take the top prediction per split.
+ if len(score[i]) > 0:
+ item = (score[i][0], ex_ids[i], s[i][0], e[i][0])
+ queue = queues[ex_ids[i][0]]
+ if len(queue) < top_n:
+ heapq.heappush(queue, item)
+ else:
+ heapq.heappushpop(queue, item)
+
+ # Arrange final top prediction data.
+ all_predictions = []
+ for queue in queues:
+ predictions = []
+ while len(queue) > 0:
+ score, (qidx, rel_didx, sidx), s, e = heapq.heappop(queue)
+ prediction = {
+ 'doc_id': all_docids[qidx][rel_didx],
+ 'span': s_tokens[sidx].slice(s, e + 1).untokenize(),
+ 'doc_score': float(all_doc_scores[qidx][rel_didx]),
+ 'span_score': float(score),
+ }
+ if return_context:
+ prediction['context'] = {
+ 'text': s_tokens[sidx].untokenize(),
+ 'start': s_tokens[sidx].offsets()[s][0],
+ 'end': s_tokens[sidx].offsets()[e][1],
+ }
+ predictions.append(prediction)
+ all_predictions.append(predictions[-1::-1])
+
+ logger.info('Processed %d queries in %.4f (s)' %
+ (len(queries), time.time() - t0))
+
+ return all_predictions
diff --git a/drqa/reader/__init__.py b/drqa/reader/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e9c6cc94bacc74e65061800b8ee00fd4eb78ef56
--- /dev/null
+++ b/drqa/reader/__init__.py
@@ -0,0 +1,28 @@
+#!/usr/bin/env python3
+# Copyright 2017-present, Facebook, Inc.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+from ..tokenizers import CoreNLPTokenizer
+from .. import DATA_DIR
+
+
+DEFAULTS = {
+ 'tokenizer': CoreNLPTokenizer,
+ 'model': os.path.join(DATA_DIR, 'reader/single.mdl'),
+}
+
+
+def set_default(key, value):
+ global DEFAULTS
+ DEFAULTS[key] = value
+
+from .model import DocReader
+from .predictor import Predictor
+from . import config
+from . import vector
+from . import data
+from . import utils
diff --git a/drqa/reader/__pycache__/__init__.cpython-38.pyc b/drqa/reader/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7538e4dd229b00dab82f8a5a50c03b5e4d270685
Binary files /dev/null and b/drqa/reader/__pycache__/__init__.cpython-38.pyc differ
diff --git a/drqa/reader/__pycache__/config.cpython-38.pyc b/drqa/reader/__pycache__/config.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1a9a1f0cae7bd1cfe8c72bb4d72e6af6df941147
Binary files /dev/null and b/drqa/reader/__pycache__/config.cpython-38.pyc differ
diff --git a/drqa/reader/__pycache__/data.cpython-38.pyc b/drqa/reader/__pycache__/data.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9b72e5a48fd606cbc29aae9b7b341edc47e27c53
Binary files /dev/null and b/drqa/reader/__pycache__/data.cpython-38.pyc differ
diff --git a/drqa/reader/__pycache__/layers.cpython-38.pyc b/drqa/reader/__pycache__/layers.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..457836a1d2c2823c327952bbee8b64f944c05c73
Binary files /dev/null and b/drqa/reader/__pycache__/layers.cpython-38.pyc differ
diff --git a/drqa/reader/__pycache__/model.cpython-38.pyc b/drqa/reader/__pycache__/model.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e486a981ffe2eb326ed03ec54358912ca0db51f9
Binary files /dev/null and b/drqa/reader/__pycache__/model.cpython-38.pyc differ
diff --git a/drqa/reader/__pycache__/predictor.cpython-38.pyc b/drqa/reader/__pycache__/predictor.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..757189fb0cddd9a58600c57421a5088943953752
Binary files /dev/null and b/drqa/reader/__pycache__/predictor.cpython-38.pyc differ
diff --git a/drqa/reader/__pycache__/rnn_reader.cpython-38.pyc b/drqa/reader/__pycache__/rnn_reader.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5377e66e076949b6ac6b5b01d70151789de24387
Binary files /dev/null and b/drqa/reader/__pycache__/rnn_reader.cpython-38.pyc differ
diff --git a/drqa/reader/__pycache__/utils.cpython-38.pyc b/drqa/reader/__pycache__/utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cf4748aa1b5c99b7d9f4483d0db9cb88c6d94b45
Binary files /dev/null and b/drqa/reader/__pycache__/utils.cpython-38.pyc differ
diff --git a/drqa/reader/__pycache__/vector.cpython-38.pyc b/drqa/reader/__pycache__/vector.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b9e2bd6baafc0a3ad43e1b46e70f0c6aaa53d113
Binary files /dev/null and b/drqa/reader/__pycache__/vector.cpython-38.pyc differ
diff --git a/drqa/reader/config.py b/drqa/reader/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d49b191f99b93da233da5db0bd4b6287aabf6f9
--- /dev/null
+++ b/drqa/reader/config.py
@@ -0,0 +1,128 @@
+#!/usr/bin/env python3
+# Copyright 2017-present, Facebook, Inc.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""Model architecture/optimization options for DrQA document reader."""
+
+import argparse
+import logging
+
+logger = logging.getLogger(__name__)
+
+# Index of arguments concerning the core model architecture
+MODEL_ARCHITECTURE = {
+ 'model_type', 'embedding_dim', 'hidden_size', 'doc_layers',
+ 'question_layers', 'rnn_type', 'concat_rnn_layers', 'question_merge',
+ 'use_qemb', 'use_in_question', 'use_pos', 'use_ner', 'use_lemma', 'use_tf'
+}
+
+# Index of arguments concerning the model optimizer/training
+MODEL_OPTIMIZER = {
+ 'fix_embeddings', 'optimizer', 'learning_rate', 'momentum', 'weight_decay',
+ 'rnn_padding', 'dropout_rnn', 'dropout_rnn_output', 'dropout_emb',
+ 'max_len', 'grad_clipping', 'tune_partial'
+}
+
+
+def str2bool(v):
+ return v.lower() in ('yes', 'true', 't', '1', 'y')
+
+
+def add_model_args(parser):
+ parser.register('type', 'bool', str2bool)
+
+ # Model architecture
+ model = parser.add_argument_group('DrQA Reader Model Architecture')
+ model.add_argument('--model-type', type=str, default='rnn',
+ help='Model architecture type')
+ model.add_argument('--embedding-dim', type=int, default=300,
+ help='Embedding size if embedding_file is not given')
+ model.add_argument('--hidden-size', type=int, default=128,
+ help='Hidden size of RNN units')
+ model.add_argument('--doc-layers', type=int, default=3,
+ help='Number of encoding layers for document')
+ model.add_argument('--question-layers', type=int, default=3,
+ help='Number of encoding layers for question')
+ model.add_argument('--rnn-type', type=str, default='lstm',
+ help='RNN type: LSTM, GRU, or RNN')
+
+ # Model specific details
+ detail = parser.add_argument_group('DrQA Reader Model Details')
+ detail.add_argument('--concat-rnn-layers', type='bool', default=True,
+ help='Combine hidden states from each encoding layer')
+ detail.add_argument('--question-merge', type=str, default='self_attn',
+ help='The way of computing the question representation')
+ detail.add_argument('--use-qemb', type='bool', default=True,
+ help='Whether to use weighted question embeddings')
+ detail.add_argument('--use-in-question', type='bool', default=True,
+ help='Whether to use in_question_* features')
+ detail.add_argument('--use-pos', type='bool', default=True,
+ help='Whether to use pos features')
+ detail.add_argument('--use-ner', type='bool', default=True,
+ help='Whether to use ner features')
+ detail.add_argument('--use-lemma', type='bool', default=True,
+ help='Whether to use lemma features')
+ detail.add_argument('--use-tf', type='bool', default=True,
+ help='Whether to use term frequency features')
+
+ # Optimization details
+ optim = parser.add_argument_group('DrQA Reader Optimization')
+ optim.add_argument('--dropout-emb', type=float, default=0.4,
+ help='Dropout rate for word embeddings')
+ optim.add_argument('--dropout-rnn', type=float, default=0.4,
+ help='Dropout rate for RNN states')
+ optim.add_argument('--dropout-rnn-output', type='bool', default=True,
+ help='Whether to dropout the RNN output')
+ optim.add_argument('--optimizer', type=str, default='adamax',
+ help='Optimizer: sgd or adamax')
+ optim.add_argument('--learning-rate', type=float, default=0.1,
+ help='Learning rate for SGD only')
+ optim.add_argument('--grad-clipping', type=float, default=10,
+ help='Gradient clipping')
+ optim.add_argument('--weight-decay', type=float, default=0,
+ help='Weight decay factor')
+ optim.add_argument('--momentum', type=float, default=0,
+ help='Momentum factor')
+ optim.add_argument('--fix-embeddings', type='bool', default=True,
+ help='Keep word embeddings fixed (use pretrained)')
+ optim.add_argument('--tune-partial', type=int, default=0,
+ help='Backprop through only the top N question words')
+ optim.add_argument('--rnn-padding', type='bool', default=False,
+ help='Explicitly account for padding in RNN encoding')
+ optim.add_argument('--max-len', type=int, default=15,
+ help='The max span allowed during decoding')
+
+
+def get_model_args(args):
+ """Filter args for model ones.
+
+ From a args Namespace, return a new Namespace with *only* the args specific
+ to the model architecture or optimization. (i.e. the ones defined here.)
+ """
+ global MODEL_ARCHITECTURE, MODEL_OPTIMIZER
+ required_args = MODEL_ARCHITECTURE | MODEL_OPTIMIZER
+ arg_values = {k: v for k, v in vars(args).items() if k in required_args}
+ return argparse.Namespace(**arg_values)
+
+
+def override_model_args(old_args, new_args):
+ """Set args to new parameters.
+
+ Decide which model args to keep and which to override when resolving a set
+ of saved args and new args.
+
+ We keep the new optimation, but leave the model architecture alone.
+ """
+ global MODEL_OPTIMIZER
+ old_args, new_args = vars(old_args), vars(new_args)
+ for k in old_args.keys():
+ if k in new_args and old_args[k] != new_args[k]:
+ if k in MODEL_OPTIMIZER:
+ logger.info('Overriding saved %s: %s --> %s' %
+ (k, old_args[k], new_args[k]))
+ old_args[k] = new_args[k]
+ else:
+ logger.info('Keeping saved %s: %s' % (k, old_args[k]))
+ return argparse.Namespace(**old_args)
diff --git a/drqa/reader/data.py b/drqa/reader/data.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9e0215814add331aec33c1ca757020765fbdef5
--- /dev/null
+++ b/drqa/reader/data.py
@@ -0,0 +1,131 @@
+#!/usr/bin/env python3
+# Copyright 2017-present, Facebook, Inc.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""Data processing/loading helpers."""
+
+import numpy as np
+import logging
+import unicodedata
+
+from torch.utils.data import Dataset
+from torch.utils.data.sampler import Sampler
+from .vector import vectorize
+
+logger = logging.getLogger(__name__)
+
+
+# ------------------------------------------------------------------------------
+# Dictionary class for tokens.
+# ------------------------------------------------------------------------------
+
+
+class Dictionary(object):
+ NULL = ''
+ UNK = ''
+ START = 2
+
+ @staticmethod
+ def normalize(token):
+ return unicodedata.normalize('NFD', token)
+
+ def __init__(self):
+ self.tok2ind = {self.NULL: 0, self.UNK: 1}
+ self.ind2tok = {0: self.NULL, 1: self.UNK}
+
+ def __len__(self):
+ return len(self.tok2ind)
+
+ def __iter__(self):
+ return iter(self.tok2ind)
+
+ def __contains__(self, key):
+ if type(key) == int:
+ return key in self.ind2tok
+ elif type(key) == str:
+ return self.normalize(key) in self.tok2ind
+
+ def __getitem__(self, key):
+ if type(key) == int:
+ return self.ind2tok.get(key, self.UNK)
+ if type(key) == str:
+ return self.tok2ind.get(self.normalize(key),
+ self.tok2ind.get(self.UNK))
+
+ def __setitem__(self, key, item):
+ if type(key) == int and type(item) == str:
+ self.ind2tok[key] = item
+ elif type(key) == str and type(item) == int:
+ self.tok2ind[key] = item
+ else:
+ raise RuntimeError('Invalid (key, item) types.')
+
+ def add(self, token):
+ token = self.normalize(token)
+ if token not in self.tok2ind:
+ index = len(self.tok2ind)
+ self.tok2ind[token] = index
+ self.ind2tok[index] = token
+
+ def tokens(self):
+ """Get dictionary tokens.
+
+ Return all the words indexed by this dictionary, except for special
+ tokens.
+ """
+ tokens = [k for k in self.tok2ind.keys()
+ if k not in {'', ''}]
+ return tokens
+
+
+# ------------------------------------------------------------------------------
+# PyTorch dataset class for SQuAD (and SQuAD-like) data.
+# ------------------------------------------------------------------------------
+
+
+class ReaderDataset(Dataset):
+
+ def __init__(self, examples, model, single_answer=False):
+ self.model = model
+ self.examples = examples
+ self.single_answer = single_answer
+
+ def __len__(self):
+ return len(self.examples)
+
+ def __getitem__(self, index):
+ return vectorize(self.examples[index], self.model, self.single_answer)
+
+ def lengths(self):
+ return [(len(ex['document']), len(ex['question']))
+ for ex in self.examples]
+
+
+# ------------------------------------------------------------------------------
+# PyTorch sampler returning batched of sorted lengths (by doc and question).
+# ------------------------------------------------------------------------------
+
+
+class SortedBatchSampler(Sampler):
+
+ def __init__(self, lengths, batch_size, shuffle=True):
+ self.lengths = lengths
+ self.batch_size = batch_size
+ self.shuffle = shuffle
+
+ def __iter__(self):
+ lengths = np.array(
+ [(-l[0], -l[1], np.random.random()) for l in self.lengths],
+ dtype=[('l1', np.int_), ('l2', np.int_), ('rand', np.float_)]
+ )
+ indices = np.argsort(lengths, order=('l1', 'l2', 'rand'))
+ batches = [indices[i:i + self.batch_size]
+ for i in range(0, len(indices), self.batch_size)]
+ if self.shuffle:
+ np.random.shuffle(batches)
+ return iter([i for batch in batches for i in batch])
+
+ def __len__(self):
+ return len(self.lengths)
diff --git a/drqa/reader/layers.py b/drqa/reader/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..71a967d806a04160ef6bce9741001c33eecca62e
--- /dev/null
+++ b/drqa/reader/layers.py
@@ -0,0 +1,311 @@
+#!/usr/bin/env python3
+# Copyright 2017-present, Facebook, Inc.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""Definitions of model layers/NN modules"""
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+# ------------------------------------------------------------------------------
+# Modules
+# ------------------------------------------------------------------------------
+
+
+class StackedBRNN(nn.Module):
+ """Stacked Bi-directional RNNs.
+
+ Differs from standard PyTorch library in that it has the option to save
+ and concat the hidden states between layers. (i.e. the output hidden size
+ for each sequence input is num_layers * hidden_size).
+ """
+
+ def __init__(self, input_size, hidden_size, num_layers,
+ dropout_rate=0, dropout_output=False, rnn_type=nn.LSTM,
+ concat_layers=False, padding=False):
+ super(StackedBRNN, self).__init__()
+ self.padding = padding
+ self.dropout_output = dropout_output
+ self.dropout_rate = dropout_rate
+ self.num_layers = num_layers
+ self.concat_layers = concat_layers
+ self.rnns = nn.ModuleList()
+ for i in range(num_layers):
+ input_size = input_size if i == 0 else 2 * hidden_size
+ self.rnns.append(rnn_type(input_size, hidden_size,
+ num_layers=1,
+ bidirectional=True))
+
+ def forward(self, x, x_mask):
+ """Encode either padded or non-padded sequences.
+
+ Can choose to either handle or ignore variable length sequences.
+ Always handle padding in eval.
+
+ Args:
+ x: batch * len * hdim
+ x_mask: batch * len (1 for padding, 0 for true)
+ Output:
+ x_encoded: batch * len * hdim_encoded
+ """
+ if x_mask.data.sum() == 0:
+ # No padding necessary.
+ output = self._forward_unpadded(x, x_mask)
+ elif self.padding or not self.training:
+ # Pad if we care or if its during eval.
+ output = self._forward_padded(x, x_mask)
+ else:
+ # We don't care.
+ output = self._forward_unpadded(x, x_mask)
+
+ return output.contiguous()
+
+ def _forward_unpadded(self, x, x_mask):
+ """Faster encoding that ignores any padding."""
+ # Transpose batch and sequence dims
+ x = x.transpose(0, 1)
+
+ # Encode all layers
+ outputs = [x]
+ for i in range(self.num_layers):
+ rnn_input = outputs[-1]
+
+ # Apply dropout to hidden input
+ if self.dropout_rate > 0:
+ rnn_input = F.dropout(rnn_input,
+ p=self.dropout_rate,
+ training=self.training)
+ # Forward
+ rnn_output = self.rnns[i](rnn_input)[0]
+ outputs.append(rnn_output)
+
+ # Concat hidden layers
+ if self.concat_layers:
+ output = torch.cat(outputs[1:], 2)
+ else:
+ output = outputs[-1]
+
+ # Transpose back
+ output = output.transpose(0, 1)
+
+ # Dropout on output layer
+ if self.dropout_output and self.dropout_rate > 0:
+ output = F.dropout(output,
+ p=self.dropout_rate,
+ training=self.training)
+ return output
+
+ def _forward_padded(self, x, x_mask):
+ """Slower (significantly), but more precise, encoding that handles
+ padding.
+ """
+ # Compute sorted sequence lengths
+ lengths = x_mask.data.eq(0).long().sum(1).squeeze()
+ _, idx_sort = torch.sort(lengths, dim=0, descending=True)
+ _, idx_unsort = torch.sort(idx_sort, dim=0)
+ lengths = list(lengths[idx_sort])
+
+ # Sort x
+ x = x.index_select(0, idx_sort)
+
+ # Transpose batch and sequence dims
+ x = x.transpose(0, 1)
+
+ # Pack it up
+ rnn_input = nn.utils.rnn.pack_padded_sequence(x, lengths)
+
+ # Encode all layers
+ outputs = [rnn_input]
+ for i in range(self.num_layers):
+ rnn_input = outputs[-1]
+
+ # Apply dropout to input
+ if self.dropout_rate > 0:
+ dropout_input = F.dropout(rnn_input.data,
+ p=self.dropout_rate,
+ training=self.training)
+ rnn_input = nn.utils.rnn.PackedSequence(dropout_input,
+ rnn_input.batch_sizes)
+ outputs.append(self.rnns[i](rnn_input)[0])
+
+ # Unpack everything
+ for i, o in enumerate(outputs[1:], 1):
+ outputs[i] = nn.utils.rnn.pad_packed_sequence(o)[0]
+
+ # Concat hidden layers or take final
+ if self.concat_layers:
+ output = torch.cat(outputs[1:], 2)
+ else:
+ output = outputs[-1]
+
+ # Transpose and unsort
+ output = output.transpose(0, 1)
+ output = output.index_select(0, idx_unsort)
+
+ # Pad up to original batch sequence length
+ if output.size(1) != x_mask.size(1):
+ padding = torch.zeros(output.size(0),
+ x_mask.size(1) - output.size(1),
+ output.size(2)).type(output.data.type())
+ output = torch.cat([output, padding], 1)
+
+ # Dropout on output layer
+ if self.dropout_output and self.dropout_rate > 0:
+ output = F.dropout(output,
+ p=self.dropout_rate,
+ training=self.training)
+ return output
+
+
+class SeqAttnMatch(nn.Module):
+ """Given sequences X and Y, match sequence Y to each element in X.
+
+ * o_i = sum(alpha_j * y_j) for i in X
+ * alpha_j = softmax(y_j * x_i)
+ """
+
+ def __init__(self, input_size, identity=False):
+ super(SeqAttnMatch, self).__init__()
+ if not identity:
+ self.linear = nn.Linear(input_size, input_size)
+ else:
+ self.linear = None
+
+ def forward(self, x, y, y_mask):
+ """
+ Args:
+ x: batch * len1 * hdim
+ y: batch * len2 * hdim
+ y_mask: batch * len2 (1 for padding, 0 for true)
+ Output:
+ matched_seq: batch * len1 * hdim
+ """
+ # Project vectors
+ if self.linear:
+ x_proj = self.linear(x.view(-1, x.size(2))).view(x.size())
+ x_proj = F.relu(x_proj)
+ y_proj = self.linear(y.view(-1, y.size(2))).view(y.size())
+ y_proj = F.relu(y_proj)
+ else:
+ x_proj = x
+ y_proj = y
+
+ # Compute scores
+ scores = x_proj.bmm(y_proj.transpose(2, 1))
+
+ # Mask padding
+ y_mask = y_mask.unsqueeze(1).expand(scores.size())
+ scores.data.masked_fill_(y_mask.data, -float('inf'))
+
+ # Normalize with softmax
+ alpha_flat = F.softmax(scores.view(-1, y.size(1)), dim=-1)
+ alpha = alpha_flat.view(-1, x.size(1), y.size(1))
+
+ # Take weighted average
+ matched_seq = alpha.bmm(y)
+ return matched_seq
+
+
+class BilinearSeqAttn(nn.Module):
+ """A bilinear attention layer over a sequence X w.r.t y:
+
+ * o_i = softmax(x_i'Wy) for x_i in X.
+
+ Optionally don't normalize output weights.
+ """
+
+ def __init__(self, x_size, y_size, identity=False, normalize=True):
+ super(BilinearSeqAttn, self).__init__()
+ self.normalize = normalize
+
+ # If identity is true, we just use a dot product without transformation.
+ if not identity:
+ self.linear = nn.Linear(y_size, x_size)
+ else:
+ self.linear = None
+
+ def forward(self, x, y, x_mask):
+ """
+ Args:
+ x: batch * len * hdim1
+ y: batch * hdim2
+ x_mask: batch * len (1 for padding, 0 for true)
+ Output:
+ alpha = batch * len
+ """
+ Wy = self.linear(y) if self.linear is not None else y
+ xWy = x.bmm(Wy.unsqueeze(2)).squeeze(2)
+ xWy.data.masked_fill_(x_mask.data, -float('inf'))
+ if self.normalize:
+ if self.training:
+ # In training we output log-softmax for NLL
+ alpha = F.log_softmax(xWy, dim=-1)
+ else:
+ # ...Otherwise 0-1 probabilities
+ alpha = F.softmax(xWy, dim=-1)
+ else:
+ alpha = xWy.exp()
+ return alpha
+
+
+class LinearSeqAttn(nn.Module):
+ """Self attention over a sequence:
+
+ * o_i = softmax(Wx_i) for x_i in X.
+ """
+
+ def __init__(self, input_size):
+ super(LinearSeqAttn, self).__init__()
+ self.linear = nn.Linear(input_size, 1)
+
+ def forward(self, x, x_mask):
+ """
+ Args:
+ x: batch * len * hdim
+ x_mask: batch * len (1 for padding, 0 for true)
+ Output:
+ alpha: batch * len
+ """
+ x_flat = x.view(-1, x.size(-1))
+ scores = self.linear(x_flat).view(x.size(0), x.size(1))
+ scores.data.masked_fill_(x_mask.data, -float('inf'))
+ alpha = F.softmax(scores, dim=-1)
+ return alpha
+
+
+# ------------------------------------------------------------------------------
+# Functional
+# ------------------------------------------------------------------------------
+
+
+def uniform_weights(x, x_mask):
+ """Return uniform weights over non-masked x (a sequence of vectors).
+
+ Args:
+ x: batch * len * hdim
+ x_mask: batch * len (1 for padding, 0 for true)
+ Output:
+ x_avg: batch * hdim
+ """
+ alpha = torch.ones(x.size(0), x.size(1))
+ if x.data.is_cuda:
+ alpha = alpha.cuda()
+ alpha = alpha * x_mask.eq(0).float()
+ alpha = alpha / alpha.sum(1).expand(alpha.size())
+ return alpha
+
+
+def weighted_avg(x, weights):
+ """Return a weighted average of x (a sequence of vectors).
+
+ Args:
+ x: batch * len * hdim
+ weights: batch * len, sum(dim = 1) = 1
+ Output:
+ x_avg: batch * hdim
+ """
+ return weights.unsqueeze(1).bmm(x).squeeze(1)
diff --git a/drqa/reader/model.py b/drqa/reader/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..c52fca0d067f6c06b55422a0e0a8507c8982a94a
--- /dev/null
+++ b/drqa/reader/model.py
@@ -0,0 +1,482 @@
+#!/usr/bin/env python3
+# Copyright 2017-present, Facebook, Inc.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""DrQA Document Reader model"""
+
+import torch
+import torch.optim as optim
+import torch.nn.functional as F
+import numpy as np
+import logging
+import copy
+
+from .config import override_model_args
+from .rnn_reader import RnnDocReader
+
+logger = logging.getLogger(__name__)
+
+
+class DocReader(object):
+ """High level model that handles intializing the underlying network
+ architecture, saving, updating examples, and predicting examples.
+ """
+
+ # --------------------------------------------------------------------------
+ # Initialization
+ # --------------------------------------------------------------------------
+
+ def __init__(self, args, word_dict, feature_dict,
+ state_dict=None, normalize=True):
+ # Book-keeping.
+ self.args = args
+ self.word_dict = word_dict
+ self.args.vocab_size = len(word_dict)
+ self.feature_dict = feature_dict
+ self.args.num_features = len(feature_dict)
+ self.updates = 0
+ self.use_cuda = False
+ self.parallel = False
+
+ # Building network. If normalize if false, scores are not normalized
+ # 0-1 per paragraph (no softmax).
+ if args.model_type == 'rnn':
+ self.network = RnnDocReader(args, normalize)
+ else:
+ raise RuntimeError('Unsupported model: %s' % args.model_type)
+
+ # Load saved state
+ if state_dict:
+ # Load buffer separately
+ if 'fixed_embedding' in state_dict:
+ fixed_embedding = state_dict.pop('fixed_embedding')
+ self.network.load_state_dict(state_dict)
+ self.network.register_buffer('fixed_embedding', fixed_embedding)
+ else:
+ self.network.load_state_dict(state_dict)
+
+ def expand_dictionary(self, words):
+ """Add words to the DocReader dictionary if they do not exist. The
+ underlying embedding matrix is also expanded (with random embeddings).
+
+ Args:
+ words: iterable of tokens to add to the dictionary.
+ Output:
+ added: set of tokens that were added.
+ """
+ to_add = {self.word_dict.normalize(w) for w in words
+ if w not in self.word_dict}
+
+ # Add words to dictionary and expand embedding layer
+ if len(to_add) > 0:
+ logger.info('Adding %d new words to dictionary...' % len(to_add))
+ for w in to_add:
+ self.word_dict.add(w)
+ self.args.vocab_size = len(self.word_dict)
+ logger.info('New vocab size: %d' % len(self.word_dict))
+
+ old_embedding = self.network.embedding.weight.data
+ self.network.embedding = torch.nn.Embedding(self.args.vocab_size,
+ self.args.embedding_dim,
+ padding_idx=0)
+ new_embedding = self.network.embedding.weight.data
+ new_embedding[:old_embedding.size(0)] = old_embedding
+
+ # Return added words
+ return to_add
+
+ def load_embeddings(self, words, embedding_file):
+ """Load pretrained embeddings for a given list of words, if they exist.
+
+ Args:
+ words: iterable of tokens. Only those that are indexed in the
+ dictionary are kept.
+ embedding_file: path to text file of embeddings, space separated.
+ """
+ words = {w for w in words if w in self.word_dict}
+ logger.info('Loading pre-trained embeddings for %d words from %s' %
+ (len(words), embedding_file))
+ embedding = self.network.embedding.weight.data
+
+ # When normalized, some words are duplicated. (Average the embeddings).
+ vec_counts = {}
+ with open(embedding_file) as f:
+ # Skip first line if of form count/dim.
+ line = f.readline().rstrip().split(' ')
+ if len(line) != 2:
+ f.seek(0)
+ for line in f:
+ parsed = line.rstrip().split(' ')
+ assert(len(parsed) == embedding.size(1) + 1)
+ w = self.word_dict.normalize(parsed[0])
+ if w in words:
+ vec = torch.Tensor([float(i) for i in parsed[1:]])
+ if w not in vec_counts:
+ vec_counts[w] = 1
+ embedding[self.word_dict[w]].copy_(vec)
+ else:
+ logging.warning(
+ 'WARN: Duplicate embedding found for %s' % w
+ )
+ vec_counts[w] = vec_counts[w] + 1
+ embedding[self.word_dict[w]].add_(vec)
+
+ for w, c in vec_counts.items():
+ embedding[self.word_dict[w]].div_(c)
+
+ logger.info('Loaded %d embeddings (%.2f%%)' %
+ (len(vec_counts), 100 * len(vec_counts) / len(words)))
+
+ def tune_embeddings(self, words):
+ """Unfix the embeddings of a list of words. This is only relevant if
+ only some of the embeddings are being tuned (tune_partial = N).
+
+ Shuffles the N specified words to the front of the dictionary, and saves
+ the original vectors of the other N + 1:vocab words in a fixed buffer.
+
+ Args:
+ words: iterable of tokens contained in dictionary.
+ """
+ words = {w for w in words if w in self.word_dict}
+
+ if len(words) == 0:
+ logger.warning('Tried to tune embeddings, but no words given!')
+ return
+
+ if len(words) == len(self.word_dict):
+ logger.warning('Tuning ALL embeddings in dictionary')
+ return
+
+ # Shuffle words and vectors
+ embedding = self.network.embedding.weight.data
+ for idx, swap_word in enumerate(words, self.word_dict.START):
+ # Get current word + embedding for this index
+ curr_word = self.word_dict[idx]
+ curr_emb = embedding[idx].clone()
+ old_idx = self.word_dict[swap_word]
+
+ # Swap embeddings + dictionary indices
+ embedding[idx].copy_(embedding[old_idx])
+ embedding[old_idx].copy_(curr_emb)
+ self.word_dict[swap_word] = idx
+ self.word_dict[idx] = swap_word
+ self.word_dict[curr_word] = old_idx
+ self.word_dict[old_idx] = curr_word
+
+ # Save the original, fixed embeddings
+ self.network.register_buffer(
+ 'fixed_embedding', embedding[idx + 1:].clone()
+ )
+
+ def init_optimizer(self, state_dict=None):
+ """Initialize an optimizer for the free parameters of the network.
+
+ Args:
+ state_dict: network parameters
+ """
+ if self.args.fix_embeddings:
+ for p in self.network.embedding.parameters():
+ p.requires_grad = False
+ parameters = [p for p in self.network.parameters() if p.requires_grad]
+ if self.args.optimizer == 'sgd':
+ self.optimizer = optim.SGD(parameters, self.args.learning_rate,
+ momentum=self.args.momentum,
+ weight_decay=self.args.weight_decay)
+ elif self.args.optimizer == 'adamax':
+ self.optimizer = optim.Adamax(parameters,
+ weight_decay=self.args.weight_decay)
+ else:
+ raise RuntimeError('Unsupported optimizer: %s' %
+ self.args.optimizer)
+
+ # --------------------------------------------------------------------------
+ # Learning
+ # --------------------------------------------------------------------------
+
+ def update(self, ex):
+ """Forward a batch of examples; step the optimizer to update weights."""
+ if not self.optimizer:
+ raise RuntimeError('No optimizer set.')
+
+ # Train mode
+ self.network.train()
+
+ # Transfer to GPU
+ if self.use_cuda:
+ inputs = [e if e is None else e.cuda(non_blocking=True)
+ for e in ex[:5]]
+ target_s = ex[5].cuda(non_blocking=True)
+ target_e = ex[6].cuda(non_blocking=True)
+ else:
+ inputs = [e if e is None else e for e in ex[:5]]
+ target_s = ex[5]
+ target_e = ex[6]
+
+ # Run forward
+ score_s, score_e = self.network(*inputs)
+
+ # Compute loss and accuracies
+ loss = F.nll_loss(score_s, target_s) + F.nll_loss(score_e, target_e)
+
+ # Clear gradients and run backward
+ self.optimizer.zero_grad()
+ loss.backward()
+
+ # Clip gradients
+ torch.nn.utils.clip_grad_norm_(self.network.parameters(),
+ self.args.grad_clipping)
+
+ # Update parameters
+ self.optimizer.step()
+ self.updates += 1
+
+ # Reset any partially fixed parameters (e.g. rare words)
+ self.reset_parameters()
+
+ return loss.item(), ex[0].size(0)
+
+ def reset_parameters(self):
+ """Reset any partially fixed parameters to original states."""
+
+ # Reset fixed embeddings to original value
+ if self.args.tune_partial > 0:
+ if self.parallel:
+ embedding = self.network.module.embedding.weight.data
+ fixed_embedding = self.network.module.fixed_embedding
+ else:
+ embedding = self.network.embedding.weight.data
+ fixed_embedding = self.network.fixed_embedding
+
+ # Embeddings to fix are the last indices
+ offset = embedding.size(0) - fixed_embedding.size(0)
+ if offset >= 0:
+ embedding[offset:] = fixed_embedding
+
+ # --------------------------------------------------------------------------
+ # Prediction
+ # --------------------------------------------------------------------------
+
+ def predict(self, ex, candidates=None, top_n=1, async_pool=None):
+ """Forward a batch of examples only to get predictions.
+
+ Args:
+ ex: the batch
+ candidates: batch * variable length list of string answer options.
+ The model will only consider exact spans contained in this list.
+ top_n: Number of predictions to return per batch element.
+ async_pool: If provided, non-gpu post-processing will be offloaded
+ to this CPU process pool.
+ Output:
+ pred_s: batch * top_n predicted start indices
+ pred_e: batch * top_n predicted end indices
+ pred_score: batch * top_n prediction scores
+
+ If async_pool is given, these will be AsyncResult handles.
+ """
+ # Eval mode
+ self.network.eval()
+
+ # Transfer to GPU
+ if self.use_cuda:
+ inputs = [e if e is None else e.cuda(non_blocking=True)
+ for e in ex[:5]]
+ else:
+ inputs = [e for e in ex[:5]]
+
+ # Run forward
+ with torch.no_grad():
+ score_s, score_e = self.network(*inputs)
+
+ # Decode predictions
+ score_s = score_s.data.cpu()
+ score_e = score_e.data.cpu()
+ if candidates:
+ args = (score_s, score_e, candidates, top_n, self.args.max_len)
+ if async_pool:
+ return async_pool.apply_async(self.decode_candidates, args)
+ else:
+ return self.decode_candidates(*args)
+ else:
+ args = (score_s, score_e, top_n, self.args.max_len)
+ if async_pool:
+ return async_pool.apply_async(self.decode, args)
+ else:
+ return self.decode(*args)
+
+ @staticmethod
+ def decode(score_s, score_e, top_n=1, max_len=None):
+ """Take argmax of constrained score_s * score_e.
+
+ Args:
+ score_s: independent start predictions
+ score_e: independent end predictions
+ top_n: number of top scored pairs to take
+ max_len: max span length to consider
+ """
+ pred_s = []
+ pred_e = []
+ pred_score = []
+ max_len = max_len or score_s.size(1)
+ for i in range(score_s.size(0)):
+ # Outer product of scores to get full p_s * p_e matrix
+ scores = torch.ger(score_s[i], score_e[i])
+
+ # Zero out negative length and over-length span scores
+ scores.triu_().tril_(max_len - 1)
+
+ # Take argmax or top n
+ scores = scores.numpy()
+ scores_flat = scores.flatten()
+ if top_n == 1:
+ idx_sort = [np.argmax(scores_flat)]
+ elif len(scores_flat) < top_n:
+ idx_sort = np.argsort(-scores_flat)
+ else:
+ idx = np.argpartition(-scores_flat, top_n)[0:top_n]
+ idx_sort = idx[np.argsort(-scores_flat[idx])]
+ s_idx, e_idx = np.unravel_index(idx_sort, scores.shape)
+ pred_s.append(s_idx)
+ pred_e.append(e_idx)
+ pred_score.append(scores_flat[idx_sort])
+ return pred_s, pred_e, pred_score
+
+ @staticmethod
+ def decode_candidates(score_s, score_e, candidates, top_n=1, max_len=None):
+ """Take argmax of constrained score_s * score_e. Except only consider
+ spans that are in the candidates list.
+ """
+ pred_s = []
+ pred_e = []
+ pred_score = []
+ for i in range(score_s.size(0)):
+ # Extract original tokens stored with candidates
+ tokens = candidates[i]['input']
+ cands = candidates[i]['cands']
+
+ if not cands:
+ # try getting from globals? (multiprocessing in pipeline mode)
+ from ..pipeline.drqa import PROCESS_CANDS
+ cands = PROCESS_CANDS
+ if not cands:
+ raise RuntimeError('No candidates given.')
+
+ # Score all valid candidates found in text.
+ # Brute force get all ngrams and compare against the candidate list.
+ max_len = max_len or len(tokens)
+ scores, s_idx, e_idx = [], [], []
+ for s, e in tokens.ngrams(n=max_len, as_strings=False):
+ span = tokens.slice(s, e).untokenize()
+ if span in cands or span.lower() in cands:
+ # Match! Record its score.
+ scores.append(score_s[i][s] * score_e[i][e - 1])
+ s_idx.append(s)
+ e_idx.append(e - 1)
+
+ if len(scores) == 0:
+ # No candidates present
+ pred_s.append([])
+ pred_e.append([])
+ pred_score.append([])
+ else:
+ # Rank found candidates
+ scores = np.array(scores)
+ s_idx = np.array(s_idx)
+ e_idx = np.array(e_idx)
+
+ idx_sort = np.argsort(-scores)[0:top_n]
+ pred_s.append(s_idx[idx_sort])
+ pred_e.append(e_idx[idx_sort])
+ pred_score.append(scores[idx_sort])
+ return pred_s, pred_e, pred_score
+
+ # --------------------------------------------------------------------------
+ # Saving and loading
+ # --------------------------------------------------------------------------
+
+ def save(self, filename):
+ if self.parallel:
+ network = self.network.module
+ else:
+ network = self.network
+ state_dict = copy.copy(network.state_dict())
+ if 'fixed_embedding' in state_dict:
+ state_dict.pop('fixed_embedding')
+ params = {
+ 'state_dict': state_dict,
+ 'word_dict': self.word_dict,
+ 'feature_dict': self.feature_dict,
+ 'args': self.args,
+ }
+ try:
+ torch.save(params, filename)
+ except BaseException:
+ logger.warning('WARN: Saving failed... continuing anyway.')
+
+ def checkpoint(self, filename, epoch):
+ if self.parallel:
+ network = self.network.module
+ else:
+ network = self.network
+ params = {
+ 'state_dict': network.state_dict(),
+ 'word_dict': self.word_dict,
+ 'feature_dict': self.feature_dict,
+ 'args': self.args,
+ 'epoch': epoch,
+ 'optimizer': self.optimizer.state_dict(),
+ }
+ try:
+ torch.save(params, filename)
+ except BaseException:
+ logger.warning('WARN: Saving failed... continuing anyway.')
+
+ @staticmethod
+ def load(filename, new_args=None, normalize=True):
+ logger.info('Loading model %s' % filename)
+ saved_params = torch.load(
+ filename, map_location=lambda storage, loc: storage
+ )
+ word_dict = saved_params['word_dict']
+ feature_dict = saved_params['feature_dict']
+ state_dict = saved_params['state_dict']
+ args = saved_params['args']
+ if new_args:
+ args = override_model_args(args, new_args)
+ return DocReader(args, word_dict, feature_dict, state_dict, normalize)
+
+ @staticmethod
+ def load_checkpoint(filename, normalize=True):
+ logger.info('Loading model %s' % filename)
+ saved_params = torch.load(
+ filename, map_location=lambda storage, loc: storage
+ )
+ word_dict = saved_params['word_dict']
+ feature_dict = saved_params['feature_dict']
+ state_dict = saved_params['state_dict']
+ epoch = saved_params['epoch']
+ optimizer = saved_params['optimizer']
+ args = saved_params['args']
+ model = DocReader(args, word_dict, feature_dict, state_dict, normalize)
+ model.init_optimizer(optimizer)
+ return model, epoch
+
+ # --------------------------------------------------------------------------
+ # Runtime
+ # --------------------------------------------------------------------------
+
+ def cuda(self):
+ self.use_cuda = True
+ self.network = self.network.cuda()
+
+ def cpu(self):
+ self.use_cuda = False
+ self.network = self.network.cpu()
+
+ def parallelize(self):
+ """Use data parallel to copy the model across several gpus.
+ This will take all gpus visible with CUDA_VISIBLE_DEVICES.
+ """
+ self.parallel = True
+ self.network = torch.nn.DataParallel(self.network)
diff --git a/drqa/reader/predictor.py b/drqa/reader/predictor.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f8f7c1cae8490166c895f51ef571e121f3045e9
--- /dev/null
+++ b/drqa/reader/predictor.py
@@ -0,0 +1,145 @@
+#!/usr/bin/env python3
+# Copyright 2017-present, Facebook, Inc.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""DrQA Document Reader predictor"""
+
+import logging
+
+from multiprocessing import Pool as ProcessPool
+from multiprocessing.util import Finalize
+
+from .vector import vectorize, batchify
+from .model import DocReader
+from . import DEFAULTS, utils
+from .. import tokenizers
+
+logger = logging.getLogger(__name__)
+
+
+# ------------------------------------------------------------------------------
+# Tokenize + annotate
+# ------------------------------------------------------------------------------
+
+PROCESS_TOK = None
+
+
+def init(tokenizer_class, annotators):
+ global PROCESS_TOK
+ PROCESS_TOK = tokenizer_class(annotators=annotators)
+ Finalize(PROCESS_TOK, PROCESS_TOK.shutdown, exitpriority=100)
+
+
+def tokenize(text):
+ global PROCESS_TOK
+ return PROCESS_TOK.tokenize(text)
+
+
+# ------------------------------------------------------------------------------
+# Predictor class.
+# ------------------------------------------------------------------------------
+
+
+class Predictor(object):
+ """Load a pretrained DocReader model and predict inputs on the fly."""
+
+ def __init__(self, model=None, tokenizer=None, normalize=True,
+ embedding_file=None, num_workers=None):
+ """
+ Args:
+ model: path to saved model file.
+ tokenizer: option string to select tokenizer class.
+ normalize: squash output score to 0-1 probabilities with a softmax.
+ embedding_file: if provided, will expand dictionary to use all
+ available pretrained vectors in this file.
+ num_workers: number of CPU processes to use to preprocess batches.
+ """
+ logger.info('Initializing model...')
+ self.model = DocReader.load(model or DEFAULTS['model'],
+ normalize=normalize)
+
+ if embedding_file:
+ logger.info('Expanding dictionary...')
+ words = utils.index_embedding_words(embedding_file)
+ added = self.model.expand_dictionary(words)
+ self.model.load_embeddings(added, embedding_file)
+
+ logger.info('Initializing tokenizer...')
+ annotators = tokenizers.get_annotators_for_model(self.model)
+ if not tokenizer:
+ tokenizer_class = DEFAULTS['tokenizer']
+ else:
+ tokenizer_class = tokenizers.get_class(tokenizer)
+
+ if num_workers is None or num_workers > 0:
+ self.workers = ProcessPool(
+ num_workers,
+ initializer=init,
+ initargs=(tokenizer_class, annotators),
+ )
+ else:
+ self.workers = None
+ self.tokenizer = tokenizer_class(annotators=annotators)
+
+ def predict(self, document, question, candidates=None, top_n=1):
+ """Predict a single document - question pair."""
+ results = self.predict_batch([(document, question, candidates,)], top_n)
+ return results[0]
+
+ def predict_batch(self, batch, top_n=1):
+ """Predict a batch of document - question pairs."""
+ documents, questions, candidates = [], [], []
+ for b in batch:
+ documents.append(b[0])
+ questions.append(b[1])
+ candidates.append(b[2] if len(b) == 3 else None)
+ candidates = candidates if any(candidates) else None
+
+ # Tokenize the inputs, perhaps multi-processed.
+ if self.workers:
+ q_tokens = self.workers.map_async(tokenize, questions)
+ d_tokens = self.workers.map_async(tokenize, documents)
+ q_tokens = list(q_tokens.get())
+ d_tokens = list(d_tokens.get())
+ else:
+ q_tokens = list(map(self.tokenizer.tokenize, questions))
+ d_tokens = list(map(self.tokenizer.tokenize, documents))
+
+ examples = []
+ for i in range(len(questions)):
+ examples.append({
+ 'id': i,
+ 'question': q_tokens[i].words(),
+ 'qlemma': q_tokens[i].lemmas(),
+ 'document': d_tokens[i].words(),
+ 'lemma': d_tokens[i].lemmas(),
+ 'pos': d_tokens[i].pos(),
+ 'ner': d_tokens[i].entities(),
+ })
+
+ # Stick document tokens in candidates for decoding
+ if candidates:
+ candidates = [{'input': d_tokens[i], 'cands': candidates[i]}
+ for i in range(len(candidates))]
+
+ # Build the batch and run it through the model
+ batch_exs = batchify([vectorize(e, self.model) for e in examples])
+ s, e, score = self.model.predict(batch_exs, candidates, top_n)
+
+ # Retrieve the predicted spans
+ results = []
+ for i in range(len(s)):
+ predictions = []
+ for j in range(len(s[i])):
+ span = d_tokens[i].slice(s[i][j], e[i][j] + 1).untokenize()
+ predictions.append((span, score[i][j].item()))
+ results.append(predictions)
+ return results
+
+ def cuda(self):
+ self.model.cuda()
+
+ def cpu(self):
+ self.model.cpu()
diff --git a/drqa/reader/rnn_reader.py b/drqa/reader/rnn_reader.py
new file mode 100644
index 0000000000000000000000000000000000000000..326dd2666b15c4e6403ccbe3ac7a62c1074642f2
--- /dev/null
+++ b/drqa/reader/rnn_reader.py
@@ -0,0 +1,135 @@
+#!/usr/bin/env python3
+# Copyright 2017-present, Facebook, Inc.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""Implementation of the RNN based DrQA reader."""
+
+import torch
+import torch.nn as nn
+from . import layers
+
+
+# ------------------------------------------------------------------------------
+# Network
+# ------------------------------------------------------------------------------
+
+
+class RnnDocReader(nn.Module):
+ RNN_TYPES = {'lstm': nn.LSTM, 'gru': nn.GRU, 'rnn': nn.RNN}
+
+ def __init__(self, args, normalize=True):
+ super(RnnDocReader, self).__init__()
+ # Store config
+ self.args = args
+
+ # Word embeddings (+1 for padding)
+ self.embedding = nn.Embedding(args.vocab_size,
+ args.embedding_dim,
+ padding_idx=0)
+
+ # Projection for attention weighted question
+ if args.use_qemb:
+ self.qemb_match = layers.SeqAttnMatch(args.embedding_dim)
+
+ # Input size to RNN: word emb + question emb + manual features
+ doc_input_size = args.embedding_dim + args.num_features
+ if args.use_qemb:
+ doc_input_size += args.embedding_dim
+
+ # RNN document encoder
+ self.doc_rnn = layers.StackedBRNN(
+ input_size=doc_input_size,
+ hidden_size=args.hidden_size,
+ num_layers=args.doc_layers,
+ dropout_rate=args.dropout_rnn,
+ dropout_output=args.dropout_rnn_output,
+ concat_layers=args.concat_rnn_layers,
+ rnn_type=self.RNN_TYPES[args.rnn_type],
+ padding=args.rnn_padding,
+ )
+
+ # RNN question encoder
+ self.question_rnn = layers.StackedBRNN(
+ input_size=args.embedding_dim,
+ hidden_size=args.hidden_size,
+ num_layers=args.question_layers,
+ dropout_rate=args.dropout_rnn,
+ dropout_output=args.dropout_rnn_output,
+ concat_layers=args.concat_rnn_layers,
+ rnn_type=self.RNN_TYPES[args.rnn_type],
+ padding=args.rnn_padding,
+ )
+
+ # Output sizes of rnn encoders
+ doc_hidden_size = 2 * args.hidden_size
+ question_hidden_size = 2 * args.hidden_size
+ if args.concat_rnn_layers:
+ doc_hidden_size *= args.doc_layers
+ question_hidden_size *= args.question_layers
+
+ # Question merging
+ if args.question_merge not in ['avg', 'self_attn']:
+ raise NotImplementedError('merge_mode = %s' % args.merge_mode)
+ if args.question_merge == 'self_attn':
+ self.self_attn = layers.LinearSeqAttn(question_hidden_size)
+
+ # Bilinear attention for span start/end
+ self.start_attn = layers.BilinearSeqAttn(
+ doc_hidden_size,
+ question_hidden_size,
+ normalize=normalize,
+ )
+ self.end_attn = layers.BilinearSeqAttn(
+ doc_hidden_size,
+ question_hidden_size,
+ normalize=normalize,
+ )
+
+ def forward(self, x1, x1_f, x1_mask, x2, x2_mask):
+ """Inputs:
+ x1 = document word indices [batch * len_d]
+ x1_f = document word features indices [batch * len_d * nfeat]
+ x1_mask = document padding mask [batch * len_d]
+ x2 = question word indices [batch * len_q]
+ x2_mask = question padding mask [batch * len_q]
+ """
+ # Embed both document and question
+ x1_emb = self.embedding(x1)
+ x2_emb = self.embedding(x2)
+
+ # Dropout on embeddings
+ if self.args.dropout_emb > 0:
+ x1_emb = nn.functional.dropout(x1_emb, p=self.args.dropout_emb,
+ training=self.training)
+ x2_emb = nn.functional.dropout(x2_emb, p=self.args.dropout_emb,
+ training=self.training)
+
+ # Form document encoding inputs
+ drnn_input = [x1_emb]
+
+ # Add attention-weighted question representation
+ if self.args.use_qemb:
+ x2_weighted_emb = self.qemb_match(x1_emb, x2_emb, x2_mask)
+ drnn_input.append(x2_weighted_emb)
+
+ # Add manual features
+ if self.args.num_features > 0:
+ drnn_input.append(x1_f)
+
+ # Encode document with RNN
+ doc_hiddens = self.doc_rnn(torch.cat(drnn_input, 2), x1_mask)
+
+ # Encode question with RNN + merge hiddens
+ question_hiddens = self.question_rnn(x2_emb, x2_mask)
+ if self.args.question_merge == 'avg':
+ q_merge_weights = layers.uniform_weights(question_hiddens, x2_mask)
+ elif self.args.question_merge == 'self_attn':
+ q_merge_weights = self.self_attn(question_hiddens, x2_mask)
+ question_hidden = layers.weighted_avg(question_hiddens, q_merge_weights)
+
+ # Predict start and end positions
+ start_scores = self.start_attn(doc_hiddens, question_hidden, x1_mask)
+ end_scores = self.end_attn(doc_hiddens, question_hidden, x1_mask)
+ return start_scores, end_scores
diff --git a/drqa/reader/utils.py b/drqa/reader/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..66eb542600e8e85382773f77d2fc5683669c0c73
--- /dev/null
+++ b/drqa/reader/utils.py
@@ -0,0 +1,288 @@
+#!/usr/bin/env python3
+# Copyright 2017-present, Facebook, Inc.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""DrQA reader utilities."""
+
+import json
+import time
+import logging
+import string
+import regex as re
+
+from collections import Counter
+from .data import Dictionary
+
+logger = logging.getLogger(__name__)
+
+
+# ------------------------------------------------------------------------------
+# Data loading
+# ------------------------------------------------------------------------------
+
+
+def load_data(args, filename, skip_no_answer=False):
+ """Load examples from preprocessed file.
+ One example per line, JSON encoded.
+ """
+ # Load JSON lines
+ with open(filename) as f:
+ examples = [json.loads(line) for line in f]
+
+ # Make case insensitive?
+ if args.uncased_question or args.uncased_doc:
+ for ex in examples:
+ if args.uncased_question:
+ ex['question'] = [w.lower() for w in ex['question']]
+ if args.uncased_doc:
+ ex['document'] = [w.lower() for w in ex['document']]
+
+ # Skip unparsed (start/end) examples
+ if skip_no_answer:
+ examples = [ex for ex in examples if len(ex['answers']) > 0]
+
+ return examples
+
+
+def load_text(filename):
+ """Load the paragraphs only of a SQuAD dataset. Store as qid -> text."""
+ # Load JSON file
+ with open(filename) as f:
+ examples = json.load(f)['data']
+
+ texts = {}
+ for article in examples:
+ for paragraph in article['paragraphs']:
+ for qa in paragraph['qas']:
+ texts[qa['id']] = paragraph['context']
+ return texts
+
+
+def load_answers(filename):
+ """Load the answers only of a SQuAD dataset. Store as qid -> [answers]."""
+ # Load JSON file
+ with open(filename) as f:
+ examples = json.load(f)['data']
+
+ ans = {}
+ for article in examples:
+ for paragraph in article['paragraphs']:
+ for qa in paragraph['qas']:
+ ans[qa['id']] = list(map(lambda x: x['text'], qa['answers']))
+ return ans
+
+
+# ------------------------------------------------------------------------------
+# Dictionary building
+# ------------------------------------------------------------------------------
+
+
+def index_embedding_words(embedding_file):
+ """Put all the words in embedding_file into a set."""
+ words = set()
+ with open(embedding_file) as f:
+ for line in f:
+ w = Dictionary.normalize(line.rstrip().split(' ')[0])
+ words.add(w)
+ return words
+
+
+def load_words(args, examples):
+ """Iterate and index all the words in examples (documents + questions)."""
+ def _insert(iterable):
+ for w in iterable:
+ w = Dictionary.normalize(w)
+ if valid_words and w not in valid_words:
+ continue
+ words.add(w)
+
+ if args.restrict_vocab and args.embedding_file:
+ logger.info('Restricting to words in %s' % args.embedding_file)
+ valid_words = index_embedding_words(args.embedding_file)
+ logger.info('Num words in set = %d' % len(valid_words))
+ else:
+ valid_words = None
+
+ words = set()
+ for ex in examples:
+ _insert(ex['question'])
+ _insert(ex['document'])
+ return words
+
+
+def build_word_dict(args, examples):
+ """Return a dictionary from question and document words in
+ provided examples.
+ """
+ word_dict = Dictionary()
+ for w in load_words(args, examples):
+ word_dict.add(w)
+ return word_dict
+
+
+def top_question_words(args, examples, word_dict):
+ """Count and return the most common question words in provided examples."""
+ word_count = Counter()
+ for ex in examples:
+ for w in ex['question']:
+ w = Dictionary.normalize(w)
+ if w in word_dict:
+ word_count.update([w])
+ return word_count.most_common(args.tune_partial)
+
+
+def build_feature_dict(args, examples):
+ """Index features (one hot) from fields in examples and options."""
+ def _insert(feature):
+ if feature not in feature_dict:
+ feature_dict[feature] = len(feature_dict)
+
+ feature_dict = {}
+
+ # Exact match features
+ if args.use_in_question:
+ _insert('in_question')
+ _insert('in_question_uncased')
+ if args.use_lemma:
+ _insert('in_question_lemma')
+
+ # Part of speech tag features
+ if args.use_pos:
+ for ex in examples:
+ for w in ex['pos']:
+ _insert('pos=%s' % w)
+
+ # Named entity tag features
+ if args.use_ner:
+ for ex in examples:
+ for w in ex['ner']:
+ _insert('ner=%s' % w)
+
+ # Term frequency feature
+ if args.use_tf:
+ _insert('tf')
+ return feature_dict
+
+
+# ------------------------------------------------------------------------------
+# Evaluation. Follows official evalutation script for v1.1 of the SQuAD dataset.
+# ------------------------------------------------------------------------------
+
+
+def normalize_answer(s):
+ """Lower text and remove punctuation, articles and extra whitespace."""
+ def remove_articles(text):
+ return re.sub(r'\b(a|an|the)\b', ' ', text)
+
+ def white_space_fix(text):
+ return ' '.join(text.split())
+
+ def remove_punc(text):
+ exclude = set(string.punctuation)
+ return ''.join(ch for ch in text if ch not in exclude)
+
+ def lower(text):
+ return text.lower()
+
+ return white_space_fix(remove_articles(remove_punc(lower(s))))
+
+
+def f1_score(prediction, ground_truth):
+ """Compute the geometric mean of precision and recall for answer tokens."""
+ prediction_tokens = normalize_answer(prediction).split()
+ ground_truth_tokens = normalize_answer(ground_truth).split()
+ common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
+ num_same = sum(common.values())
+ if num_same == 0:
+ return 0
+ precision = 1.0 * num_same / len(prediction_tokens)
+ recall = 1.0 * num_same / len(ground_truth_tokens)
+ f1 = (2 * precision * recall) / (precision + recall)
+ return f1
+
+
+def exact_match_score(prediction, ground_truth):
+ """Check if the prediction is a (soft) exact match with the ground truth."""
+ return normalize_answer(prediction) == normalize_answer(ground_truth)
+
+
+def regex_match_score(prediction, pattern):
+ """Check if the prediction matches the given regular expression."""
+ try:
+ compiled = re.compile(
+ pattern,
+ flags=re.IGNORECASE + re.UNICODE + re.MULTILINE
+ )
+ except BaseException:
+ logger.warn('Regular expression failed to compile: %s' % pattern)
+ return False
+ return compiled.match(prediction) is not None
+
+
+def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
+ """Given a prediction and multiple valid answers, return the score of
+ the best prediction-answer_n pair given a metric function.
+ """
+ scores_for_ground_truths = []
+ for ground_truth in ground_truths:
+ score = metric_fn(prediction, ground_truth)
+ scores_for_ground_truths.append(score)
+ return max(scores_for_ground_truths)
+
+
+# ------------------------------------------------------------------------------
+# Utility classes
+# ------------------------------------------------------------------------------
+
+
+class AverageMeter(object):
+ """Computes and stores the average and current value."""
+
+ def __init__(self):
+ self.reset()
+
+ def reset(self):
+ self.val = 0
+ self.avg = 0
+ self.sum = 0
+ self.count = 0
+
+ def update(self, val, n=1):
+ self.val = val
+ self.sum += val * n
+ self.count += n
+ self.avg = self.sum / self.count
+
+
+class Timer(object):
+ """Computes elapsed time."""
+
+ def __init__(self):
+ self.running = True
+ self.total = 0
+ self.start = time.time()
+
+ def reset(self):
+ self.running = True
+ self.total = 0
+ self.start = time.time()
+ return self
+
+ def resume(self):
+ if not self.running:
+ self.running = True
+ self.start = time.time()
+ return self
+
+ def stop(self):
+ if self.running:
+ self.running = False
+ self.total += time.time() - self.start
+ return self
+
+ def time(self):
+ if self.running:
+ return self.total + time.time() - self.start
+ return self.total
diff --git a/drqa/reader/vector.py b/drqa/reader/vector.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a721c7589b92524b470aaea1ce22bfa70ed6b46
--- /dev/null
+++ b/drqa/reader/vector.py
@@ -0,0 +1,127 @@
+#!/usr/bin/env python3
+# Copyright 2017-present, Facebook, Inc.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""Functions for putting examples into torch format."""
+
+from collections import Counter
+import torch
+
+
+def vectorize(ex, model, single_answer=False):
+ """Torchify a single example."""
+ args = model.args
+ word_dict = model.word_dict
+ feature_dict = model.feature_dict
+
+ # Index words
+ document = torch.LongTensor([word_dict[w] for w in ex['document']])
+ question = torch.LongTensor([word_dict[w] for w in ex['question']])
+
+ # Create extra features vector
+ if len(feature_dict) > 0:
+ features = torch.zeros(len(ex['document']), len(feature_dict))
+ else:
+ features = None
+
+ # f_{exact_match}
+ if args.use_in_question:
+ q_words_cased = {w for w in ex['question']}
+ q_words_uncased = {w.lower() for w in ex['question']}
+ q_lemma = {w for w in ex['qlemma']} if args.use_lemma else None
+ for i in range(len(ex['document'])):
+ if ex['document'][i] in q_words_cased:
+ features[i][feature_dict['in_question']] = 1.0
+ if ex['document'][i].lower() in q_words_uncased:
+ features[i][feature_dict['in_question_uncased']] = 1.0
+ if q_lemma and ex['lemma'][i] in q_lemma:
+ features[i][feature_dict['in_question_lemma']] = 1.0
+
+ # f_{token} (POS)
+ if args.use_pos:
+ for i, w in enumerate(ex['pos']):
+ f = 'pos=%s' % w
+ if f in feature_dict:
+ features[i][feature_dict[f]] = 1.0
+
+ # f_{token} (NER)
+ if args.use_ner:
+ for i, w in enumerate(ex['ner']):
+ f = 'ner=%s' % w
+ if f in feature_dict:
+ features[i][feature_dict[f]] = 1.0
+
+ # f_{token} (TF)
+ if args.use_tf:
+ counter = Counter([w.lower() for w in ex['document']])
+ l = len(ex['document'])
+ for i, w in enumerate(ex['document']):
+ features[i][feature_dict['tf']] = counter[w.lower()] * 1.0 / l
+
+ # Maybe return without target
+ if 'answers' not in ex:
+ return document, features, question, ex['id']
+
+ # ...or with target(s) (might still be empty if answers is empty)
+ if single_answer:
+ assert(len(ex['answers']) > 0)
+ start = torch.LongTensor(1).fill_(ex['answers'][0][0])
+ end = torch.LongTensor(1).fill_(ex['answers'][0][1])
+ else:
+ start = [a[0] for a in ex['answers']]
+ end = [a[1] for a in ex['answers']]
+
+ return document, features, question, start, end, ex['id']
+
+
+def batchify(batch):
+ """Gather a batch of individual examples into one batch."""
+ NUM_INPUTS = 3
+ NUM_TARGETS = 2
+ NUM_EXTRA = 1
+
+ ids = [ex[-1] for ex in batch]
+ docs = [ex[0] for ex in batch]
+ features = [ex[1] for ex in batch]
+ questions = [ex[2] for ex in batch]
+
+ # Batch documents and features
+ max_length = max([d.size(0) for d in docs])
+ x1 = torch.LongTensor(len(docs), max_length).zero_()
+ x1_mask = torch.ByteTensor(len(docs), max_length).fill_(1)
+ if features[0] is None:
+ x1_f = None
+ else:
+ x1_f = torch.zeros(len(docs), max_length, features[0].size(1))
+ for i, d in enumerate(docs):
+ x1[i, :d.size(0)].copy_(d)
+ x1_mask[i, :d.size(0)].fill_(0)
+ if x1_f is not None:
+ x1_f[i, :d.size(0)].copy_(features[i])
+
+ # Batch questions
+ max_length = max([q.size(0) for q in questions])
+ x2 = torch.LongTensor(len(questions), max_length).zero_()
+ x2_mask = torch.ByteTensor(len(questions), max_length).fill_(1)
+ for i, q in enumerate(questions):
+ x2[i, :q.size(0)].copy_(q)
+ x2_mask[i, :q.size(0)].fill_(0)
+
+ # Maybe return without targets
+ if len(batch[0]) == NUM_INPUTS + NUM_EXTRA:
+ return x1, x1_f, x1_mask, x2, x2_mask, ids
+
+ elif len(batch[0]) == NUM_INPUTS + NUM_EXTRA + NUM_TARGETS:
+ # ...Otherwise add targets
+ if torch.is_tensor(batch[0][3]):
+ y_s = torch.cat([ex[3] for ex in batch])
+ y_e = torch.cat([ex[4] for ex in batch])
+ else:
+ y_s = [ex[3] for ex in batch]
+ y_e = [ex[4] for ex in batch]
+ else:
+ raise RuntimeError('Incorrect number of inputs per example.')
+
+ return x1, x1_f, x1_mask, x2, x2_mask, y_s, y_e, ids
diff --git a/drqa/retriever/__init__.py b/drqa/retriever/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..13b4166edc7bea8b51fdfc248dffcd3859a7f4df
--- /dev/null
+++ b/drqa/retriever/__init__.py
@@ -0,0 +1,38 @@
+#!/usr/bin/env python3
+# Copyright 2017-present, Facebook, Inc.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+from .. import DATA_DIR
+
+DEFAULTS = {
+ 'db_path': os.path.join(DATA_DIR, 'wikipedia/docs.db'),
+ 'tfidf_path': os.path.join(
+ DATA_DIR,
+ 'wikipedia/docs-tfidf-ngram=2-hash=16777216-tokenizer=simple.npz'
+ ),
+ 'elastic_url': 'localhost:9200'
+}
+
+
+def set_default(key, value):
+ global DEFAULTS
+ DEFAULTS[key] = value
+
+
+def get_class(name):
+ if name == 'tfidf':
+ return TfidfDocRanker
+ if name == 'sqlite':
+ return DocDB
+ if name == 'elasticsearch':
+ return ElasticDocRanker
+ raise RuntimeError('Invalid retriever class: %s' % name)
+
+
+from .doc_db import DocDB
+from .tfidf_doc_ranker import TfidfDocRanker
+from .elastic_doc_ranker import ElasticDocRanker
diff --git a/drqa/retriever/__pycache__/__init__.cpython-38.pyc b/drqa/retriever/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bd16cdb0bade54f9f9f4ba79d0e686006a88134d
Binary files /dev/null and b/drqa/retriever/__pycache__/__init__.cpython-38.pyc differ
diff --git a/drqa/retriever/__pycache__/doc_db.cpython-38.pyc b/drqa/retriever/__pycache__/doc_db.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ee67d549cf21212978b158bc3681bd9618f21361
Binary files /dev/null and b/drqa/retriever/__pycache__/doc_db.cpython-38.pyc differ
diff --git a/drqa/retriever/__pycache__/elastic_doc_ranker.cpython-38.pyc b/drqa/retriever/__pycache__/elastic_doc_ranker.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cb6f8effd22ddd42ee3c7bd8cdd57d233e3f013c
Binary files /dev/null and b/drqa/retriever/__pycache__/elastic_doc_ranker.cpython-38.pyc differ
diff --git a/drqa/retriever/__pycache__/tfidf_doc_ranker.cpython-38.pyc b/drqa/retriever/__pycache__/tfidf_doc_ranker.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0e95b15af679a981e1655ba49d262c31c55fbfbc
Binary files /dev/null and b/drqa/retriever/__pycache__/tfidf_doc_ranker.cpython-38.pyc differ
diff --git a/drqa/retriever/__pycache__/utils.cpython-38.pyc b/drqa/retriever/__pycache__/utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b4b5b6dfe381946e8f9663bec5c3d5632eb710f6
Binary files /dev/null and b/drqa/retriever/__pycache__/utils.cpython-38.pyc differ
diff --git a/drqa/retriever/doc_db.py b/drqa/retriever/doc_db.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d1a451a37be375cdad1909536c6c38920700d34
--- /dev/null
+++ b/drqa/retriever/doc_db.py
@@ -0,0 +1,81 @@
+#!/usr/bin/env python3
+# Copyright 2017-present, Facebook, Inc.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""Documents, in a sqlite database."""
+
+import sqlite3
+from . import utils
+from . import DEFAULTS
+
+
+class DocDB(object):
+ """Sqlite backed document storage.
+
+ Implements get_doc_text(doc_id).
+ """
+
+ def __init__(self, db_path=None):
+ self.path = db_path or DEFAULTS['db_path']
+ self.connection = sqlite3.connect(self.path, check_same_thread=False)
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, *args):
+ self.close()
+
+ def path(self):
+ """Return the path to the file that backs this database."""
+ return self.path
+
+ def close(self):
+ """Close the connection to the database."""
+ self.connection.close()
+
+ def get_doc_ids(self):
+ """Fetch all ids of docs stored in the db."""
+ cursor = self.connection.cursor()
+ cursor.execute("SELECT id FROM documents")
+ results = [r[0] for r in cursor.fetchall()]
+ cursor.close()
+ return results
+
+ def get_doc_text(self, doc_id):
+ """Fetch the raw text of the doc for 'doc_id'."""
+ cursor = self.connection.cursor()
+ cursor.execute(
+ "SELECT text FROM documents WHERE id = ?",
+ (utils.normalize(doc_id), )
+ # (doc_id, )
+ )
+ result = cursor.fetchone()
+ cursor.close()
+ return result if result is None else result[0]
+
+
+ def get_doc_title(self, doc_id):
+ """Fetch the raw text of the doc for 'doc_id'."""
+ cursor = self.connection.cursor()
+ cursor.execute(
+ "SELECT title FROM documents WHERE id = ?",
+ (utils.normalize(doc_id),)
+ # (doc_id, )
+ )
+ result = cursor.fetchone()
+ cursor.close()
+ return result if result is None else result[0]
+
+ def get_doc_intro(self, doc_id):
+ """Fetch the raw text of the doc for 'doc_id'."""
+ cursor = self.connection.cursor()
+ cursor.execute(
+ "SELECT intro FROM documents WHERE id = ?", # intro: the introduction of Wikipedia page
+ (utils.normalize(doc_id),)
+ # (doc_id, )
+ )
+ result = cursor.fetchone()
+ cursor.close()
+ return result if result is None else result[0]
diff --git a/drqa/retriever/elastic_doc_ranker.py b/drqa/retriever/elastic_doc_ranker.py
new file mode 100644
index 0000000000000000000000000000000000000000..41d1bddd7bc07558aff0ffb7781df46978511129
--- /dev/null
+++ b/drqa/retriever/elastic_doc_ranker.py
@@ -0,0 +1,109 @@
+#!/usr/bin/env python3
+# Copyright 2017-present, Facebook, Inc.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""Rank documents with an ElasticSearch index"""
+
+import logging
+import scipy.sparse as sp
+
+from multiprocessing.pool import ThreadPool
+from functools import partial
+from elasticsearch import Elasticsearch
+
+from . import utils
+from . import DEFAULTS
+from .. import tokenizers
+
+logger = logging.getLogger(__name__)
+
+
+class ElasticDocRanker(object):
+ """ Connect to an ElasticSearch index.
+ Score pairs based on Elasticsearch
+ """
+
+ def __init__(self, elastic_url=None, elastic_index=None, elastic_fields=None, elastic_field_doc_name=None, strict=True, elastic_field_content=None):
+ """
+ Args:
+ elastic_url: URL of the ElasticSearch server containing port
+ elastic_index: Index name of ElasticSearch
+ elastic_fields: Fields of the Elasticsearch index to search in
+ elastic_field_doc_name: Field containing the name of the document (index)
+ strict: fail on empty queries or continue (and return empty result)
+ elastic_field_content: Field containing the content of document in plaint text
+ """
+ # Load from disk
+ elastic_url = elastic_url or DEFAULTS['elastic_url']
+ logger.info('Connecting to %s' % elastic_url)
+ self.es = Elasticsearch(hosts=elastic_url)
+ self.elastic_index = elastic_index
+ self.elastic_fields = elastic_fields
+ self.elastic_field_doc_name = elastic_field_doc_name
+ self.elastic_field_content = elastic_field_content
+ self.strict = strict
+
+ # Elastic Ranker
+
+ def get_doc_index(self, doc_id):
+ """Convert doc_id --> doc_index"""
+ field_index = self.elastic_field_doc_name
+ if isinstance(field_index, list):
+ field_index = '.'.join(field_index)
+ result = self.es.search(index=self.elastic_index, body={'query':{'match':
+ {field_index: doc_id}}})
+ return result['hits']['hits'][0]['_id']
+
+
+ def get_doc_id(self, doc_index):
+ """Convert doc_index --> doc_id"""
+ result = self.es.search(index=self.elastic_index, body={'query': { 'match': {"_id": doc_index}}})
+ source = result['hits']['hits'][0]['_source']
+ return utils.get_field(source, self.elastic_field_doc_name)
+
+ def closest_docs(self, query, k=1):
+ """Closest docs by using ElasticSearch
+ """
+ results = self.es.search(index=self.elastic_index, body={'size':k ,'query':
+ {'multi_match': {
+ 'query': query,
+ 'type': 'most_fields',
+ 'fields': self.elastic_fields}}})
+ hits = results['hits']['hits']
+ doc_ids = [utils.get_field(row['_source'], self.elastic_field_doc_name) for row in hits]
+ doc_scores = [row['_score'] for row in hits]
+ return doc_ids, doc_scores
+
+ def batch_closest_docs(self, queries, k=1, num_workers=None):
+ """Process a batch of closest_docs requests multithreaded.
+ Note: we can use plain threads here as scipy is outside of the GIL.
+ """
+ with ThreadPool(num_workers) as threads:
+ closest_docs = partial(self.closest_docs, k=k)
+ results = threads.map(closest_docs, queries)
+ return results
+
+ # Elastic DB
+
+ def __enter__(self):
+ return self
+
+ def close(self):
+ """Close the connection to the database."""
+ self.es = None
+
+ def get_doc_ids(self):
+ """Fetch all ids of docs stored in the db."""
+ results = self.es.search(index= self.elastic_index, body={
+ "query": {"match_all": {}}})
+ doc_ids = [utils.get_field(result['_source'], self.elastic_field_doc_name) for result in results['hits']['hits']]
+ return doc_ids
+
+ def get_doc_text(self, doc_id):
+ """Fetch the raw text of the doc for 'doc_id'."""
+ idx = self.get_doc_index(doc_id)
+ result = self.es.get(index=self.elastic_index, doc_type='_doc', id=idx)
+ return result if result is None else result['_source'][self.elastic_field_content]
+
diff --git a/drqa/retriever/tfidf_doc_ranker.py b/drqa/retriever/tfidf_doc_ranker.py
new file mode 100644
index 0000000000000000000000000000000000000000..551ed879c13e78dc4029475e1e501e10ce1829ec
--- /dev/null
+++ b/drqa/retriever/tfidf_doc_ranker.py
@@ -0,0 +1,121 @@
+#!/usr/bin/env python3
+# Copyright 2017-present, Facebook, Inc.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""Rank documents with TF-IDF scores"""
+
+import logging
+import numpy as np
+import scipy.sparse as sp
+
+from multiprocessing.pool import ThreadPool
+from functools import partial
+
+from . import utils
+from . import DEFAULTS
+from .. import tokenizers
+
+logger = logging.getLogger(__name__)
+
+
+class TfidfDocRanker(object):
+ """Loads a pre-weighted inverted index of token/document terms.
+ Scores new queries by taking sparse dot products.
+ """
+
+ def __init__(self, tfidf_path=None, strict=True):
+ """
+ Args:
+ tfidf_path: path to saved model file
+ strict: fail on empty queries or continue (and return empty result)
+ """
+ # Load from disk
+ tfidf_path = tfidf_path or DEFAULTS['tfidf_path']
+ logger.info('Loading %s' % tfidf_path)
+ matrix, metadata = utils.load_sparse_csr(tfidf_path)
+ self.doc_mat = matrix
+ self.ngrams = metadata['ngram']
+ self.hash_size = metadata['hash_size']
+ self.tokenizer = tokenizers.get_class(metadata['tokenizer'])()
+ self.doc_freqs = metadata['doc_freqs'].squeeze()
+ self.doc_dict = metadata['doc_dict']
+ self.num_docs = len(self.doc_dict[0])
+ self.strict = strict
+
+ def get_doc_index(self, doc_id):
+ """Convert doc_id --> doc_index"""
+ return self.doc_dict[0][doc_id]
+
+ def get_doc_id(self, doc_index):
+ """Convert doc_index --> doc_id"""
+ return self.doc_dict[1][doc_index]
+
+ def closest_docs(self, query, k=1):
+ """Closest docs by dot product between query and documents
+ in tfidf weighted word vector space.
+ """
+ spvec = self.text2spvec(query)
+ res = spvec * self.doc_mat
+
+ if len(res.data) <= k:
+ o_sort = np.argsort(-res.data)
+ else:
+ o = np.argpartition(-res.data, k)[0:k]
+ o_sort = o[np.argsort(-res.data[o])]
+
+ doc_scores = res.data[o_sort]
+ doc_ids = [self.get_doc_id(i) for i in res.indices[o_sort]]
+ return doc_ids, doc_scores
+
+ def batch_closest_docs(self, queries, k=1, num_workers=None):
+ """Process a batch of closest_docs requests multithreaded.
+ Note: we can use plain threads here as scipy is outside of the GIL.
+ """
+ with ThreadPool(num_workers) as threads:
+ closest_docs = partial(self.closest_docs, k=k)
+ results = threads.map(closest_docs, queries)
+ return results
+
+ def parse(self, query):
+ """Parse the query into tokens (either ngrams or tokens)."""
+ tokens = self.tokenizer.tokenize(query)
+ return tokens.ngrams(n=self.ngrams, uncased=True,
+ filter_fn=utils.filter_ngram)
+
+ def text2spvec(self, query):
+ """Create a sparse tfidf-weighted word vector from query.
+
+ tfidf = log(tf + 1) * log((N - Nt + 0.5) / (Nt + 0.5))
+ """
+ # Get hashed ngrams
+ words = self.parse(utils.normalize(query))
+ wids = [utils.hash(w, self.hash_size) for w in words]
+
+ if len(wids) == 0:
+ if self.strict:
+ raise RuntimeError('No valid word in: %s' % query)
+ else:
+ logger.warning('No valid word in: %s' % query)
+ return sp.csr_matrix((1, self.hash_size))
+
+ # Count TF
+ wids_unique, wids_counts = np.unique(wids, return_counts=True)
+ tfs = np.log1p(wids_counts)
+
+ # Count IDF
+ Ns = self.doc_freqs[wids_unique]
+ idfs = np.log((self.num_docs - Ns + 0.5) / (Ns + 0.5))
+ idfs[idfs < 0] = 0
+
+ # TF-IDF
+ data = np.multiply(tfs, idfs)
+
+ # One row, sparse csr matrix
+ indptr = np.array([0, len(wids_unique)])
+ spvec = sp.csr_matrix(
+ (data, wids_unique, indptr), shape=(1, self.hash_size)
+ )
+
+ return spvec
diff --git a/drqa/retriever/utils.py b/drqa/retriever/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..c08c555c9d6b7929100f60266358d0da6423fa8c
--- /dev/null
+++ b/drqa/retriever/utils.py
@@ -0,0 +1,120 @@
+#!/usr/bin/env python3
+# Copyright 2017-present, Facebook, Inc.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""Various retriever utilities."""
+
+import regex
+import unicodedata
+import numpy as np
+import scipy.sparse as sp
+from sklearn.utils import murmurhash3_32
+
+
+# ------------------------------------------------------------------------------
+# Sparse matrix saving/loading helpers.
+# ------------------------------------------------------------------------------
+
+
+def save_sparse_csr(filename, matrix, metadata=None):
+ data = {
+ 'data': matrix.data,
+ 'indices': matrix.indices,
+ 'indptr': matrix.indptr,
+ 'shape': matrix.shape,
+ 'metadata': metadata,
+ }
+ np.savez(filename, **data)
+
+
+def load_sparse_csr(filename):
+ loader = np.load(filename, allow_pickle=True)
+ matrix = sp.csr_matrix((loader['data'], loader['indices'],
+ loader['indptr']), shape=loader['shape'])
+ return matrix, loader['metadata'].item(0) if 'metadata' in loader else None
+
+
+# ------------------------------------------------------------------------------
+# Token hashing.
+# ------------------------------------------------------------------------------
+
+
+def hash(token, num_buckets):
+ """Unsigned 32 bit murmurhash for feature hashing."""
+ return murmurhash3_32(token, positive=True) % num_buckets
+
+
+# ------------------------------------------------------------------------------
+# Text cleaning.
+# ------------------------------------------------------------------------------
+
+
+STOPWORDS = {
+ 'i', 'me', 'my', 'myself', 'we', 'our', 'ours', 'ourselves', 'you', 'your',
+ 'yours', 'yourself', 'yourselves', 'he', 'him', 'his', 'himself', 'she',
+ 'her', 'hers', 'herself', 'it', 'its', 'itself', 'they', 'them', 'their',
+ 'theirs', 'themselves', 'what', 'which', 'who', 'whom', 'this', 'that',
+ 'these', 'those', 'am', 'is', 'are', 'was', 'were', 'be', 'been', 'being',
+ 'have', 'has', 'had', 'having', 'do', 'does', 'did', 'doing', 'a', 'an',
+ 'the', 'and', 'but', 'if', 'or', 'because', 'as', 'until', 'while', 'of',
+ 'at', 'by', 'for', 'with', 'about', 'against', 'between', 'into', 'through',
+ 'during', 'before', 'after', 'above', 'below', 'to', 'from', 'up', 'down',
+ 'in', 'out', 'on', 'off', 'over', 'under', 'again', 'further', 'then',
+ 'once', 'here', 'there', 'when', 'where', 'why', 'how', 'all', 'any',
+ 'both', 'each', 'few', 'more', 'most', 'other', 'some', 'such', 'no', 'nor',
+ 'not', 'only', 'own', 'same', 'so', 'than', 'too', 'very', 's', 't', 'can',
+ 'will', 'just', 'don', 'should', 'now', 'd', 'll', 'm', 'o', 're', 've',
+ 'y', 'ain', 'aren', 'couldn', 'didn', 'doesn', 'hadn', 'hasn', 'haven',
+ 'isn', 'ma', 'mightn', 'mustn', 'needn', 'shan', 'shouldn', 'wasn', 'weren',
+ 'won', 'wouldn', "'ll", "'re", "'ve", "n't", "'s", "'d", "'m", "''", "``"
+}
+
+
+def normalize(text):
+ """Resolve different type of unicode encodings."""
+ return unicodedata.normalize('NFD', text)
+
+
+def filter_word(text):
+ """Take out english stopwords, punctuation, and compound endings."""
+ text = normalize(text)
+ if regex.match(r'^\p{P}+$', text):
+ return True
+ if text.lower() in STOPWORDS:
+ return True
+ return False
+
+
+def filter_ngram(gram, mode='any'):
+ """Decide whether to keep or discard an n-gram.
+
+ Args:
+ gram: list of tokens (length N)
+ mode: Option to throw out ngram if
+ 'any': any single token passes filter_word
+ 'all': all tokens pass filter_word
+ 'ends': book-ended by filterable tokens
+ """
+ filtered = [filter_word(w) for w in gram]
+ if mode == 'any':
+ return any(filtered)
+ elif mode == 'all':
+ return all(filtered)
+ elif mode == 'ends':
+ return filtered[0] or filtered[-1]
+ else:
+ raise ValueError('Invalid mode: %s' % mode)
+
+def get_field(d, field_list):
+ """get the subfield associated to a list of elastic fields
+ E.g. ['file', 'filename'] to d['file']['filename']
+ """
+ if isinstance(field_list, str):
+ return d[field_list]
+ else:
+ idx = d.copy()
+ for field in field_list:
+ idx = idx[field]
+ return idx
diff --git a/drqa/tokenizers/__init__.py b/drqa/tokenizers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d8616047064bdc3e70ef2cf8fc3509364f95b6a
--- /dev/null
+++ b/drqa/tokenizers/__init__.py
@@ -0,0 +1,56 @@
+#!/usr/bin/env python3
+# Copyright 2017-present, Facebook, Inc.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+
+DEFAULTS = {
+ 'corenlp_classpath': os.getenv('CLASSPATH')
+}
+
+
+def set_default(key, value):
+ global DEFAULTS
+ DEFAULTS[key] = value
+
+
+from .corenlp_tokenizer import CoreNLPTokenizer
+from .regexp_tokenizer import RegexpTokenizer
+from .simple_tokenizer import SimpleTokenizer
+
+# Spacy is optional
+try:
+ from .spacy_tokenizer import SpacyTokenizer
+except ImportError:
+ pass
+
+
+def get_class(name):
+ if name == 'spacy':
+ return SpacyTokenizer
+ if name == 'corenlp':
+ return CoreNLPTokenizer
+ if name == 'regexp':
+ return RegexpTokenizer
+ if name == 'simple':
+ return SimpleTokenizer
+
+ raise RuntimeError('Invalid tokenizer: %s' % name)
+
+
+def get_annotators_for_args(args):
+ annotators = set()
+ if args.use_pos:
+ annotators.add('pos')
+ if args.use_lemma:
+ annotators.add('lemma')
+ if args.use_ner:
+ annotators.add('ner')
+ return annotators
+
+
+def get_annotators_for_model(model):
+ return get_annotators_for_args(model.args)
diff --git a/drqa/tokenizers/__pycache__/__init__.cpython-38.pyc b/drqa/tokenizers/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a6038417baf1360eceea0a0cdb790b860834257b
Binary files /dev/null and b/drqa/tokenizers/__pycache__/__init__.cpython-38.pyc differ
diff --git a/drqa/tokenizers/__pycache__/corenlp_tokenizer.cpython-38.pyc b/drqa/tokenizers/__pycache__/corenlp_tokenizer.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..93161290b2662df9610b5ea11ece7a8d888db8d5
Binary files /dev/null and b/drqa/tokenizers/__pycache__/corenlp_tokenizer.cpython-38.pyc differ
diff --git a/drqa/tokenizers/__pycache__/regexp_tokenizer.cpython-38.pyc b/drqa/tokenizers/__pycache__/regexp_tokenizer.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2cc7d05c1da9129bc7dc9e73d6af499fff7ef9f9
Binary files /dev/null and b/drqa/tokenizers/__pycache__/regexp_tokenizer.cpython-38.pyc differ
diff --git a/drqa/tokenizers/__pycache__/simple_tokenizer.cpython-38.pyc b/drqa/tokenizers/__pycache__/simple_tokenizer.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..92bd1e7ff012c3f0d16145187b77a0cc09a225db
Binary files /dev/null and b/drqa/tokenizers/__pycache__/simple_tokenizer.cpython-38.pyc differ
diff --git a/drqa/tokenizers/__pycache__/spacy_tokenizer.cpython-38.pyc b/drqa/tokenizers/__pycache__/spacy_tokenizer.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..97d019553b1cbb918f481cb57a94b17e3f368e67
Binary files /dev/null and b/drqa/tokenizers/__pycache__/spacy_tokenizer.cpython-38.pyc differ
diff --git a/drqa/tokenizers/__pycache__/tokenizer.cpython-38.pyc b/drqa/tokenizers/__pycache__/tokenizer.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cc84891b0e751430423d5e89b9c27763f0bd9a7a
Binary files /dev/null and b/drqa/tokenizers/__pycache__/tokenizer.cpython-38.pyc differ
diff --git a/drqa/tokenizers/corenlp_tokenizer.py b/drqa/tokenizers/corenlp_tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..10dece7432d7caf30f3845f37446a5ce353dd853
--- /dev/null
+++ b/drqa/tokenizers/corenlp_tokenizer.py
@@ -0,0 +1,122 @@
+#!/usr/bin/env python3
+# Copyright 2017-present, Facebook, Inc.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""Simple wrapper around the Stanford CoreNLP pipeline.
+
+Serves commands to a java subprocess running the jar. Requires java 8.
+"""
+
+import copy
+import json
+import pexpect
+
+from .tokenizer import Tokens, Tokenizer
+from . import DEFAULTS
+
+
+class CoreNLPTokenizer(Tokenizer):
+
+ def __init__(self, **kwargs):
+ """
+ Args:
+ annotators: set that can include pos, lemma, and ner.
+ classpath: Path to the corenlp directory of jars
+ mem: Java heap memory
+ """
+ self.classpath = (kwargs.get('classpath') or
+ DEFAULTS['corenlp_classpath'])
+ self.annotators = copy.deepcopy(kwargs.get('annotators', set()))
+ self.mem = kwargs.get('mem', '2g')
+ self._launch()
+
+ def _launch(self):
+ """Start the CoreNLP jar with pexpect."""
+ annotators = ['tokenize', 'ssplit']
+ if 'ner' in self.annotators:
+ annotators.extend(['pos', 'lemma', 'ner'])
+ elif 'lemma' in self.annotators:
+ annotators.extend(['pos', 'lemma'])
+ elif 'pos' in self.annotators:
+ annotators.extend(['pos'])
+ annotators = ','.join(annotators)
+ options = ','.join(['untokenizable=noneDelete',
+ 'invertible=true'])
+ cmd = ['java', '-mx' + self.mem, '-cp', '"%s"' % self.classpath,
+ 'edu.stanford.nlp.pipeline.StanfordCoreNLP', '-annotators',
+ annotators, '-tokenize.options', options,
+ '-outputFormat', 'json', '-prettyPrint', 'false']
+
+ # We use pexpect to keep the subprocess alive and feed it commands.
+ # Because we don't want to get hit by the max terminal buffer size,
+ # we turn off canonical input processing to have unlimited bytes.
+ self.corenlp = pexpect.spawn('/bin/bash', maxread=100000, timeout=60)
+ self.corenlp.setecho(False)
+ self.corenlp.sendline('stty -icanon')
+ self.corenlp.sendline(' '.join(cmd))
+ self.corenlp.delaybeforesend = 0
+ self.corenlp.delayafterread = 0
+ self.corenlp.expect_exact('NLP>', searchwindowsize=100)
+
+ @staticmethod
+ def _convert(token):
+ if token == '-LRB-':
+ return '('
+ if token == '-RRB-':
+ return ')'
+ if token == '-LSB-':
+ return '['
+ if token == '-RSB-':
+ return ']'
+ if token == '-LCB-':
+ return '{'
+ if token == '-RCB-':
+ return '}'
+ return token
+
+ def tokenize(self, text):
+ # Since we're feeding text to the commandline, we're waiting on seeing
+ # the NLP> prompt. Hacky!
+ if 'NLP>' in text:
+ raise RuntimeError('Bad token (NLP>) in text!')
+
+ # Sending q will cause the process to quit -- manually override
+ if text.lower().strip() == 'q':
+ token = text.strip()
+ index = text.index(token)
+ data = [(token, text[index:], (index, index + 1), 'NN', 'q', 'O')]
+ return Tokens(data, self.annotators)
+
+ # Minor cleanup before tokenizing.
+ clean_text = text.replace('\n', ' ')
+
+ self.corenlp.sendline(clean_text.encode('utf-8'))
+ self.corenlp.expect_exact('NLP>', searchwindowsize=100)
+
+ # Skip to start of output (may have been stderr logging messages)
+ output = self.corenlp.before
+ start = output.find(b'{"sentences":')
+ output = json.loads(output[start:].decode('utf-8'))
+
+ data = []
+ tokens = [t for s in output['sentences'] for t in s['tokens']]
+ for i in range(len(tokens)):
+ # Get whitespace
+ start_ws = tokens[i]['characterOffsetBegin']
+ if i + 1 < len(tokens):
+ end_ws = tokens[i + 1]['characterOffsetBegin']
+ else:
+ end_ws = tokens[i]['characterOffsetEnd']
+
+ data.append((
+ self._convert(tokens[i]['word']),
+ text[start_ws: end_ws],
+ (tokens[i]['characterOffsetBegin'],
+ tokens[i]['characterOffsetEnd']),
+ tokens[i].get('pos', None),
+ tokens[i].get('lemma', None),
+ tokens[i].get('ner', None)
+ ))
+ return Tokens(data, self.annotators)
diff --git a/drqa/tokenizers/regexp_tokenizer.py b/drqa/tokenizers/regexp_tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..94858ec71845c13bc5a1e049845e8ab3225cc894
--- /dev/null
+++ b/drqa/tokenizers/regexp_tokenizer.py
@@ -0,0 +1,100 @@
+#!/usr/bin/env python3
+# Copyright 2017-present, Facebook, Inc.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""Regex based tokenizer that emulates the Stanford/NLTK PTB tokenizers.
+
+However it is purely in Python, supports robust untokenization, unicode,
+and requires minimal dependencies.
+"""
+
+import regex
+import logging
+from .tokenizer import Tokens, Tokenizer
+
+logger = logging.getLogger(__name__)
+
+
+class RegexpTokenizer(Tokenizer):
+ DIGIT = r'\p{Nd}+([:\.\,]\p{Nd}+)*'
+ TITLE = (r'(dr|esq|hon|jr|mr|mrs|ms|prof|rev|sr|st|rt|messrs|mmes|msgr)'
+ r'\.(?=\p{Z})')
+ ABBRV = r'([\p{L}]\.){2,}(?=\p{Z}|$)'
+ ALPHA_NUM = r'[\p{L}\p{N}\p{M}]++'
+ HYPHEN = r'{A}([-\u058A\u2010\u2011]{A})+'.format(A=ALPHA_NUM)
+ NEGATION = r"((?!n't)[\p{L}\p{N}\p{M}])++(?=n't)|n't"
+ CONTRACTION1 = r"can(?=not\b)"
+ CONTRACTION2 = r"'([tsdm]|re|ll|ve)\b"
+ START_DQUOTE = r'(?<=[\p{Z}\(\[{<]|^)(``|["\u0093\u201C\u00AB])(?!\p{Z})'
+ START_SQUOTE = r'(?<=[\p{Z}\(\[{<]|^)[\'\u0091\u2018\u201B\u2039](?!\p{Z})'
+ END_DQUOTE = r'(?%s)|(?P%s)|(?P%s)|(?P%s)|(?P%s)|'
+ '(?P%s)|(?P%s)|(?P%s)|(?P%s)|'
+ '(?P%s)|(?P%s)|(?P%s)|(?P%s)|'
+ '(?%s)|(?P%s)|(?P%s)' %
+ (self.DIGIT, self.TITLE, self.ABBRV, self.NEGATION, self.HYPHEN,
+ self.CONTRACTION1, self.ALPHA_NUM, self.CONTRACTION2,
+ self.START_DQUOTE, self.END_DQUOTE, self.START_SQUOTE,
+ self.END_SQUOTE, self.DASH, self.ELLIPSES, self.PUNCT,
+ self.NON_WS),
+ flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE
+ )
+ if len(kwargs.get('annotators', {})) > 0:
+ logger.warning('%s only tokenizes! Skipping annotators: %s' %
+ (type(self).__name__, kwargs.get('annotators')))
+ self.annotators = set()
+ self.substitutions = kwargs.get('substitutions', True)
+
+ def tokenize(self, text):
+ data = []
+ matches = [m for m in self._regexp.finditer(text)]
+ for i in range(len(matches)):
+ # Get text
+ token = matches[i].group()
+
+ # Make normalizations for special token types
+ if self.substitutions:
+ groups = matches[i].groupdict()
+ if groups['sdquote']:
+ token = "``"
+ elif groups['edquote']:
+ token = "''"
+ elif groups['ssquote']:
+ token = "`"
+ elif groups['esquote']:
+ token = "'"
+ elif groups['dash']:
+ token = '--'
+ elif groups['ellipses']:
+ token = '...'
+
+ # Get whitespace
+ span = matches[i].span()
+ start_ws = span[0]
+ if i + 1 < len(matches):
+ end_ws = matches[i + 1].span()[0]
+ else:
+ end_ws = span[1]
+
+ # Format data
+ data.append((
+ token,
+ text[start_ws: end_ws],
+ span,
+ ))
+ return Tokens(data, self.annotators)
diff --git a/drqa/tokenizers/simple_tokenizer.py b/drqa/tokenizers/simple_tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..bad601390457c08125c44007714eb40e36ac7ee1
--- /dev/null
+++ b/drqa/tokenizers/simple_tokenizer.py
@@ -0,0 +1,57 @@
+#!/usr/bin/env python3
+# Copyright 2017-present, Facebook, Inc.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""Basic tokenizer that splits text into alpha-numeric tokens and
+non-whitespace tokens.
+"""
+
+import regex
+import logging
+from .tokenizer import Tokens, Tokenizer
+
+logger = logging.getLogger(__name__)
+
+
+class SimpleTokenizer(Tokenizer):
+ ALPHA_NUM = r'[\p{L}\p{N}\p{M}]+'
+ NON_WS = r'[^\p{Z}\p{C}]'
+
+ def __init__(self, **kwargs):
+ """
+ Args:
+ annotators: None or empty set (only tokenizes).
+ """
+ self._regexp = regex.compile(
+ '(%s)|(%s)' % (self.ALPHA_NUM, self.NON_WS),
+ flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE
+ )
+ if len(kwargs.get('annotators', {})) > 0:
+ logger.warning('%s only tokenizes! Skipping annotators: %s' %
+ (type(self).__name__, kwargs.get('annotators')))
+ self.annotators = set()
+
+ def tokenize(self, text):
+ data = []
+ matches = [m for m in self._regexp.finditer(text)]
+ for i in range(len(matches)):
+ # Get text
+ token = matches[i].group()
+
+ # Get whitespace
+ span = matches[i].span()
+ start_ws = span[0]
+ if i + 1 < len(matches):
+ end_ws = matches[i + 1].span()[0]
+ else:
+ end_ws = span[1]
+
+ # Format data
+ data.append((
+ token,
+ text[start_ws: end_ws],
+ span,
+ ))
+ return Tokens(data, self.annotators)
diff --git a/drqa/tokenizers/spacy_tokenizer.py b/drqa/tokenizers/spacy_tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc05d59afde21541e3a392674388263b61326382
--- /dev/null
+++ b/drqa/tokenizers/spacy_tokenizer.py
@@ -0,0 +1,62 @@
+#!/usr/bin/env python3
+# Copyright 2017-present, Facebook, Inc.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""Tokenizer that is backed by spaCy (spacy.io).
+
+Requires spaCy package and the spaCy english model.
+"""
+
+import spacy
+import copy
+from .tokenizer import Tokens, Tokenizer
+
+
+class SpacyTokenizer(Tokenizer):
+
+ def __init__(self, **kwargs):
+ """
+ Args:
+ annotators: set that can include pos, lemma, and ner.
+ model: spaCy model to use (either path, or keyword like 'en').
+ """
+ model = kwargs.get('model', 'en')
+ self.annotators = copy.deepcopy(kwargs.get('annotators', set()))
+ nlp_kwargs = {'parser': False}
+ if not any([p in self.annotators for p in ['lemma', 'pos', 'ner']]):
+ nlp_kwargs['tagger'] = False
+ if 'ner' not in self.annotators:
+ nlp_kwargs['entity'] = False
+ self.nlp = spacy.load(model, **nlp_kwargs)
+
+ def tokenize(self, text):
+ # We don't treat new lines as tokens.
+ clean_text = text.replace('\n', ' ')
+ tokens = self.nlp.tokenizer(clean_text)
+ if any([p in self.annotators for p in ['lemma', 'pos', 'ner']]):
+ self.nlp.tagger(tokens)
+ if 'ner' in self.annotators:
+ self.nlp.entity(tokens)
+
+ data = []
+ for i in range(len(tokens)):
+ # Get whitespace
+ start_ws = tokens[i].idx
+ if i + 1 < len(tokens):
+ end_ws = tokens[i + 1].idx
+ else:
+ end_ws = tokens[i].idx + len(tokens[i].text)
+
+ data.append((
+ tokens[i].text,
+ text[start_ws: end_ws],
+ (tokens[i].idx, tokens[i].idx + len(tokens[i].text)),
+ tokens[i].tag_,
+ tokens[i].lemma_,
+ tokens[i].ent_type_,
+ ))
+
+ # Set special option for non-entity tag: '' vs 'O' in spaCy
+ return Tokens(data, self.annotators, opts={'non_ent': ''})
diff --git a/drqa/tokenizers/tokenizer.py b/drqa/tokenizers/tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..e4e588c3055414a55e06981a1818cb0394ca990a
--- /dev/null
+++ b/drqa/tokenizers/tokenizer.py
@@ -0,0 +1,139 @@
+#!/usr/bin/env python3
+# Copyright 2017-present, Facebook, Inc.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""Base tokenizer/tokens classes and utilities."""
+
+import copy
+
+
+class Tokens(object):
+ """A class to represent a list of tokenized text."""
+ TEXT = 0
+ TEXT_WS = 1
+ SPAN = 2
+ POS = 3
+ LEMMA = 4
+ NER = 5
+
+ def __init__(self, data, annotators, opts=None):
+ self.data = data
+ self.annotators = annotators
+ self.opts = opts or {}
+
+ def __len__(self):
+ """The number of tokens."""
+ return len(self.data)
+
+ def slice(self, i=None, j=None):
+ """Return a view of the list of tokens from [i, j)."""
+ new_tokens = copy.copy(self)
+ new_tokens.data = self.data[i: j]
+ return new_tokens
+
+ def untokenize(self):
+ """Returns the original text (with whitespace reinserted)."""
+ return ''.join([t[self.TEXT_WS] for t in self.data]).strip()
+
+ def words(self, uncased=False):
+ """Returns a list of the text of each token
+
+ Args:
+ uncased: lower cases text
+ """
+ if uncased:
+ return [t[self.TEXT].lower() for t in self.data]
+ else:
+ return [t[self.TEXT] for t in self.data]
+
+ def offsets(self):
+ """Returns a list of [start, end) character offsets of each token."""
+ return [t[self.SPAN] for t in self.data]
+
+ def pos(self):
+ """Returns a list of part-of-speech tags of each token.
+ Returns None if this annotation was not included.
+ """
+ if 'pos' not in self.annotators:
+ return None
+ return [t[self.POS] for t in self.data]
+
+ def lemmas(self):
+ """Returns a list of the lemmatized text of each token.
+ Returns None if this annotation was not included.
+ """
+ if 'lemma' not in self.annotators:
+ return None
+ return [t[self.LEMMA] for t in self.data]
+
+ def entities(self):
+ """Returns a list of named-entity-recognition tags of each token.
+ Returns None if this annotation was not included.
+ """
+ if 'ner' not in self.annotators:
+ return None
+ return [t[self.NER] for t in self.data]
+
+ def ngrams(self, n=1, uncased=False, filter_fn=None, as_strings=True):
+ """Returns a list of all ngrams from length 1 to n.
+
+ Args:
+ n: upper limit of ngram length
+ uncased: lower cases text
+ filter_fn: user function that takes in an ngram list and returns
+ True or False to keep or not keep the ngram
+ as_string: return the ngram as a string vs list
+ """
+ def _skip(gram):
+ if not filter_fn:
+ return False
+ return filter_fn(gram)
+
+ words = self.words(uncased)
+ ngrams = [(s, e + 1)
+ for s in range(len(words))
+ for e in range(s, min(s + n, len(words)))
+ if not _skip(words[s:e + 1])]
+
+ # Concatenate into strings
+ if as_strings:
+ ngrams = ['{}'.format(' '.join(words[s:e])) for (s, e) in ngrams]
+
+ return ngrams
+
+ def entity_groups(self):
+ """Group consecutive entity tokens with the same NER tag."""
+ entities = self.entities()
+ if not entities:
+ return None
+ non_ent = self.opts.get('non_ent', 'O')
+ groups = []
+ idx = 0
+ while idx < len(entities):
+ ner_tag = entities[idx]
+ # Check for entity tag
+ if ner_tag != non_ent:
+ # Chomp the sequence
+ start = idx
+ while (idx < len(entities) and entities[idx] == ner_tag):
+ idx += 1
+ groups.append((self.slice(start, idx).untokenize(), ner_tag))
+ else:
+ idx += 1
+ return groups
+
+
+class Tokenizer(object):
+ """Base tokenizer class.
+ Tokenizers implement tokenize, which should return a Tokens class.
+ """
+ def tokenize(self, text):
+ raise NotImplementedError
+
+ def shutdown(self):
+ pass
+
+ def __del__(self):
+ self.shutdown()
diff --git a/html2lines.py b/html2lines.py
new file mode 100644
index 0000000000000000000000000000000000000000..420948dc393cf0047c7a177a26d8faecbe839c6c
--- /dev/null
+++ b/html2lines.py
@@ -0,0 +1,72 @@
+from distutils.command.config import config
+import requests
+from time import sleep
+import trafilatura
+from trafilatura.meta import reset_caches
+from trafilatura.settings import DEFAULT_CONFIG
+import spacy
+import os
+os.system("python -m spacy download en_core_web_sm")
+nlp = spacy.load('en_core_web_sm')
+import sys
+
+DEFAULT_CONFIG.MAX_FILE_SIZE = 50000
+
+def get_page(url):
+ page = None
+ for i in range(3):
+ try:
+ page = trafilatura.fetch_url(url, config=DEFAULT_CONFIG)
+ assert page is not None
+ print("Fetched "+url, file=sys.stderr)
+ break
+ except:
+ sleep(3)
+ return page
+
+def url2lines(url):
+ page = get_page(url)
+
+ if page is None:
+ return []
+
+ lines = html2lines(page)
+ return lines
+
+def line_correction(lines, max_size=100):
+ out_lines = []
+ for line in lines:
+ if len(line) < 4:
+ continue
+
+ if len(line) > max_size:
+ doc = nlp(line[:5000]) # We split lines into sentences, but for performance we take only the first 5k characters per line
+ stack = ""
+ for sent in doc.sents:
+ if len(stack) > 0:
+ stack += " "
+ stack += str(sent).strip()
+ if len(stack) > max_size:
+ out_lines.append(stack)
+ stack = ""
+
+ if len(stack) > 0:
+ out_lines.append(stack)
+ else:
+ out_lines.append(line)
+
+ return out_lines
+
+def html2lines(page):
+ out_lines = []
+
+ if len(page.strip()) == 0 or page is None:
+ return out_lines
+
+ text = trafilatura.extract(page, config=DEFAULT_CONFIG)
+ reset_caches()
+
+ if text is None:
+ return out_lines
+
+ return text.split("\n") # We just spit out the entire page, so need to reformat later.
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..94899466fcfdd3b8fba95c0bc3ef68adf2dda2e6
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,22 @@
+gradio
+nltk
+rank_bm25
+accelerate
+trafilatura
+spacy
+pytorch_lightning
+transformers==4.29.2
+datasets
+leven
+scikit-learn
+pexpect
+elasticsearch
+torch
+huggingface_hub
+google-api-python-client
+wikipedia-api
+beautifulsoup4
+azure-storage-file-share
+azure-storage-blob
+bm25s
+PyStemmer
diff --git a/setup.sh b/setup.sh
new file mode 100644
index 0000000000000000000000000000000000000000..ad8bb8ee847c128cbc233e57fa8f1b0d62c84d4e
--- /dev/null
+++ b/setup.sh
@@ -0,0 +1 @@
+python -m spacy download en_core_web_sm
\ No newline at end of file
diff --git a/style.css b/style.css
new file mode 100644
index 0000000000000000000000000000000000000000..a152a297f35b617595cb3ac3e5c8f31991a3f2f6
--- /dev/null
+++ b/style.css
@@ -0,0 +1,379 @@
+
+/* :root {
+ --user-image: url('https://ih1.redbubble.net/image.4776899543.6215/st,small,507x507-pad,600x600,f8f8f8.jpg');
+ } */
+
+.warning-box {
+ background-color: #fff3cd;
+ border: 1px solid #ffeeba;
+ border-radius: 4px;
+ padding: 15px 20px;
+ font-size: 14px;
+ color: #856404;
+ display: inline-block;
+ margin-bottom: 15px;
+ }
+
+
+.tip-box {
+ background-color: #f0f9ff;
+ border: 1px solid #80d4fa;
+ border-radius: 4px;
+ margin-top:20px;
+ padding: 15px 20px;
+ font-size: 14px;
+ display: inline-block;
+ margin-bottom: 15px;
+ width: auto;
+ color:black !important;
+}
+
+body.dark .warning-box * {
+ color:black !important;
+}
+
+
+body.dark .tip-box * {
+ color:black !important;
+}
+
+
+.tip-box-title {
+ font-weight: bold;
+ font-size: 14px;
+ margin-bottom: 5px;
+}
+
+.light-bulb {
+ display: inline;
+ margin-right: 5px;
+}
+
+.gr-box {border-color: #d6c37c}
+
+#hidden-message{
+ display:none;
+}
+
+.message{
+ font-size:14px !important;
+}
+
+
+a {
+ text-decoration: none;
+ color: inherit;
+}
+
+.card {
+ background-color: white;
+ border-radius: 10px;
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
+ overflow: hidden;
+ display: flex;
+ flex-direction: column;
+ margin:20px;
+}
+
+.card-content {
+ padding: 20px;
+}
+
+.card-content h2 {
+ font-size: 14px !important;
+ font-weight: bold;
+ margin-bottom: 10px;
+ margin-top:0px !important;
+ color:#dc2626!important;;
+}
+
+.card-content p {
+ font-size: 12px;
+ margin-bottom: 0;
+}
+
+.card-footer {
+ background-color: #f4f4f4;
+ font-size: 10px;
+ padding: 10px;
+ display: flex;
+ justify-content: space-between;
+ align-items: center;
+}
+
+.card-footer span {
+ flex-grow: 1;
+ text-align: left;
+ color: #999 !important;
+}
+
+.pdf-link {
+ display: inline-flex;
+ align-items: center;
+ margin-left: auto;
+ text-decoration: none!important;
+ font-size: 14px;
+ word-wrap: break-word; /* For IE */
+ word-break: break-all; /* For all other browsers */
+}
+
+
+
+.message.user{
+ /* background-color:#7494b0 !important; */
+ border:none;
+ /* color:white!important; */
+}
+
+.message.bot{
+ /* background-color:#f2f2f7 !important; */
+ border:none;
+}
+
+/* .gallery-item > div:hover{
+ background-color:#7494b0 !important;
+ color:white!important;
+}
+
+.gallery-item:hover{
+ border:#7494b0 !important;
+}
+
+.gallery-item > div{
+ background-color:white !important;
+ color:#577b9b!important;
+}
+
+.label{
+ color:#577b9b!important;
+} */
+
+/* .paginate{
+ color:#577b9b!important;
+} */
+
+
+
+/* span[data-testid="block-info"]{
+ background:none !important;
+ color:#577b9b;
+ } */
+
+/* Pseudo-element for the circularly cropped picture */
+/* .message.bot::before {
+ content: '';
+ position: absolute;
+ top: -10px;
+ left: -10px;
+ width: 30px;
+ height: 30px;
+ background-image: var(--user-image);
+ background-size: cover;
+ background-position: center;
+ border-radius: 50%;
+ z-index: 10;
+ }
+ */
+
+label.selected{
+ background:none !important;
+}
+
+#submit-button{
+ padding:0px !important;
+}
+
+
+@media screen and (min-width: 1024px) {
+ div#tab-examples{
+ height:calc(100vh - 190px) !important;
+ overflow-y: auto;
+ }
+
+ div#sources-textbox{
+ height:calc(100vh - 190px) !important;
+ overflow-y: auto !important;
+ }
+
+ div#tab-config{
+ height:calc(100vh - 190px) !important;
+ overflow-y: auto !important;
+ }
+
+ div#chatbot-row{
+ height:calc(100vh - 90px) !important;
+ }
+
+ div#chatbot{
+ height:calc(100vh - 170px) !important;
+ }
+
+ .max-height{
+ height:calc(100vh - 90px) !important;
+ overflow-y: auto;
+ }
+
+ /* .tabitem:nth-child(n+3) {
+ padding-top:30px;
+ padding-left:40px;
+ padding-right:40px;
+ } */
+}
+
+footer {
+ visibility: hidden;
+ display:none !important;
+}
+
+
+@media screen and (max-width: 767px) {
+ /* Your mobile-specific styles go here */
+
+ div#chatbot{
+ height:500px !important;
+ }
+
+ #submit-button{
+ padding:0px !important;
+ min-width: 80px;
+ }
+
+ /* This will hide all list items */
+ div.tab-nav button {
+ display: none !important;
+ }
+
+ /* This will show only the first list item */
+ div.tab-nav button:first-child {
+ display: block !important;
+ }
+
+ /* This will show only the first list item */
+ div.tab-nav button:nth-child(2) {
+ display: block !important;
+ }
+
+ #right-panel button{
+ display: block !important;
+ }
+
+ /* ... add other mobile-specific styles ... */
+}
+
+
+body.dark .card{
+ background-color: #374151;
+}
+
+body.dark .card-content h2{
+ color:#f4dbd3 !important;
+}
+
+body.dark .card-footer {
+ background-color: #404652;
+}
+
+body.dark .card-footer span {
+ color:white !important;
+}
+
+
+.doc-ref sup{
+ color:#dc2626 !important;
+ margin-right:1px;
+}
+
+.doc-ref_ori{
+ color:#dc2626 !important;
+ margin-right:1px;
+}
+
+.tabitem{
+ border:none !important;
+}
+
+.other-tabs > div{
+ padding-left:40px;
+ padding-right:40px;
+ padding-top:10px;
+}
+
+.gallery-item > div{
+ white-space: normal !important; /* Allow the text to wrap */
+ word-break: break-word !important; /* Break words to prevent overflow */
+ overflow-wrap: break-word !important; /* Break long words if necessary */
+ }
+
+span.chatbot > p > img{
+ margin-top:40px !important;
+ max-height: none !important;
+ max-width: 80% !important;
+ border-radius:0px !important;
+}
+
+
+.chatbot-caption{
+ font-size:11px;
+ font-style:italic;
+ color:#508094;
+}
+
+.ai-generated{
+ font-size:11px!important;
+ font-style:italic;
+ color:#73b8d4 !important;
+}
+
+.card-image > .card-content{
+ background-color:#f1f7fa !important;
+}
+
+
+
+.tab-nav > button.selected{
+ color:#4b8ec3;
+ font-weight:bold;
+ border:none;
+}
+
+.tab-nav{
+ border:none !important;
+}
+
+#input-textbox > label > textarea{
+ border-radius:40px;
+ padding-left:30px;
+ resize:none;
+}
+
+#input-message > div{
+ border:none;
+}
+
+#dropdown-samples{
+ /*! border:none !important; */
+ /*! border-width:0px !important; */
+ background:none !important;
+
+}
+
+#dropdown-samples > .container > .wrap{
+ background-color:white;
+}
+
+
+#tab-examples > div > .form{
+ border:none;
+ background:none !important;
+}
+
+.a-doc-ref{
+ text-decoration: none !important;
+}
+
+
+
+
+
+
+
diff --git a/utils.py b/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ddc12a212abbc735fe244784c2bfbb298c37b28d
--- /dev/null
+++ b/utils.py
@@ -0,0 +1,12 @@
+import numpy as np
+import random
+import string
+import uuid
+
+
+def create_user_id():
+ """Create user_id
+ str: String to id user
+ """
+ user_id = str(uuid.uuid4())
+ return user_id
\ No newline at end of file