zhenyundeng commited on
Commit
e62781a
·
1 Parent(s): e5c50a7
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +6 -5
  2. app.py +1368 -0
  3. drqa/__init__.py +23 -0
  4. drqa/__pycache__/__init__.cpython-38.pyc +0 -0
  5. drqa/pipeline/__init__.py +27 -0
  6. drqa/pipeline/__pycache__/__init__.cpython-38.pyc +0 -0
  7. drqa/pipeline/__pycache__/drqa.cpython-38.pyc +0 -0
  8. drqa/pipeline/drqa.py +312 -0
  9. drqa/reader/__init__.py +28 -0
  10. drqa/reader/__pycache__/__init__.cpython-38.pyc +0 -0
  11. drqa/reader/__pycache__/config.cpython-38.pyc +0 -0
  12. drqa/reader/__pycache__/data.cpython-38.pyc +0 -0
  13. drqa/reader/__pycache__/layers.cpython-38.pyc +0 -0
  14. drqa/reader/__pycache__/model.cpython-38.pyc +0 -0
  15. drqa/reader/__pycache__/predictor.cpython-38.pyc +0 -0
  16. drqa/reader/__pycache__/rnn_reader.cpython-38.pyc +0 -0
  17. drqa/reader/__pycache__/utils.cpython-38.pyc +0 -0
  18. drqa/reader/__pycache__/vector.cpython-38.pyc +0 -0
  19. drqa/reader/config.py +128 -0
  20. drqa/reader/data.py +131 -0
  21. drqa/reader/layers.py +311 -0
  22. drqa/reader/model.py +482 -0
  23. drqa/reader/predictor.py +145 -0
  24. drqa/reader/rnn_reader.py +135 -0
  25. drqa/reader/utils.py +288 -0
  26. drqa/reader/vector.py +127 -0
  27. drqa/retriever/__init__.py +38 -0
  28. drqa/retriever/__pycache__/__init__.cpython-38.pyc +0 -0
  29. drqa/retriever/__pycache__/doc_db.cpython-38.pyc +0 -0
  30. drqa/retriever/__pycache__/elastic_doc_ranker.cpython-38.pyc +0 -0
  31. drqa/retriever/__pycache__/tfidf_doc_ranker.cpython-38.pyc +0 -0
  32. drqa/retriever/__pycache__/utils.cpython-38.pyc +0 -0
  33. drqa/retriever/doc_db.py +81 -0
  34. drqa/retriever/elastic_doc_ranker.py +109 -0
  35. drqa/retriever/tfidf_doc_ranker.py +121 -0
  36. drqa/retriever/utils.py +120 -0
  37. drqa/tokenizers/__init__.py +56 -0
  38. drqa/tokenizers/__pycache__/__init__.cpython-38.pyc +0 -0
  39. drqa/tokenizers/__pycache__/corenlp_tokenizer.cpython-38.pyc +0 -0
  40. drqa/tokenizers/__pycache__/regexp_tokenizer.cpython-38.pyc +0 -0
  41. drqa/tokenizers/__pycache__/simple_tokenizer.cpython-38.pyc +0 -0
  42. drqa/tokenizers/__pycache__/spacy_tokenizer.cpython-38.pyc +0 -0
  43. drqa/tokenizers/__pycache__/tokenizer.cpython-38.pyc +0 -0
  44. drqa/tokenizers/corenlp_tokenizer.py +122 -0
  45. drqa/tokenizers/regexp_tokenizer.py +100 -0
  46. drqa/tokenizers/simple_tokenizer.py +57 -0
  47. drqa/tokenizers/spacy_tokenizer.py +62 -0
  48. drqa/tokenizers/tokenizer.py +139 -0
  49. html2lines.py +72 -0
  50. requirements.txt +22 -0
README.md CHANGED
@@ -1,12 +1,13 @@
1
  ---
2
- title: AVeriTeC API
3
- emoji: 🚀
4
- colorFrom: blue
5
- colorTo: gray
6
  sdk: gradio
7
- sdk_version: 4.38.1
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: AVeriTeC
3
+ emoji: 🏆
4
+ colorFrom: purple
5
+ colorTo: red
6
  sdk: gradio
7
+ sdk_version: 4.37.2
8
  app_file: app.py
9
  pinned: false
10
+ license: apache-2.0
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,1368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Created by zd302 at 08/07/2024
4
+
5
+ import gradio as gr
6
+ import tqdm
7
+ import torch
8
+ import numpy as np
9
+ from time import sleep
10
+ import threading
11
+ import gc
12
+ import os
13
+ import json
14
+ import pytorch_lightning as pl
15
+ from urllib.parse import urlparse
16
+ from accelerate import Accelerator
17
+
18
+ from transformers import BartTokenizer, BartForConditionalGeneration
19
+ from transformers import BloomTokenizerFast, BloomForCausalLM, BertTokenizer, BertForSequenceClassification
20
+ from transformers import RobertaTokenizer, RobertaForSequenceClassification
21
+
22
+ from rank_bm25 import BM25Okapi
23
+ # import bm25s
24
+ # import Stemmer # optional: for stemming
25
+ from html2lines import url2lines
26
+ from googleapiclient.discovery import build
27
+ from averitec.models.DualEncoderModule import DualEncoderModule
28
+ from averitec.models.SequenceClassificationModule import SequenceClassificationModule
29
+ from averitec.models.JustificationGenerationModule import JustificationGenerationModule
30
+ from averitec.data.sample_claims import CLAIMS_Type
31
+
32
+ # ---------------------------------------------------------------------------
33
+ # load .env
34
+ from utils import create_user_id
35
+ user_id = create_user_id()
36
+
37
+ from datetime import datetime
38
+ from azure.storage.fileshare import ShareServiceClient
39
+ try:
40
+ from dotenv import load_dotenv
41
+ load_dotenv()
42
+ except Exception as e:
43
+ pass
44
+
45
+ account_url = os.environ["AZURE_ACCOUNT_URL"]
46
+ credential = {
47
+ "account_key": os.environ['AZURE_ACCOUNT_KEY'],
48
+ "account_name": os.environ['AZURE_ACCOUNT_NAME']
49
+ }
50
+
51
+ file_share_name = "averitec"
52
+ azure_service = ShareServiceClient(account_url=account_url, credential=credential)
53
+ azure_share_client = azure_service.get_share_client(file_share_name)
54
+
55
+ # ---------- Setting ----------
56
+ import requests
57
+ from bs4 import BeautifulSoup
58
+ import wikipediaapi
59
+ wiki_wiki = wikipediaapi.Wikipedia('AVeriTeC ([email protected])', 'en')
60
+
61
+ import nltk
62
+ nltk.download('punkt')
63
+ from nltk import pos_tag, word_tokenize, sent_tokenize
64
+
65
+ import spacy
66
+ os.system("python -m spacy download en_core_web_sm")
67
+ nlp = spacy.load("en_core_web_sm")
68
+
69
+ # ---------------------------------------------------------------------------
70
+ # Load sample dict for AVeriTeC search
71
+ # all_samples_dict = json.load(open('averitec/data/all_samples.json', 'r'))
72
+
73
+ # ---------------------------------------------------------------------------
74
+ # ---------- Load pretrained models ----------
75
+ # ---------- load Evidence retrieval model ----------
76
+ # from drqa import retriever
77
+ # db_class = retriever.get_class('sqlite')
78
+ # doc_db = db_class("averitec/data/wikipedia_dumps/enwiki.db")
79
+ # ranker = retriever.get_class('tfidf')(tfidf_path="averitec/data/wikipedia_dumps/enwiki-tfidf-with-id-title.npz")
80
+
81
+ # ---------- Load Veracity and Justification prediction model ----------
82
+ print("Loading models ...")
83
+ LABEL = [
84
+ "Supported",
85
+ "Refuted",
86
+ "Not Enough Evidence",
87
+ "Conflicting Evidence/Cherrypicking",
88
+ ]
89
+ # Veracity
90
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
91
+ veracity_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
92
+ bert_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=4, problem_type="single_label_classification")
93
+ veracity_model = SequenceClassificationModule.load_from_checkpoint("averitec/pretrained_models/bert_veracity.ckpt",
94
+ tokenizer=veracity_tokenizer, model=bert_model).to(device)
95
+ # Justification
96
+ justification_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large', add_prefix_space=True)
97
+ bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large")
98
+ best_checkpoint = 'averitec/pretrained_models/bart_justifications_verdict-epoch=13-val_loss=2.03-val_meteor=0.28.ckpt'
99
+ justification_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=justification_tokenizer, model=bart_model).to(device)
100
+ # ---------------------------------------------------------------------------
101
+
102
+
103
+ # Set up Gradio Theme
104
+ theme = gr.themes.Base(
105
+ primary_hue="blue",
106
+ secondary_hue="red",
107
+ font=[gr.themes.GoogleFont("Poppins"), "ui-sans-serif", "system-ui", "sans-serif"],
108
+ )
109
+
110
+ # ---------- Setting ----------
111
+
112
+ class Docs:
113
+ def __init__(self, metadata=dict(), page_content=""):
114
+ self.metadata = metadata
115
+ self.page_content = page_content
116
+
117
+
118
+ def make_html_source(source, i):
119
+ meta = source.metadata
120
+ content = source.page_content.strip()
121
+
122
+ card = f"""
123
+ <div class="card" id="doc{i}">
124
+ <div class="card-content">
125
+ <h2>Doc {i} - URL: <a href="{meta['url']}" target="_blank" class="pdf-link">{meta['url']}</a></h2>
126
+ <p>{content}</p>
127
+ </div>
128
+ <div class="card-footer">
129
+ <span>CACHED SOURCE URL:</span>
130
+ <a href="{meta['cached_source_url']}" target="_blank" class="pdf-link">
131
+ <span role="img" aria-label="Open PDF">🔗</span>
132
+ </a>
133
+ </div>
134
+ </div>
135
+ """
136
+
137
+ return card
138
+
139
+
140
+ # ----- veracity_prediction -----
141
+ class SequenceClassificationDataLoader(pl.LightningDataModule):
142
+ def __init__(self, tokenizer, data_file, batch_size, add_extra_nee=False):
143
+ super().__init__()
144
+ self.tokenizer = tokenizer
145
+ self.data_file = data_file
146
+ self.batch_size = batch_size
147
+ self.add_extra_nee = add_extra_nee
148
+
149
+ def tokenize_strings(
150
+ self,
151
+ source_sentences,
152
+ max_length=400,
153
+ pad_to_max_length=False,
154
+ return_tensors="pt",
155
+ ):
156
+ encoded_dict = self.tokenizer(
157
+ source_sentences,
158
+ max_length=max_length,
159
+ padding="max_length" if pad_to_max_length else "longest",
160
+ truncation=True,
161
+ return_tensors=return_tensors,
162
+ )
163
+
164
+ input_ids = encoded_dict["input_ids"]
165
+ attention_masks = encoded_dict["attention_mask"]
166
+
167
+ return input_ids, attention_masks
168
+
169
+ def quadruple_to_string(self, claim, question, answer, bool_explanation=""):
170
+ if bool_explanation is not None and len(bool_explanation) > 0:
171
+ bool_explanation = ", because " + bool_explanation.lower().strip()
172
+ else:
173
+ bool_explanation = ""
174
+ return (
175
+ "[CLAIM] "
176
+ + claim.strip()
177
+ + " [QUESTION] "
178
+ + question.strip()
179
+ + " "
180
+ + answer.strip()
181
+ + bool_explanation
182
+ )
183
+
184
+
185
+ def averitec_veracity_prediction(claim, qa_evidence):
186
+ bert_model_name = "bert-base-uncased"
187
+ tokenizer = BertTokenizer.from_pretrained(bert_model_name)
188
+ bert_model = BertForSequenceClassification.from_pretrained(bert_model_name, num_labels=4,
189
+ problem_type="single_label_classification")
190
+
191
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
192
+ trained_model = SequenceClassificationModule.load_from_checkpoint("averitec/pretrained_models/bert_veracity.ckpt",
193
+ tokenizer=tokenizer, model=bert_model).to(device)
194
+
195
+ dataLoader = SequenceClassificationDataLoader(
196
+ tokenizer=tokenizer,
197
+ data_file="this_is_discontinued",
198
+ batch_size=32,
199
+ add_extra_nee=False,
200
+ )
201
+
202
+ evidence_strings = []
203
+ for evidence in qa_evidence:
204
+ evidence_strings.append(
205
+ dataLoader.quadruple_to_string(claim, evidence.metadata["query"], evidence.metadata["answer"], ""))
206
+
207
+ if len(evidence_strings) == 0: # If we found no evidence e.g. because google returned 0 pages, just output NEI.
208
+ pred_label = "Not Enough Evidence"
209
+ return pred_label
210
+
211
+ tokenized_strings, attention_mask = dataLoader.tokenize_strings(evidence_strings)
212
+ example_support = torch.argmax(
213
+ trained_model(tokenized_strings.to(device), attention_mask=attention_mask.to(device)).logits, axis=1)
214
+
215
+ has_unanswerable = False
216
+ has_true = False
217
+ has_false = False
218
+
219
+ for v in example_support:
220
+ if v == 0:
221
+ has_true = True
222
+ if v == 1:
223
+ has_false = True
224
+ if v in (2, 3,): # TODO another hack -- we cant have different labels for train and test so we do this
225
+ has_unanswerable = True
226
+
227
+ if has_unanswerable:
228
+ answer = 2
229
+ elif has_true and not has_false:
230
+ answer = 0
231
+ elif not has_true and has_false:
232
+ answer = 1
233
+ else:
234
+ answer = 3
235
+
236
+ pred_label = LABEL[answer]
237
+
238
+ return pred_label
239
+
240
+
241
+ def fever_veracity_prediction(claim, evidence):
242
+ tokenizer = RobertaTokenizer.from_pretrained('Dzeniks/roberta-fact-check')
243
+ model = RobertaForSequenceClassification.from_pretrained('Dzeniks/roberta-fact-check')
244
+
245
+ evidence_string = ""
246
+ for evi in evidence:
247
+ evidence_string += evi.metadata['title'] + evi.metadata['evidence'] + ' '
248
+
249
+ input_sequence = tokenizer.encode_plus(claim, evidence_string, return_tensors="pt")
250
+ with torch.no_grad():
251
+ prediction = model(**input_sequence)
252
+
253
+ label = torch.argmax(prediction[0]).item()
254
+ pred_label = LABEL[label]
255
+
256
+ return pred_label
257
+
258
+
259
+ def veracity_prediction(claim, qa_evidence):
260
+ # bert_model_name = "bert-base-uncased"
261
+ # tokenizer = BertTokenizer.from_pretrained(bert_model_name)
262
+ # bert_model = BertForSequenceClassification.from_pretrained(bert_model_name, num_labels=4,
263
+ # problem_type="single_label_classification")
264
+ #
265
+ # device = "cuda:0" if torch.cuda.is_available() else "cpu"
266
+ # trained_model = SequenceClassificationModule.load_from_checkpoint("averitec/pretrained_models/bert_veracity.ckpt",
267
+ # tokenizer=tokenizer, model=bert_model).to(device)
268
+
269
+ dataLoader = SequenceClassificationDataLoader(
270
+ tokenizer=veracity_tokenizer,
271
+ data_file="this_is_discontinued",
272
+ batch_size=32,
273
+ add_extra_nee=False,
274
+ )
275
+
276
+ evidence_strings = []
277
+ for evidence in qa_evidence:
278
+ evidence_strings.append(
279
+ dataLoader.quadruple_to_string(claim, evidence.metadata["query"], evidence.metadata["answer"], ""))
280
+
281
+ if len(evidence_strings) == 0: # If we found no evidence e.g. because google returned 0 pages, just output NEI.
282
+ pred_label = "Not Enough Evidence"
283
+ return pred_label
284
+
285
+ tokenized_strings, attention_mask = dataLoader.tokenize_strings(evidence_strings)
286
+ example_support = torch.argmax(veracity_model(tokenized_strings.to(device), attention_mask=attention_mask.to(device)).logits, axis=1)
287
+
288
+ has_unanswerable = False
289
+ has_true = False
290
+ has_false = False
291
+
292
+ for v in example_support:
293
+ if v == 0:
294
+ has_true = True
295
+ if v == 1:
296
+ has_false = True
297
+ if v in (2, 3,): # TODO another hack -- we cant have different labels for train and test so we do this
298
+ has_unanswerable = True
299
+
300
+ if has_unanswerable:
301
+ answer = 2
302
+ elif has_true and not has_false:
303
+ answer = 0
304
+ elif not has_true and has_false:
305
+ answer = 1
306
+ else:
307
+ answer = 3
308
+
309
+ pred_label = LABEL[answer]
310
+
311
+ return pred_label
312
+
313
+
314
+ def extract_claim_str(claim, qa_evidence, verdict_label):
315
+ claim_str = "[CLAIM] " + claim + " [EVIDENCE] "
316
+
317
+ for evidence in qa_evidence:
318
+ q_text = evidence.metadata['query'].strip()
319
+
320
+ if len(q_text) == 0:
321
+ continue
322
+
323
+ if not q_text[-1] == "?":
324
+ q_text += "?"
325
+
326
+ answer_strings = []
327
+ answer_strings.append(evidence.metadata['answer'])
328
+
329
+ claim_str += q_text
330
+ for a_text in answer_strings:
331
+ if a_text:
332
+ if not a_text[-1] == ".":
333
+ a_text += "."
334
+ claim_str += " " + a_text.strip()
335
+
336
+ claim_str += " "
337
+
338
+ claim_str += " [VERDICT] " + verdict_label
339
+
340
+ return claim_str
341
+
342
+
343
+ def averitec_justification_generation(claim, qa_evidence, verdict_label):
344
+ #
345
+ claim_str = extract_claim_str(claim, qa_evidence, verdict_label)
346
+ claim_str.strip()
347
+
348
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
349
+ tokenizer = BartTokenizer.from_pretrained('facebook/bart-large', add_prefix_space=True)
350
+ bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large")
351
+
352
+ best_checkpoint = 'averitec/pretrained_models/bart_justifications_verdict-epoch=13-val_loss=2.03-val_meteor=0.28.ckpt'
353
+ trained_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=tokenizer,
354
+ model=bart_model).to(device)
355
+
356
+ pred_justification = trained_model.generate(claim_str, device=device)
357
+
358
+ return pred_justification.strip()
359
+
360
+
361
+ def justification_generation(claim, qa_evidence, verdict_label):
362
+ #
363
+ claim_str = extract_claim_str(claim, qa_evidence, verdict_label)
364
+ claim_str.strip()
365
+
366
+ # device = "cuda:0" if torch.cuda.is_available() else "cpu"
367
+ # tokenizer = BartTokenizer.from_pretrained('facebook/bart-large', add_prefix_space=True)
368
+ # bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large")
369
+ #
370
+ # best_checkpoint = 'averitec/pretrained_models/bart_justifications_verdict-epoch=13-val_loss=2.03-val_meteor=0.28.ckpt'
371
+ # trained_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=tokenizer,
372
+ # model=bart_model).to(device)
373
+
374
+ pred_justification = justification_model.generate(claim_str, device=device)
375
+
376
+ return pred_justification.strip()
377
+
378
+
379
+ def QAprediction(claim, evidence, sources):
380
+ parts = []
381
+ #
382
+ evidence_title = f"""<h5>Retrieved Evidence:</h5>"""
383
+ for i, evi in enumerate(evidence, 1):
384
+ part = f"""<span>Doc {i}</span>"""
385
+ subpart = f"""<a href="#doc{i}" class="a-doc-ref" target="_self"><span class='doc-ref'><sup>{i}</sup></span></a>"""
386
+ # subpart = f"""<span class='doc-ref'>{i}</sup></span>"""
387
+ subparts = "".join([part, subpart])
388
+ parts.append(subparts)
389
+
390
+ evidence_part = ", ".join(parts)
391
+
392
+ prediction_title = f"""<h5>Prediction:</h5>"""
393
+ # if 'Google' in sources or 'AVeriTeC' in sources:
394
+ # verdict_label = averitec_veracity_prediction(claim, evidence)
395
+ # justification_label = averitec_justification_generation(claim, evidence, verdict_label)
396
+ # # justification_label = "See retrieved docs."
397
+ # justification_part = f"""<span>Justification: {justification_label}</span>"""
398
+ # if 'WikiPedia' in sources:
399
+ # # verdict_label = fever_veracity_prediction(claim, evidence)
400
+ # justification_label = averitec_justification_generation(claim, evidence, verdict_label)
401
+ # # justification_label = "See retrieved docs."
402
+ # justification_part = f"""<span>Justification: {justification_label}</span>"""
403
+
404
+ verdict_label = veracity_prediction(claim, evidence)
405
+ justification_label = justification_generation(claim, evidence, verdict_label)
406
+ # justification_label = "See retrieved docs."
407
+ justification_part = f"""<span>Justification: {justification_label}</span>"""
408
+
409
+
410
+ verdict_part = f"""Verdict: <span>{verdict_label}.</span><br>"""
411
+
412
+ content_parts = "".join([evidence_title, evidence_part, prediction_title, verdict_part, justification_part])
413
+ # content_parts = "".join([evidence_title, evidence_part, verdict_title, verdict_part, justification_title, justification_part])
414
+
415
+ return content_parts, [verdict_label, justification_label]
416
+
417
+
418
+ # ----------GoogleAPIretriever---------
419
+ def generate_reference_corpus(reference_file):
420
+ with open(reference_file) as f:
421
+ j = json.load(f)
422
+ train_examples = j
423
+
424
+ all_data_corpus = []
425
+ tokenized_corpus = []
426
+
427
+ for train_example in train_examples:
428
+ train_claim = train_example["claim"]
429
+
430
+ speaker = train_example["speaker"].strip() if train_example["speaker"] is not None and len(
431
+ train_example["speaker"]) > 1 else "they"
432
+
433
+ questions = [q["question"] for q in train_example["questions"]]
434
+
435
+ claim_dict_builder = {}
436
+ claim_dict_builder["claim"] = train_claim
437
+ claim_dict_builder["speaker"] = speaker
438
+ claim_dict_builder["questions"] = questions
439
+
440
+ tokenized_corpus.append(nltk.word_tokenize(claim_dict_builder["claim"]))
441
+ all_data_corpus.append(claim_dict_builder)
442
+
443
+ return tokenized_corpus, all_data_corpus
444
+
445
+
446
+ def doc2prompt(doc):
447
+ prompt_parts = "Outrageously, " + doc["speaker"] + " claimed that \"" + doc[
448
+ "claim"].strip() + "\". Criticism includes questions like: "
449
+ questions = [q.strip() for q in doc["questions"]]
450
+ return prompt_parts + " ".join(questions)
451
+
452
+
453
+ def docs2prompt(top_docs):
454
+ return "\n\n".join([doc2prompt(d) for d in top_docs])
455
+
456
+
457
+ def prompt_question_generation(test_claim, speaker="they", topk=10):
458
+ #
459
+ reference_file = "averitec_code/data/train.json"
460
+ tokenized_corpus, all_data_corpus = generate_reference_corpus(reference_file)
461
+ bm25 = BM25Okapi(tokenized_corpus)
462
+
463
+ # Define the bloom model:
464
+ accelerator = Accelerator()
465
+ accel_device = accelerator.device
466
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
467
+ tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-7b1")
468
+ model = BloomForCausalLM.from_pretrained("bigscience/bloom-7b1", torch_dtype=torch.bfloat16).to(device)
469
+
470
+ # --------------------------------------------------
471
+ # test claim
472
+ s = bm25.get_scores(nltk.word_tokenize(test_claim))
473
+ top_n = np.argsort(s)[::-1][:topk]
474
+ docs = [all_data_corpus[i] for i in top_n]
475
+ # --------------------------------------------------
476
+
477
+ prompt = docs2prompt(docs) + "\n\n" + "Outrageously, " + speaker + " claimed that \"" + test_claim.strip() + \
478
+ "\". Criticism includes questions like: "
479
+ sentences = [prompt]
480
+
481
+ inputs = tokenizer(sentences, padding=True, return_tensors="pt").to(device)
482
+ outputs = model.generate(inputs["input_ids"], max_length=2000, num_beams=2, no_repeat_ngram_size=2,
483
+ early_stopping=True)
484
+
485
+ tgt_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
486
+ in_len = len(sentences[0])
487
+ questions_str = tgt_text[in_len:].split("\n")[0]
488
+
489
+ qs = questions_str.split("?")
490
+ qs = [q.strip() + "?" for q in qs if q.strip() and len(q.strip()) < 300]
491
+
492
+ #
493
+ generate_question = [{"question": q, "answers": []} for q in qs]
494
+
495
+ return generate_question
496
+
497
+
498
+ def check_claim_date(check_date):
499
+ try:
500
+ year, month, date = check_date.split("-")
501
+ except:
502
+ month, date, year = "01", "01", "2022"
503
+
504
+ if len(year) == 2 and int(year) <= 30:
505
+ year = "20" + year
506
+ elif len(year) == 2:
507
+ year = "19" + year
508
+ elif len(year) == 1:
509
+ year = "200" + year
510
+
511
+ if len(month) == 1:
512
+ month = "0" + month
513
+
514
+ if len(date) == 1:
515
+ date = "0" + date
516
+
517
+ sort_date = year + month + date
518
+
519
+ return sort_date
520
+
521
+
522
+ def string_to_search_query(text, author):
523
+ parts = word_tokenize(text.strip())
524
+ tags = pos_tag(parts)
525
+
526
+ keep_tags = ["CD", "JJ", "NN", "VB"]
527
+
528
+ if author is not None:
529
+ search_string = author.split()
530
+ else:
531
+ search_string = []
532
+
533
+ for token, tag in zip(parts, tags):
534
+ for keep_tag in keep_tags:
535
+ if tag[1].startswith(keep_tag):
536
+ search_string.append(token)
537
+
538
+ search_string = " ".join(search_string)
539
+ return search_string
540
+
541
+
542
+ def google_search(search_term, api_key, cse_id, **kwargs):
543
+ service = build("customsearch", "v1", developerKey=api_key)
544
+ res = service.cse().list(q=search_term, cx=cse_id, **kwargs).execute()
545
+
546
+ if "items" in res:
547
+ return res['items']
548
+ else:
549
+ return []
550
+
551
+
552
+ def get_domain_name(url):
553
+ if '://' not in url:
554
+ url = 'http://' + url
555
+
556
+ domain = urlparse(url).netloc
557
+
558
+ if domain.startswith("www."):
559
+ return domain[4:]
560
+ else:
561
+ return domain
562
+
563
+
564
+ def get_and_store(url_link, fp, worker, worker_stack):
565
+ page_lines = url2lines(url_link)
566
+
567
+ with open(fp, "w") as out_f:
568
+ print("\n".join([url_link] + page_lines), file=out_f)
569
+
570
+ worker_stack.append(worker)
571
+ gc.collect()
572
+
573
+
574
+ def get_google_search_results(api_key, search_engine_id, google_search, sort_date, search_string, page=0):
575
+ search_results = []
576
+ for i in range(3):
577
+ try:
578
+ search_results += google_search(
579
+ search_string,
580
+ api_key,
581
+ search_engine_id,
582
+ num=10,
583
+ start=0 + 10 * page,
584
+ sort="date:r:19000101:" + sort_date,
585
+ dateRestrict=None,
586
+ gl="US"
587
+ )
588
+ break
589
+ except:
590
+ sleep(3)
591
+
592
+ return search_results
593
+
594
+
595
+ def averitec_search(claim, generate_question, speaker="they", check_date="2024-01-01", n_pages=1): # n_pages=3
596
+ # default config
597
+ api_key = os.environ["GOOGLE_API_KEY"]
598
+ search_engine_id = os.environ["GOOGLE_SEARCH_ENGINE_ID"]
599
+
600
+ blacklist = [
601
+ "jstor.org", # Blacklisted because their pdfs are not labelled as such, and clog up the download
602
+ "facebook.com", # Blacklisted because only post titles can be scraped, but the scraper doesn't know this,
603
+ "ftp.cs.princeton.edu", # Blacklisted because it hosts many large NLP corpora that keep showing up
604
+ "nlp.cs.princeton.edu",
605
+ "huggingface.co"
606
+ ]
607
+
608
+ blacklist_files = [ # Blacklisted some NLP nonsense that crashes my machine with OOM errors
609
+ "/glove.",
610
+ "ftp://ftp.cs.princeton.edu/pub/cs226/autocomplete/words-333333.txt",
611
+ "https://web.mit.edu/adamrose/Public/googlelist",
612
+ ]
613
+
614
+ # save to folder
615
+ store_folder = "averitec_code/store/retrieved_docs"
616
+ #
617
+ index = 0
618
+ questions = [q["question"] for q in generate_question]
619
+
620
+ # check the date of the claim
621
+ sort_date = check_claim_date(check_date) # check_date="2022-01-01"
622
+
623
+ #
624
+ search_strings = []
625
+ search_types = []
626
+
627
+ search_string_2 = string_to_search_query(claim, None)
628
+ search_strings += [search_string_2, claim, ]
629
+ search_types += ["claim", "claim-noformat", ]
630
+
631
+ search_strings += questions
632
+ search_types += ["question" for _ in questions]
633
+
634
+ # start to search
635
+ search_results = []
636
+ visited = {}
637
+ store_counter = 0
638
+ worker_stack = list(range(10))
639
+
640
+ retrieve_evidence = []
641
+
642
+ for this_search_string, this_search_type in zip(search_strings, search_types):
643
+ for page_num in range(n_pages):
644
+ search_results = get_google_search_results(api_key, search_engine_id, google_search, sort_date,
645
+ this_search_string, page=page_num)
646
+
647
+ for result in search_results:
648
+ link = str(result["link"])
649
+ domain = get_domain_name(link)
650
+
651
+ if domain in blacklist:
652
+ continue
653
+ broken = False
654
+ for b_file in blacklist_files:
655
+ if b_file in link:
656
+ broken = True
657
+ if broken:
658
+ continue
659
+ if link.endswith(".pdf") or link.endswith(".doc"):
660
+ continue
661
+
662
+ store_file_path = ""
663
+
664
+ if link in visited:
665
+ store_file_path = visited[link]
666
+ else:
667
+ store_counter += 1
668
+ store_file_path = store_folder + "/search_result_" + str(index) + "_" + str(
669
+ store_counter) + ".store"
670
+ visited[link] = store_file_path
671
+
672
+ while len(worker_stack) == 0: # Wait for a wrrker to become available. Check every second.
673
+ sleep(1)
674
+
675
+ worker = worker_stack.pop()
676
+
677
+ t = threading.Thread(target=get_and_store, args=(link, store_file_path, worker, worker_stack))
678
+ t.start()
679
+
680
+ line = [str(index), claim, link, str(page_num), this_search_string, this_search_type, store_file_path]
681
+ retrieve_evidence.append(line)
682
+
683
+ return retrieve_evidence
684
+
685
+
686
+ def claim2prompts(example):
687
+ claim = example["claim"]
688
+
689
+ # claim_str = "Claim: " + claim + "||Evidence: "
690
+ claim_str = "Evidence: "
691
+
692
+ for question in example["questions"]:
693
+ q_text = question["question"].strip()
694
+ if len(q_text) == 0:
695
+ continue
696
+
697
+ if not q_text[-1] == "?":
698
+ q_text += "?"
699
+
700
+ answer_strings = []
701
+
702
+ for a in question["answers"]:
703
+ if a["answer_type"] in ["Extractive", "Abstractive"]:
704
+ answer_strings.append(a["answer"])
705
+ if a["answer_type"] == "Boolean":
706
+ answer_strings.append(a["answer"] + ", because " + a["boolean_explanation"].lower().strip())
707
+
708
+ for a_text in answer_strings:
709
+ if not a_text[-1] in [".", "!", ":", "?"]:
710
+ a_text += "."
711
+
712
+ # prompt_lookup_str = claim + " " + a_text
713
+ prompt_lookup_str = a_text
714
+ this_q_claim_str = claim_str + " " + a_text.strip() + "||Question answered: " + q_text
715
+ yield (prompt_lookup_str, this_q_claim_str.replace("\n", " ").replace("||", "\n"))
716
+
717
+
718
+ def generate_step2_reference_corpus(reference_file):
719
+ with open(reference_file) as f:
720
+ train_examples = json.load(f)
721
+
722
+ prompt_corpus = []
723
+ tokenized_corpus = []
724
+
725
+ for example in train_examples:
726
+ for lookup_str, prompt in claim2prompts(example):
727
+ entry = nltk.word_tokenize(lookup_str)
728
+ tokenized_corpus.append(entry)
729
+ prompt_corpus.append(prompt)
730
+
731
+ return tokenized_corpus, prompt_corpus
732
+
733
+
734
+ def decorate_with_questions(claim, retrieve_evidence, top_k=10): # top_k=100
735
+ #
736
+ reference_file = "averitec_code/data/train.json"
737
+ tokenized_corpus, prompt_corpus = generate_step2_reference_corpus(reference_file)
738
+ prompt_bm25 = BM25Okapi(tokenized_corpus)
739
+
740
+ # Define the bloom model:
741
+ accelerator = Accelerator()
742
+ accel_device = accelerator.device
743
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
744
+ tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-7b1")
745
+ model = BloomForCausalLM.from_pretrained(
746
+ "bigscience/bloom-7b1",
747
+ device_map="auto",
748
+ torch_dtype=torch.bfloat16,
749
+ offload_folder="./offload"
750
+ )
751
+
752
+ #
753
+ tokenized_corpus = []
754
+ all_data_corpus = []
755
+
756
+ for retri_evi in tqdm.tqdm(retrieve_evidence):
757
+ store_file = retri_evi[-1]
758
+
759
+ with open(store_file, 'r') as f:
760
+ first = True
761
+ for line in f:
762
+ line = line.strip()
763
+
764
+ if first:
765
+ first = False
766
+ location_url = line
767
+ continue
768
+
769
+ if len(line) > 3:
770
+ entry = nltk.word_tokenize(line)
771
+ if (location_url, line) not in all_data_corpus:
772
+ tokenized_corpus.append(entry)
773
+ all_data_corpus.append((location_url, line))
774
+
775
+ if len(tokenized_corpus) == 0:
776
+ print("")
777
+
778
+ bm25 = BM25Okapi(tokenized_corpus)
779
+ s = bm25.get_scores(nltk.word_tokenize(claim))
780
+ top_n = np.argsort(s)[::-1][:top_k]
781
+ docs = [all_data_corpus[i] for i in top_n]
782
+
783
+ generate_qa_pairs = []
784
+ # Then, generate questions for those top 50:
785
+ for doc in tqdm.tqdm(docs):
786
+ # prompt_lookup_str = example["claim"] + " " + doc[1]
787
+ prompt_lookup_str = doc[1]
788
+
789
+ prompt_s = prompt_bm25.get_scores(nltk.word_tokenize(prompt_lookup_str))
790
+ prompt_n = 10
791
+ prompt_top_n = np.argsort(prompt_s)[::-1][:prompt_n]
792
+ prompt_docs = [prompt_corpus[i] for i in prompt_top_n]
793
+
794
+ claim_prompt = "Evidence: " + doc[1].replace("\n", " ") + "\nQuestion answered: "
795
+ prompt = "\n\n".join(prompt_docs + [claim_prompt])
796
+ sentences = [prompt]
797
+
798
+ inputs = tokenizer(sentences, padding=True, return_tensors="pt").to(device)
799
+ outputs = model.generate(inputs["input_ids"], max_length=5000, num_beams=2, no_repeat_ngram_size=2,
800
+ early_stopping=True)
801
+
802
+ tgt_text = tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:], skip_special_tokens=True)[0]
803
+ # We are not allowed to generate more than 250 characters:
804
+ tgt_text = tgt_text[:250]
805
+
806
+ qa_pair = [tgt_text.strip().split("?")[0].replace("\n", " ") + "?", doc[1].replace("\n", " "), doc[0]]
807
+ generate_qa_pairs.append(qa_pair)
808
+
809
+ return generate_qa_pairs
810
+
811
+
812
+ def triple_to_string(x):
813
+ return " </s> ".join([item.strip() for item in x])
814
+
815
+
816
+ def rerank_questions(claim, bm25_qas, topk=3):
817
+ #
818
+ tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
819
+ bert_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2,
820
+ problem_type="single_label_classification") # Must specify single_label for some reason
821
+ best_checkpoint = "averitec_code/pretrained_models/bert_dual_encoder.ckpt"
822
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
823
+ trained_model = DualEncoderModule.load_from_checkpoint(best_checkpoint, tokenizer=tokenizer, model=bert_model).to(
824
+ device)
825
+
826
+ #
827
+ strs_to_score = []
828
+ values = []
829
+
830
+ for question, answer, source in bm25_qas:
831
+ str_to_score = triple_to_string([claim, question, answer])
832
+
833
+ strs_to_score.append(str_to_score)
834
+ values.append([question, answer, source])
835
+
836
+ if len(bm25_qas) > 0:
837
+ encoded_dict = tokenizer(strs_to_score, max_length=512, padding="longest", truncation=True,
838
+ return_tensors="pt").to(device)
839
+
840
+ input_ids = encoded_dict['input_ids']
841
+ attention_masks = encoded_dict['attention_mask']
842
+
843
+ scores = torch.softmax(trained_model(input_ids, attention_mask=attention_masks).logits, axis=-1)[:, 1]
844
+
845
+ top_n = torch.argsort(scores, descending=True)[:topk]
846
+ pass_through = [{"question": values[i][0], "answers": values[i][1], "source_url": values[i][2]} for i in top_n]
847
+ else:
848
+ pass_through = []
849
+
850
+ top3_qa_pairs = pass_through
851
+
852
+ return top3_qa_pairs
853
+
854
+
855
+ def GoogleAPIretriever(query):
856
+ # ----- Generate QA pairs using AVeriTeC
857
+ top3_qa_pairs_path = "averitec_code/top3_qa_pairs1.json"
858
+ if not os.path.exists(top3_qa_pairs_path):
859
+ # step 1: generate questions for the query/claim using Bloom
860
+ generate_question = prompt_question_generation(query)
861
+ # step 2: retrieve evidence for the generated questions using Google API
862
+ retrieve_evidence = averitec_search(query, generate_question)
863
+ # step 3: generate QA pairs for each retrieved document
864
+ bm25_qa_pairs = decorate_with_questions(query, retrieve_evidence)
865
+ # step 4: rerank QA pairs
866
+ top3_qa_pairs = rerank_questions(query, bm25_qa_pairs)
867
+ else:
868
+ top3_qa_pairs = json.load(open(top3_qa_pairs_path, 'r'))
869
+
870
+ # Add score to metadata
871
+ results = []
872
+ for i, qa in enumerate(top3_qa_pairs):
873
+ metadata = dict()
874
+
875
+ metadata['name'] = qa['question']
876
+ metadata['url'] = qa['source_url']
877
+ metadata['cached_source_url'] = qa['source_url']
878
+ metadata['short_name'] = "Evidence {}".format(i + 1)
879
+ metadata['page_number'] = ""
880
+ metadata['query'] = qa['question']
881
+ metadata['answer'] = qa['answers']
882
+ metadata['page_content'] = "<b>Question</b>: " + qa['question'] + "<br>" + "<b>Answer</b>: " + qa['answers']
883
+ page_content = f"""{metadata['page_content']}"""
884
+ results.append((metadata, page_content))
885
+
886
+ return results
887
+
888
+
889
+ # ----------GoogleAPIretriever---------
890
+
891
+ # ----------Wikipediaretriever---------
892
+ def bm25_retriever(query, corpus, topk=3):
893
+ bm25 = BM25Okapi(corpus)
894
+ #
895
+ query_tokens = word_tokenize(query)
896
+ scores = bm25.get_scores(query_tokens)
897
+ top_n = np.argsort(scores)[::-1][:topk]
898
+ top_n_scores = [scores[i] for i in top_n]
899
+
900
+ return top_n, top_n_scores
901
+
902
+
903
+ def bm25s_retriever(query, corpus, topk=3):
904
+ # optional: create a stemmer
905
+ stemmer = Stemmer.Stemmer("english")
906
+ # Tokenize the corpus and only keep the ids (faster and saves memory)
907
+ corpus_tokens = bm25s.tokenize(corpus, stopwords="en", stemmer=stemmer)
908
+ # Create the BM25 model and index the corpus
909
+ retriever = bm25s.BM25()
910
+ retriever.index(corpus_tokens)
911
+ # Query the corpus
912
+ query_tokens = bm25s.tokenize(query, stemmer=stemmer)
913
+ # Get top-k results as a tuple of (doc ids, scores). Both are arrays of shape (n_queries, k)
914
+ results, scores = retriever.retrieve(query_tokens, corpus=corpus, k=topk)
915
+ top_n = [corpus.index(res) for res in results[0]]
916
+ return top_n, scores
917
+
918
+
919
+ def find_evidence_from_wikipedia_dumps(claim):
920
+ #
921
+ doc = nlp(claim)
922
+ entities_in_claim = [str(ent).lower() for ent in doc.ents]
923
+ title2id = ranker.doc_dict[0]
924
+ wiki_intro, ent_list = [], []
925
+ for ent in entities_in_claim:
926
+ if ent in title2id.keys():
927
+ ids = title2id[ent]
928
+ introduction = doc_db.get_doc_intro(ids)
929
+ wiki_intro.append([ent, introduction])
930
+ # fulltext = doc_db.get_doc_text(ids)
931
+ # evidence.append([ent, fulltext])
932
+ ent_list.append(ent)
933
+
934
+ if len(wiki_intro) < 5:
935
+ evidence_tfidf = process_topk(claim, title2id, ent_list, k=5)
936
+ wiki_intro.extend(evidence_tfidf)
937
+
938
+ return wiki_intro, doc
939
+
940
+
941
+ def relevant_sentence_retrieval(query, wiki_intro, k):
942
+ # 1. Create corpus here
943
+ corpus, sentences = [], []
944
+ titles = []
945
+ for i, (title, intro) in enumerate(wiki_intro):
946
+ sents_in_intro = sent_tokenize(intro)
947
+
948
+ for sent in sents_in_intro:
949
+ corpus.append(word_tokenize(sent))
950
+ sentences.append(sent)
951
+ titles.append(title)
952
+ #
953
+ # ----- BM25
954
+ bm25_top_n, bm25_top_n_scores = bm25_retriever(query, corpus, topk=k)
955
+ bm25_top_n_sents = [sentences[i] for i in bm25_top_n]
956
+ bm25_top_n_titles = [titles[i] for i in bm25_top_n]
957
+
958
+ # ----- BM25s
959
+ # bm25s_top_n, bm25s_top_n_scores = bm25s_retriever(query, sentences, topk=k) # corpus->sentences
960
+ # bm25s_top_n_sents = [sentences[i] for i in bm25s_top_n]
961
+ # bm25s_top_n_titles = [titles[i] for i in bm25s_top_n]
962
+
963
+ return bm25_top_n_sents, bm25_top_n_titles
964
+
965
+
966
+ def process_topk(query, title2id, ent_list, k=1):
967
+ doc_names, doc_scores = ranker.closest_docs(query, k)
968
+ evidence_tfidf = []
969
+
970
+ for _name in doc_names:
971
+ if _name not in ent_list and len(ent_list) < 5:
972
+ ent_list.append(_name)
973
+ idx = title2id[_name]
974
+ introduction = doc_db.get_doc_intro(idx)
975
+ evidence_tfidf.append([_name, introduction])
976
+ # fulltext = doc_db.get_doc_text(idx)
977
+ # evidence_tfidf.append([_name,fulltext])
978
+
979
+ return evidence_tfidf
980
+
981
+
982
+ def WikipediaDumpsretriever(claim):
983
+ #
984
+ # 1. extract relevant wikipedia pages from wikipedia dumps
985
+ wiki_intro, doc = find_evidence_from_wikipedia_dumps(claim)
986
+ # wiki_intro = [['trump', "'''Trump''' most commonly refers to:\n* Donald Trump (born 1946), President of the United States from 2017 to 2021 \n* Trump (card games), any playing card given an ad-hoc high rank\n\n'''Trump''' may also refer to:"]]
987
+
988
+ # 2. extract relevant sentences from extracted wikipedia pages
989
+ sents, titles = relevant_sentence_retrieval(claim, wiki_intro, k=3)
990
+
991
+ #
992
+ results = []
993
+ for i, (sent, title) in enumerate(zip(sents, titles)):
994
+ metadata = dict()
995
+ metadata['name'] = claim
996
+ metadata['url'] = "https://en.wikipedia.org/wiki/" + "_".join(title.split())
997
+ metadata['cached_source_url'] = "https://en.wikipedia.org/wiki/" + "_".join(title.split())
998
+ metadata['short_name'] = "Evidence {}".format(i + 1)
999
+ metadata['page_number'] = ""
1000
+ metadata['query'] = sent
1001
+ metadata['title'] = title
1002
+ metadata['evidence'] = sent
1003
+ metadata['answer'] = ""
1004
+ metadata['page_content'] = "<b>Title</b>: " + str(metadata['title']) + "<br>" + "<b>Evidence</b>: " + metadata[
1005
+ 'evidence']
1006
+ page_content = f"""{metadata['page_content']}"""
1007
+
1008
+ results.append(Docs(metadata, page_content))
1009
+
1010
+ return results
1011
+
1012
+ # ----------WikipediaAPIretriever---------
1013
+ def clean_str(p):
1014
+ return p.encode().decode("unicode-escape").encode("latin1").decode("utf-8")
1015
+
1016
+
1017
+ def get_page_obs(page):
1018
+ # find all paragraphs
1019
+ paragraphs = page.split("\n")
1020
+ paragraphs = [p.strip() for p in paragraphs if p.strip()]
1021
+
1022
+ # # find all sentence
1023
+ # sentences = []
1024
+ # for p in paragraphs:
1025
+ # sentences += p.split('. ')
1026
+ # sentences = [s.strip() + '.' for s in sentences if s.strip()]
1027
+ # # return ' '.join(sentences[:5])
1028
+ # return ' '.join(sentences)
1029
+
1030
+ return ' '.join(paragraphs[:5])
1031
+
1032
+
1033
+ def search_entity_wikipeida(entity):
1034
+ find_evidence = []
1035
+
1036
+ page_py = wiki_wiki.page(entity)
1037
+ if page_py.exists():
1038
+ introduction = page_py.summary
1039
+
1040
+ find_evidence.append([str(entity), introduction])
1041
+
1042
+ return find_evidence
1043
+
1044
+
1045
+ def search_step(entity):
1046
+ ent_ = entity.replace(" ", "+")
1047
+ search_url = f"https://en.wikipedia.org/w/index.php?search={ent_}"
1048
+ response_text = requests.get(search_url).text
1049
+ soup = BeautifulSoup(response_text, features="html.parser")
1050
+ result_divs = soup.find_all("div", {"class": "mw-search-result-heading"})
1051
+
1052
+ find_evidence = []
1053
+
1054
+ if result_divs: # mismatch
1055
+ # If the wikipeida page of the entity is not exist, find similar wikipedia pages.
1056
+ result_titles = [clean_str(div.get_text().strip()) for div in result_divs]
1057
+ similar_titles = result_titles[:5]
1058
+
1059
+ for _t in similar_titles:
1060
+ if len(find_evidence) < 5:
1061
+ _evi = search_step(_t)
1062
+ find_evidence.extend(_evi)
1063
+ else:
1064
+ page = [p.get_text().strip() for p in soup.find_all("p") + soup.find_all("ul")]
1065
+ if any("may refer to:" in p for p in page):
1066
+ _evi = search_step("[" + entity + "]")
1067
+ find_evidence.extend(_evi)
1068
+ else:
1069
+ # page_py = wiki_wiki.page(entity)
1070
+ #
1071
+ # if page_py.exists():
1072
+ # introduction = page_py.summary
1073
+ # else:
1074
+ page_text = ""
1075
+ for p in page:
1076
+ if len(p.split(" ")) > 2:
1077
+ page_text += clean_str(p)
1078
+ if not p.endswith("\n"):
1079
+ page_text += "\n"
1080
+ introduction = get_page_obs(page_text)
1081
+
1082
+ find_evidence.append([entity, introduction])
1083
+
1084
+ return find_evidence
1085
+
1086
+
1087
+ def find_similar_wikipedia(entity, relevant_wikipages):
1088
+ # If the relevant wikipeida page of the entity is less than 5, find similar wikipedia pages.
1089
+ ent_ = entity.replace(" ", "+")
1090
+ search_url = f"https://en.wikipedia.org/w/index.php?search={ent_}&title=Special:Search&profile=advanced&fulltext=1&ns0=1"
1091
+ response_text = requests.get(search_url).text
1092
+ soup = BeautifulSoup(response_text, features="html.parser")
1093
+ result_divs = soup.find_all("div", {"class": "mw-search-result-heading"})
1094
+
1095
+ if result_divs:
1096
+ result_titles = [clean_str(div.get_text().strip()) for div in result_divs]
1097
+ similar_titles = result_titles[:5]
1098
+
1099
+ saved_titles = [ent[0] for ent in relevant_wikipages] if relevant_wikipages else relevant_wikipages
1100
+ for _t in similar_titles:
1101
+ if _t not in saved_titles and len(relevant_wikipages) < 5:
1102
+ _evi = search_entity_wikipeida(_t)
1103
+ # _evi = search_step(_t)
1104
+ relevant_wikipages.extend(_evi)
1105
+
1106
+ return relevant_wikipages
1107
+
1108
+
1109
+ def find_evidence_from_wikipedia(claim):
1110
+ #
1111
+ doc = nlp(claim)
1112
+ #
1113
+ wikipedia_page = []
1114
+ for ent in doc.ents:
1115
+ relevant_wikipages = search_entity_wikipeida(ent)
1116
+
1117
+ if len(relevant_wikipages) < 5:
1118
+ relevant_wikipages = find_similar_wikipedia(str(ent), relevant_wikipages)
1119
+
1120
+ wikipedia_page.extend(relevant_wikipages)
1121
+
1122
+ return wikipedia_page
1123
+
1124
+
1125
+ def relevant_wikipedia_API_retriever(claim):
1126
+ #
1127
+ doc = nlp(claim)
1128
+
1129
+ wiki_intro = []
1130
+ for ent in doc.ents:
1131
+ page_py = wiki_wiki.page(ent)
1132
+
1133
+ if page_py.exists():
1134
+ introduction = page_py.summary
1135
+ else:
1136
+ introduction = "No documents found."
1137
+
1138
+ wiki_intro.append([str(ent), introduction])
1139
+
1140
+ return wiki_intro, doc
1141
+
1142
+
1143
+ def Wikipediaretriever(claim, sources):
1144
+ #
1145
+ # 1. extract relevant wikipedia pages from wikipedia dumps
1146
+ if "Dump" in sources:
1147
+ wikipedia_page = find_evidence_from_wikipedia_dumps(claim)
1148
+ else:
1149
+ wikipedia_page = find_evidence_from_wikipedia(claim)
1150
+ # wiki_intro, doc = relevant_wikipedia_API_retriever(claim)
1151
+
1152
+ # 2. extract relevant sentences from extracted wikipedia pages
1153
+ sents, titles = relevant_sentence_retrieval(claim, wikipedia_page, k=3)
1154
+
1155
+ #
1156
+ results = []
1157
+ for i, (sent, title) in enumerate(zip(sents, titles)):
1158
+ metadata = dict()
1159
+ metadata['name'] = claim
1160
+ metadata['url'] = "https://en.wikipedia.org/wiki/" + "_".join(title.split())
1161
+ metadata['cached_source_url'] = "https://en.wikipedia.org/wiki/" + "_".join(title)
1162
+ metadata['short_name'] = "Evidence {}".format(i + 1)
1163
+ metadata['page_number'] = ""
1164
+ metadata['query'] = sent
1165
+ metadata['title'] = title
1166
+ metadata['evidence'] = sent
1167
+ metadata['answer'] = ""
1168
+ metadata['page_content'] = "<b>Title</b>: " + str(metadata['title']) + "<br>" + "<b>Evidence</b>: " + metadata['evidence']
1169
+ page_content = f"""{metadata['page_content']}"""
1170
+
1171
+ results.append(Docs(metadata, page_content))
1172
+
1173
+ return results
1174
+
1175
+
1176
+ def log_on_azure(file, logs, azure_share_client):
1177
+ logs = json.dumps(logs)
1178
+ file_client = azure_share_client.get_file_client(file)
1179
+ file_client.upload_file(logs)
1180
+
1181
+
1182
+ def chat(claim, history, sources):
1183
+ evidence = []
1184
+ # if 'Google' in sources:
1185
+ # evidence = GoogleAPIretriever(query)
1186
+
1187
+ # if 'WikiPediaDumps' in sources:
1188
+ # evidence = WikipediaDumpsretriever(query)
1189
+
1190
+ if 'WikiPedia' in sources:
1191
+ evidence = Wikipediaretriever(claim, sources)
1192
+
1193
+ answer_set, answer_output = QAprediction(claim, evidence, sources)
1194
+
1195
+ docs_html = ""
1196
+ if len(evidence) > 0:
1197
+ docs_html = []
1198
+ for i, evi in enumerate(evidence, 1):
1199
+ docs_html.append(make_html_source(evi, i))
1200
+ docs_html = "".join(docs_html)
1201
+ else:
1202
+ print("No documents found")
1203
+
1204
+ url_of_evidence = ""
1205
+ output_language = "English"
1206
+ output_query = claim
1207
+ history[-1] = (claim, answer_set)
1208
+ history = [tuple(x) for x in history]
1209
+
1210
+ ############################################################
1211
+ evi_list = []
1212
+ for evi in evidence:
1213
+ title_str = evi.metadata['title']
1214
+ evi_str = evi.metadata['evidence']
1215
+ evi_list.append([title_str, evi_str])
1216
+
1217
+ try:
1218
+ # Log answer on Azure Blob Storage
1219
+ # IF AZURE_ISSAVE=TRUE, save the logs into the Azure share client.
1220
+ if bool(os.environ["AZURE_ISSAVE"]):
1221
+ timestamp = str(datetime.now().timestamp())
1222
+ # timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
1223
+ file = timestamp + ".json"
1224
+ logs = {
1225
+ "user_id": str(user_id),
1226
+ "claim": claim,
1227
+ "sources": sources,
1228
+ "evidence": evi_list,
1229
+ "url": url_of_evidence,
1230
+ "answer": answer_output,
1231
+ "time": timestamp,
1232
+ }
1233
+ log_on_azure(file, logs, azure_share_client)
1234
+ except Exception as e:
1235
+ print(f"Error logging on Azure Blob Storage: {e}")
1236
+ raise gr.Error(
1237
+ f"AVeriTeC Error: {str(e)[:100]} - The error has been noted, try another question and if the error remains, you can contact us :)")
1238
+ ##########
1239
+
1240
+ return history, docs_html, output_query, output_language
1241
+
1242
+
1243
+ def main():
1244
+ init_prompt = """
1245
+ Hello, I am a fact-checking assistant designed to help you find appropriate evidence to predict the veracity of claims.
1246
+
1247
+ What do you want to fact-check?
1248
+ """
1249
+
1250
+ with gr.Blocks(title="AVeriTeC fact-checker", css="style.css", theme=theme, elem_id="main-component") as demo:
1251
+ with gr.Tab("AVeriTeC"):
1252
+ with gr.Row(elem_id="chatbot-row"):
1253
+ with gr.Column(scale=2):
1254
+ chatbot = gr.Chatbot(
1255
+ value=[(None, init_prompt)],
1256
+ show_copy_button=True, show_label=False, elem_id="chatbot", layout="panel",
1257
+ avatar_images=(None, "assets/averitec.png")
1258
+ ) # avatar_images=(None, "https://i.ibb.co/YNyd5W2/logo4.png"),
1259
+
1260
+ with gr.Row(elem_id="input-message"):
1261
+ textbox = gr.Textbox(placeholder="Ask me what claim do you want to check!", show_label=False,
1262
+ scale=7, lines=1, interactive=True, elem_id="input-textbox")
1263
+ # submit = gr.Button("",elem_id = "submit-button",scale = 1,interactive = True,icon = "https://static-00.iconduck.com/assets.00/settings-icon-2048x2046-cw28eevx.png")
1264
+
1265
+ with gr.Column(scale=1, variant="panel", elem_id="right-panel"):
1266
+ with gr.Tabs() as tabs:
1267
+ with gr.TabItem("Examples", elem_id="tab-examples", id=0):
1268
+ examples_hidden = gr.Textbox(visible=False)
1269
+ first_key = list(CLAIMS_Type.keys())[0]
1270
+ dropdown_samples = gr.Dropdown(CLAIMS_Type.keys(), value=first_key, interactive=True,
1271
+ show_label=True,
1272
+ label="Select claim type",
1273
+ elem_id="dropdown-samples")
1274
+
1275
+ samples = []
1276
+ for i, key in enumerate(CLAIMS_Type.keys()):
1277
+ examples_visible = True if i == 0 else False
1278
+
1279
+ with gr.Row(visible=examples_visible) as group_examples:
1280
+ examples_questions = gr.Examples(
1281
+ CLAIMS_Type[key],
1282
+ [examples_hidden],
1283
+ examples_per_page=8,
1284
+ run_on_click=False,
1285
+ elem_id=f"examples{i}",
1286
+ api_name=f"examples{i}",
1287
+ # label = "Click on the example question or enter your own",
1288
+ # cache_examples=True,
1289
+ )
1290
+
1291
+ samples.append(group_examples)
1292
+
1293
+ with gr.Tab("Sources", elem_id="tab-citations", id=1):
1294
+ sources_textbox = gr.HTML(show_label=False, elem_id="sources-textbox")
1295
+ docs_textbox = gr.State("")
1296
+
1297
+ with gr.Tab("Configuration", elem_id="tab-config", id=2):
1298
+ gr.Markdown("Reminder: We currently only support fact-checking in English!")
1299
+
1300
+ # dropdown_sources = gr.Radio(
1301
+ # ["AVeriTeC", "WikiPediaDumps", "Google", "WikiPediaAPI"],
1302
+ # label="Select source",
1303
+ # value="WikiPediaAPI",
1304
+ # interactive=True,
1305
+ # )
1306
+
1307
+ dropdown_sources = gr.Radio(
1308
+ ["Google", "WikiPedia"],
1309
+ label="Select source",
1310
+ value="WikiPedia",
1311
+ interactive=True,
1312
+ )
1313
+
1314
+ dropdown_retriever = gr.Dropdown(
1315
+ ["BM25", "BM25s"],
1316
+ label="Select evidence retriever",
1317
+ multiselect=False,
1318
+ value="BM25",
1319
+ interactive=True,
1320
+ )
1321
+
1322
+ output_query = gr.Textbox(label="Query used for retrieval", show_label=True,
1323
+ elem_id="reformulated-query", lines=2, interactive=False)
1324
+ output_language = gr.Textbox(label="Language", show_label=True, elem_id="language", lines=1,
1325
+ interactive=False)
1326
+
1327
+ with gr.Tab("About", elem_classes="max-height other-tabs"):
1328
+ with gr.Row():
1329
+ with gr.Column(scale=1):
1330
+ gr.Markdown("See more info at [https://fever.ai/task.html](https://fever.ai/task.html)")
1331
+
1332
+ def start_chat(query, history):
1333
+ history = history + [(query, None)]
1334
+ history = [tuple(x) for x in history]
1335
+ return (gr.update(interactive=False), gr.update(selected=1), history)
1336
+
1337
+ def finish_chat():
1338
+ return (gr.update(interactive=True, value=""))
1339
+
1340
+ (textbox
1341
+ .submit(start_chat, [textbox, chatbot], [textbox, tabs, chatbot], queue=False, api_name="start_chat_textbox")
1342
+ .then(chat, [textbox, chatbot, dropdown_sources],
1343
+ [chatbot, sources_textbox, output_query, output_language], concurrency_limit=8, api_name="chat_textbox")
1344
+ .then(finish_chat, None, [textbox], api_name="finish_chat_textbox")
1345
+ )
1346
+
1347
+ (examples_hidden
1348
+ .change(start_chat, [examples_hidden, chatbot], [textbox, tabs, chatbot], queue=False,
1349
+ api_name="start_chat_examples")
1350
+ .then(chat, [examples_hidden, chatbot, dropdown_sources],
1351
+ [chatbot, sources_textbox, output_query, output_language], concurrency_limit=8, api_name="chat_examples")
1352
+ .then(finish_chat, None, [textbox], api_name="finish_chat_examples")
1353
+ )
1354
+
1355
+ def change_sample_questions(key):
1356
+ index = list(CLAIMS_Type.keys()).index(key)
1357
+ visible_bools = [False] * len(samples)
1358
+ visible_bools[index] = True
1359
+ return [gr.update(visible=visible_bools[i]) for i in range(len(samples))]
1360
+
1361
+ dropdown_samples.change(change_sample_questions, dropdown_samples, samples)
1362
+ demo.queue()
1363
+
1364
+ demo.launch(share=True)
1365
+
1366
+
1367
+ if __name__ == "__main__":
1368
+ main()
drqa/__init__.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2017-present, Facebook, Inc.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ import os
9
+ import sys
10
+ from pathlib import PosixPath
11
+
12
+ if sys.version_info < (3, 5):
13
+ raise RuntimeError('DrQA supports Python 3.5 or higher.')
14
+
15
+ DATA_DIR = (
16
+ os.getenv('DRQA_DATA') or
17
+ os.path.join(PosixPath(__file__).absolute().parents[1].as_posix(), 'data')
18
+ )
19
+
20
+ from . import tokenizers
21
+ from . import reader
22
+ from . import retriever
23
+ from . import pipeline
drqa/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (612 Bytes). View file
 
drqa/pipeline/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2017-present, Facebook, Inc.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ import os
9
+ from ..tokenizers import CoreNLPTokenizer
10
+ from ..retriever import TfidfDocRanker
11
+ from ..retriever import DocDB
12
+ from .. import DATA_DIR
13
+
14
+ DEFAULTS = {
15
+ 'tokenizer': CoreNLPTokenizer,
16
+ 'ranker': TfidfDocRanker,
17
+ 'db': DocDB,
18
+ 'reader_model': os.path.join(DATA_DIR, 'reader/multitask.mdl'),
19
+ }
20
+
21
+
22
+ def set_default(key, value):
23
+ global DEFAULTS
24
+ DEFAULTS[key] = value
25
+
26
+
27
+ from .drqa import DrQA
drqa/pipeline/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (645 Bytes). View file
 
drqa/pipeline/__pycache__/drqa.cpython-38.pyc ADDED
Binary file (7.78 kB). View file
 
drqa/pipeline/drqa.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2017-present, Facebook, Inc.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+ """Full DrQA pipeline."""
8
+
9
+ import torch
10
+ import regex
11
+ import heapq
12
+ import math
13
+ import time
14
+ import logging
15
+
16
+ from multiprocessing import Pool as ProcessPool
17
+ from multiprocessing.util import Finalize
18
+
19
+ from ..reader.vector import batchify
20
+ from ..reader.data import ReaderDataset, SortedBatchSampler
21
+ from .. import reader
22
+ from .. import tokenizers
23
+ from . import DEFAULTS
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ # ------------------------------------------------------------------------------
29
+ # Multiprocessing functions to fetch and tokenize text
30
+ # ------------------------------------------------------------------------------
31
+
32
+ PROCESS_TOK = None
33
+ PROCESS_DB = None
34
+ PROCESS_CANDS = None
35
+
36
+
37
+ def init(tokenizer_class, tokenizer_opts, db_class, db_opts, candidates=None):
38
+ global PROCESS_TOK, PROCESS_DB, PROCESS_CANDS
39
+ PROCESS_TOK = tokenizer_class(**tokenizer_opts)
40
+ Finalize(PROCESS_TOK, PROCESS_TOK.shutdown, exitpriority=100)
41
+ PROCESS_DB = db_class(**db_opts)
42
+ Finalize(PROCESS_DB, PROCESS_DB.close, exitpriority=100)
43
+ PROCESS_CANDS = candidates
44
+
45
+
46
+ def fetch_text(doc_id):
47
+ global PROCESS_DB
48
+ return PROCESS_DB.get_doc_text(doc_id)
49
+
50
+
51
+ def tokenize_text(text):
52
+ global PROCESS_TOK
53
+ return PROCESS_TOK.tokenize(text)
54
+
55
+
56
+ # ------------------------------------------------------------------------------
57
+ # Main DrQA pipeline
58
+ # ------------------------------------------------------------------------------
59
+
60
+
61
+ class DrQA(object):
62
+ # Target size for squashing short paragraphs together.
63
+ # 0 = read every paragraph independently
64
+ # infty = read all paragraphs together
65
+ GROUP_LENGTH = 0
66
+
67
+ def __init__(
68
+ self,
69
+ reader_model=None,
70
+ embedding_file=None,
71
+ tokenizer=None,
72
+ fixed_candidates=None,
73
+ batch_size=128,
74
+ cuda=True,
75
+ data_parallel=False,
76
+ max_loaders=5,
77
+ num_workers=None,
78
+ db_config=None,
79
+ ranker_config=None
80
+ ):
81
+ """Initialize the pipeline.
82
+
83
+ Args:
84
+ reader_model: model file from which to load the DocReader.
85
+ embedding_file: if given, will expand DocReader dictionary to use
86
+ all available pretrained embeddings.
87
+ tokenizer: string option to specify tokenizer used on docs.
88
+ fixed_candidates: if given, all predictions will be constrated to
89
+ the set of candidates contained in the file. One entry per line.
90
+ batch_size: batch size when processing paragraphs.
91
+ cuda: whether to use the gpu.
92
+ data_parallel: whether to use multile gpus.
93
+ max_loaders: max number of async data loading workers when reading.
94
+ (default is fine).
95
+ num_workers: number of parallel CPU processes to use for tokenizing
96
+ and post processing resuls.
97
+ db_config: config for doc db.
98
+ ranker_config: config for ranker.
99
+ """
100
+ self.batch_size = batch_size
101
+ self.max_loaders = max_loaders
102
+ self.fixed_candidates = fixed_candidates is not None
103
+ self.cuda = cuda
104
+
105
+ logger.info('Initializing document ranker...')
106
+ ranker_config = ranker_config or {}
107
+ ranker_class = ranker_config.get('class', DEFAULTS['ranker'])
108
+ ranker_opts = ranker_config.get('options', {})
109
+ self.ranker = ranker_class(**ranker_opts)
110
+
111
+ logger.info('Initializing document reader...')
112
+ reader_model = reader_model or DEFAULTS['reader_model']
113
+ self.reader = reader.DocReader.load(reader_model, normalize=False)
114
+ if embedding_file:
115
+ logger.info('Expanding dictionary...')
116
+ words = reader.utils.index_embedding_words(embedding_file)
117
+ added = self.reader.expand_dictionary(words)
118
+ self.reader.load_embeddings(added, embedding_file)
119
+ if cuda:
120
+ self.reader.cuda()
121
+ if data_parallel:
122
+ self.reader.parallelize()
123
+
124
+ if not tokenizer:
125
+ tok_class = DEFAULTS['tokenizer']
126
+ else:
127
+ tok_class = tokenizers.get_class(tokenizer)
128
+ annotators = tokenizers.get_annotators_for_model(self.reader)
129
+ tok_opts = {'annotators': annotators}
130
+
131
+ # ElasticSearch is also used as backend if used as ranker
132
+ if hasattr(self.ranker, 'es'):
133
+ db_config = ranker_config
134
+ db_class = ranker_class
135
+ db_opts = ranker_opts
136
+ else:
137
+ db_config = db_config or {}
138
+ db_class = db_config.get('class', DEFAULTS['db'])
139
+ db_opts = db_config.get('options', {})
140
+
141
+ logger.info('Initializing tokenizers and document retrievers...')
142
+ self.num_workers = num_workers
143
+ self.processes = ProcessPool(
144
+ num_workers,
145
+ initializer=init,
146
+ initargs=(tok_class, tok_opts, db_class, db_opts, fixed_candidates)
147
+ )
148
+
149
+ def _split_doc(self, doc):
150
+ """Given a doc, split it into chunks (by paragraph)."""
151
+ curr = []
152
+ curr_len = 0
153
+ for split in regex.split(r'\n+', doc):
154
+ split = split.strip()
155
+ if len(split) == 0:
156
+ continue
157
+ # Maybe group paragraphs together until we hit a length limit
158
+ if len(curr) > 0 and curr_len + len(split) > self.GROUP_LENGTH:
159
+ yield ' '.join(curr)
160
+ curr = []
161
+ curr_len = 0
162
+ curr.append(split)
163
+ curr_len += len(split)
164
+ if len(curr) > 0:
165
+ yield ' '.join(curr)
166
+
167
+ def _get_loader(self, data, num_loaders):
168
+ """Return a pytorch data iterator for provided examples."""
169
+ dataset = ReaderDataset(data, self.reader)
170
+ sampler = SortedBatchSampler(
171
+ dataset.lengths(),
172
+ self.batch_size,
173
+ shuffle=False
174
+ )
175
+ loader = torch.utils.data.DataLoader(
176
+ dataset,
177
+ batch_size=self.batch_size,
178
+ sampler=sampler,
179
+ num_workers=num_loaders,
180
+ collate_fn=batchify,
181
+ pin_memory=self.cuda,
182
+ )
183
+ return loader
184
+
185
+ def process(self, query, candidates=None, top_n=1, n_docs=5,
186
+ return_context=False):
187
+ """Run a single query."""
188
+ predictions = self.process_batch(
189
+ [query], [candidates] if candidates else None,
190
+ top_n, n_docs, return_context
191
+ )
192
+ return predictions[0]
193
+
194
+ def process_batch(self, queries, candidates=None, top_n=1, n_docs=5,
195
+ return_context=False):
196
+ """Run a batch of queries (more efficient)."""
197
+ t0 = time.time()
198
+ logger.info('Processing %d queries...' % len(queries))
199
+ logger.info('Retrieving top %d docs...' % n_docs)
200
+
201
+ # Rank documents for queries.
202
+ if len(queries) == 1:
203
+ ranked = [self.ranker.closest_docs(queries[0], k=n_docs)]
204
+ else:
205
+ ranked = self.ranker.batch_closest_docs(
206
+ queries, k=n_docs, num_workers=self.num_workers
207
+ )
208
+ all_docids, all_doc_scores = zip(*ranked)
209
+
210
+ # Flatten document ids and retrieve text from database.
211
+ # We remove duplicates for processing efficiency.
212
+ flat_docids = list({d for docids in all_docids for d in docids})
213
+ did2didx = {did: didx for didx, did in enumerate(flat_docids)}
214
+ doc_texts = self.processes.map(fetch_text, flat_docids)
215
+
216
+ # Split and flatten documents. Maintain a mapping from doc (index in
217
+ # flat list) to split (index in flat list).
218
+ flat_splits = []
219
+ didx2sidx = []
220
+ for text in doc_texts:
221
+ splits = self._split_doc(text)
222
+ didx2sidx.append([len(flat_splits), -1])
223
+ for split in splits:
224
+ flat_splits.append(split)
225
+ didx2sidx[-1][1] = len(flat_splits)
226
+
227
+ # Push through the tokenizers as fast as possible.
228
+ q_tokens = self.processes.map_async(tokenize_text, queries)
229
+ s_tokens = self.processes.map_async(tokenize_text, flat_splits)
230
+ q_tokens = q_tokens.get()
231
+ s_tokens = s_tokens.get()
232
+
233
+ # Group into structured example inputs. Examples' ids represent
234
+ # mappings to their question, document, and split ids.
235
+ examples = []
236
+ for qidx in range(len(queries)):
237
+ for rel_didx, did in enumerate(all_docids[qidx]):
238
+ start, end = didx2sidx[did2didx[did]]
239
+ for sidx in range(start, end):
240
+ if (len(q_tokens[qidx].words()) > 0 and
241
+ len(s_tokens[sidx].words()) > 0):
242
+ examples.append({
243
+ 'id': (qidx, rel_didx, sidx),
244
+ 'question': q_tokens[qidx].words(),
245
+ 'qlemma': q_tokens[qidx].lemmas(),
246
+ 'document': s_tokens[sidx].words(),
247
+ 'lemma': s_tokens[sidx].lemmas(),
248
+ 'pos': s_tokens[sidx].pos(),
249
+ 'ner': s_tokens[sidx].entities(),
250
+ })
251
+
252
+ logger.info('Reading %d paragraphs...' % len(examples))
253
+
254
+ # Push all examples through the document reader.
255
+ # We decode argmax start/end indices asychronously on CPU.
256
+ result_handles = []
257
+ num_loaders = min(self.max_loaders, math.floor(len(examples) / 1e3))
258
+ for batch in self._get_loader(examples, num_loaders):
259
+ if candidates or self.fixed_candidates:
260
+ batch_cands = []
261
+ for ex_id in batch[-1]:
262
+ batch_cands.append({
263
+ 'input': s_tokens[ex_id[2]],
264
+ 'cands': candidates[ex_id[0]] if candidates else None
265
+ })
266
+ handle = self.reader.predict(
267
+ batch, batch_cands, async_pool=self.processes
268
+ )
269
+ else:
270
+ handle = self.reader.predict(batch, async_pool=self.processes)
271
+ result_handles.append((handle, batch[-1], batch[0].size(0)))
272
+
273
+ # Iterate through the predictions, and maintain priority queues for
274
+ # top scored answers for each question in the batch.
275
+ queues = [[] for _ in range(len(queries))]
276
+ for result, ex_ids, batch_size in result_handles:
277
+ s, e, score = result.get()
278
+ for i in range(batch_size):
279
+ # We take the top prediction per split.
280
+ if len(score[i]) > 0:
281
+ item = (score[i][0], ex_ids[i], s[i][0], e[i][0])
282
+ queue = queues[ex_ids[i][0]]
283
+ if len(queue) < top_n:
284
+ heapq.heappush(queue, item)
285
+ else:
286
+ heapq.heappushpop(queue, item)
287
+
288
+ # Arrange final top prediction data.
289
+ all_predictions = []
290
+ for queue in queues:
291
+ predictions = []
292
+ while len(queue) > 0:
293
+ score, (qidx, rel_didx, sidx), s, e = heapq.heappop(queue)
294
+ prediction = {
295
+ 'doc_id': all_docids[qidx][rel_didx],
296
+ 'span': s_tokens[sidx].slice(s, e + 1).untokenize(),
297
+ 'doc_score': float(all_doc_scores[qidx][rel_didx]),
298
+ 'span_score': float(score),
299
+ }
300
+ if return_context:
301
+ prediction['context'] = {
302
+ 'text': s_tokens[sidx].untokenize(),
303
+ 'start': s_tokens[sidx].offsets()[s][0],
304
+ 'end': s_tokens[sidx].offsets()[e][1],
305
+ }
306
+ predictions.append(prediction)
307
+ all_predictions.append(predictions[-1::-1])
308
+
309
+ logger.info('Processed %d queries in %.4f (s)' %
310
+ (len(queries), time.time() - t0))
311
+
312
+ return all_predictions
drqa/reader/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2017-present, Facebook, Inc.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ import os
9
+ from ..tokenizers import CoreNLPTokenizer
10
+ from .. import DATA_DIR
11
+
12
+
13
+ DEFAULTS = {
14
+ 'tokenizer': CoreNLPTokenizer,
15
+ 'model': os.path.join(DATA_DIR, 'reader/single.mdl'),
16
+ }
17
+
18
+
19
+ def set_default(key, value):
20
+ global DEFAULTS
21
+ DEFAULTS[key] = value
22
+
23
+ from .model import DocReader
24
+ from .predictor import Predictor
25
+ from . import config
26
+ from . import vector
27
+ from . import data
28
+ from . import utils
drqa/reader/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (697 Bytes). View file
 
drqa/reader/__pycache__/config.cpython-38.pyc ADDED
Binary file (4.52 kB). View file
 
drqa/reader/__pycache__/data.cpython-38.pyc ADDED
Binary file (4.99 kB). View file
 
drqa/reader/__pycache__/layers.cpython-38.pyc ADDED
Binary file (7.73 kB). View file
 
drqa/reader/__pycache__/model.cpython-38.pyc ADDED
Binary file (13.3 kB). View file
 
drqa/reader/__pycache__/predictor.cpython-38.pyc ADDED
Binary file (4.22 kB). View file
 
drqa/reader/__pycache__/rnn_reader.cpython-38.pyc ADDED
Binary file (2.9 kB). View file
 
drqa/reader/__pycache__/utils.cpython-38.pyc ADDED
Binary file (8.67 kB). View file
 
drqa/reader/__pycache__/vector.cpython-38.pyc ADDED
Binary file (4.71 kB). View file
 
drqa/reader/config.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2017-present, Facebook, Inc.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+ """Model architecture/optimization options for DrQA document reader."""
8
+
9
+ import argparse
10
+ import logging
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ # Index of arguments concerning the core model architecture
15
+ MODEL_ARCHITECTURE = {
16
+ 'model_type', 'embedding_dim', 'hidden_size', 'doc_layers',
17
+ 'question_layers', 'rnn_type', 'concat_rnn_layers', 'question_merge',
18
+ 'use_qemb', 'use_in_question', 'use_pos', 'use_ner', 'use_lemma', 'use_tf'
19
+ }
20
+
21
+ # Index of arguments concerning the model optimizer/training
22
+ MODEL_OPTIMIZER = {
23
+ 'fix_embeddings', 'optimizer', 'learning_rate', 'momentum', 'weight_decay',
24
+ 'rnn_padding', 'dropout_rnn', 'dropout_rnn_output', 'dropout_emb',
25
+ 'max_len', 'grad_clipping', 'tune_partial'
26
+ }
27
+
28
+
29
+ def str2bool(v):
30
+ return v.lower() in ('yes', 'true', 't', '1', 'y')
31
+
32
+
33
+ def add_model_args(parser):
34
+ parser.register('type', 'bool', str2bool)
35
+
36
+ # Model architecture
37
+ model = parser.add_argument_group('DrQA Reader Model Architecture')
38
+ model.add_argument('--model-type', type=str, default='rnn',
39
+ help='Model architecture type')
40
+ model.add_argument('--embedding-dim', type=int, default=300,
41
+ help='Embedding size if embedding_file is not given')
42
+ model.add_argument('--hidden-size', type=int, default=128,
43
+ help='Hidden size of RNN units')
44
+ model.add_argument('--doc-layers', type=int, default=3,
45
+ help='Number of encoding layers for document')
46
+ model.add_argument('--question-layers', type=int, default=3,
47
+ help='Number of encoding layers for question')
48
+ model.add_argument('--rnn-type', type=str, default='lstm',
49
+ help='RNN type: LSTM, GRU, or RNN')
50
+
51
+ # Model specific details
52
+ detail = parser.add_argument_group('DrQA Reader Model Details')
53
+ detail.add_argument('--concat-rnn-layers', type='bool', default=True,
54
+ help='Combine hidden states from each encoding layer')
55
+ detail.add_argument('--question-merge', type=str, default='self_attn',
56
+ help='The way of computing the question representation')
57
+ detail.add_argument('--use-qemb', type='bool', default=True,
58
+ help='Whether to use weighted question embeddings')
59
+ detail.add_argument('--use-in-question', type='bool', default=True,
60
+ help='Whether to use in_question_* features')
61
+ detail.add_argument('--use-pos', type='bool', default=True,
62
+ help='Whether to use pos features')
63
+ detail.add_argument('--use-ner', type='bool', default=True,
64
+ help='Whether to use ner features')
65
+ detail.add_argument('--use-lemma', type='bool', default=True,
66
+ help='Whether to use lemma features')
67
+ detail.add_argument('--use-tf', type='bool', default=True,
68
+ help='Whether to use term frequency features')
69
+
70
+ # Optimization details
71
+ optim = parser.add_argument_group('DrQA Reader Optimization')
72
+ optim.add_argument('--dropout-emb', type=float, default=0.4,
73
+ help='Dropout rate for word embeddings')
74
+ optim.add_argument('--dropout-rnn', type=float, default=0.4,
75
+ help='Dropout rate for RNN states')
76
+ optim.add_argument('--dropout-rnn-output', type='bool', default=True,
77
+ help='Whether to dropout the RNN output')
78
+ optim.add_argument('--optimizer', type=str, default='adamax',
79
+ help='Optimizer: sgd or adamax')
80
+ optim.add_argument('--learning-rate', type=float, default=0.1,
81
+ help='Learning rate for SGD only')
82
+ optim.add_argument('--grad-clipping', type=float, default=10,
83
+ help='Gradient clipping')
84
+ optim.add_argument('--weight-decay', type=float, default=0,
85
+ help='Weight decay factor')
86
+ optim.add_argument('--momentum', type=float, default=0,
87
+ help='Momentum factor')
88
+ optim.add_argument('--fix-embeddings', type='bool', default=True,
89
+ help='Keep word embeddings fixed (use pretrained)')
90
+ optim.add_argument('--tune-partial', type=int, default=0,
91
+ help='Backprop through only the top N question words')
92
+ optim.add_argument('--rnn-padding', type='bool', default=False,
93
+ help='Explicitly account for padding in RNN encoding')
94
+ optim.add_argument('--max-len', type=int, default=15,
95
+ help='The max span allowed during decoding')
96
+
97
+
98
+ def get_model_args(args):
99
+ """Filter args for model ones.
100
+
101
+ From a args Namespace, return a new Namespace with *only* the args specific
102
+ to the model architecture or optimization. (i.e. the ones defined here.)
103
+ """
104
+ global MODEL_ARCHITECTURE, MODEL_OPTIMIZER
105
+ required_args = MODEL_ARCHITECTURE | MODEL_OPTIMIZER
106
+ arg_values = {k: v for k, v in vars(args).items() if k in required_args}
107
+ return argparse.Namespace(**arg_values)
108
+
109
+
110
+ def override_model_args(old_args, new_args):
111
+ """Set args to new parameters.
112
+
113
+ Decide which model args to keep and which to override when resolving a set
114
+ of saved args and new args.
115
+
116
+ We keep the new optimation, but leave the model architecture alone.
117
+ """
118
+ global MODEL_OPTIMIZER
119
+ old_args, new_args = vars(old_args), vars(new_args)
120
+ for k in old_args.keys():
121
+ if k in new_args and old_args[k] != new_args[k]:
122
+ if k in MODEL_OPTIMIZER:
123
+ logger.info('Overriding saved %s: %s --> %s' %
124
+ (k, old_args[k], new_args[k]))
125
+ old_args[k] = new_args[k]
126
+ else:
127
+ logger.info('Keeping saved %s: %s' % (k, old_args[k]))
128
+ return argparse.Namespace(**old_args)
drqa/reader/data.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2017-present, Facebook, Inc.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+ """Data processing/loading helpers."""
8
+
9
+ import numpy as np
10
+ import logging
11
+ import unicodedata
12
+
13
+ from torch.utils.data import Dataset
14
+ from torch.utils.data.sampler import Sampler
15
+ from .vector import vectorize
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ # ------------------------------------------------------------------------------
21
+ # Dictionary class for tokens.
22
+ # ------------------------------------------------------------------------------
23
+
24
+
25
+ class Dictionary(object):
26
+ NULL = '<NULL>'
27
+ UNK = '<UNK>'
28
+ START = 2
29
+
30
+ @staticmethod
31
+ def normalize(token):
32
+ return unicodedata.normalize('NFD', token)
33
+
34
+ def __init__(self):
35
+ self.tok2ind = {self.NULL: 0, self.UNK: 1}
36
+ self.ind2tok = {0: self.NULL, 1: self.UNK}
37
+
38
+ def __len__(self):
39
+ return len(self.tok2ind)
40
+
41
+ def __iter__(self):
42
+ return iter(self.tok2ind)
43
+
44
+ def __contains__(self, key):
45
+ if type(key) == int:
46
+ return key in self.ind2tok
47
+ elif type(key) == str:
48
+ return self.normalize(key) in self.tok2ind
49
+
50
+ def __getitem__(self, key):
51
+ if type(key) == int:
52
+ return self.ind2tok.get(key, self.UNK)
53
+ if type(key) == str:
54
+ return self.tok2ind.get(self.normalize(key),
55
+ self.tok2ind.get(self.UNK))
56
+
57
+ def __setitem__(self, key, item):
58
+ if type(key) == int and type(item) == str:
59
+ self.ind2tok[key] = item
60
+ elif type(key) == str and type(item) == int:
61
+ self.tok2ind[key] = item
62
+ else:
63
+ raise RuntimeError('Invalid (key, item) types.')
64
+
65
+ def add(self, token):
66
+ token = self.normalize(token)
67
+ if token not in self.tok2ind:
68
+ index = len(self.tok2ind)
69
+ self.tok2ind[token] = index
70
+ self.ind2tok[index] = token
71
+
72
+ def tokens(self):
73
+ """Get dictionary tokens.
74
+
75
+ Return all the words indexed by this dictionary, except for special
76
+ tokens.
77
+ """
78
+ tokens = [k for k in self.tok2ind.keys()
79
+ if k not in {'<NULL>', '<UNK>'}]
80
+ return tokens
81
+
82
+
83
+ # ------------------------------------------------------------------------------
84
+ # PyTorch dataset class for SQuAD (and SQuAD-like) data.
85
+ # ------------------------------------------------------------------------------
86
+
87
+
88
+ class ReaderDataset(Dataset):
89
+
90
+ def __init__(self, examples, model, single_answer=False):
91
+ self.model = model
92
+ self.examples = examples
93
+ self.single_answer = single_answer
94
+
95
+ def __len__(self):
96
+ return len(self.examples)
97
+
98
+ def __getitem__(self, index):
99
+ return vectorize(self.examples[index], self.model, self.single_answer)
100
+
101
+ def lengths(self):
102
+ return [(len(ex['document']), len(ex['question']))
103
+ for ex in self.examples]
104
+
105
+
106
+ # ------------------------------------------------------------------------------
107
+ # PyTorch sampler returning batched of sorted lengths (by doc and question).
108
+ # ------------------------------------------------------------------------------
109
+
110
+
111
+ class SortedBatchSampler(Sampler):
112
+
113
+ def __init__(self, lengths, batch_size, shuffle=True):
114
+ self.lengths = lengths
115
+ self.batch_size = batch_size
116
+ self.shuffle = shuffle
117
+
118
+ def __iter__(self):
119
+ lengths = np.array(
120
+ [(-l[0], -l[1], np.random.random()) for l in self.lengths],
121
+ dtype=[('l1', np.int_), ('l2', np.int_), ('rand', np.float_)]
122
+ )
123
+ indices = np.argsort(lengths, order=('l1', 'l2', 'rand'))
124
+ batches = [indices[i:i + self.batch_size]
125
+ for i in range(0, len(indices), self.batch_size)]
126
+ if self.shuffle:
127
+ np.random.shuffle(batches)
128
+ return iter([i for batch in batches for i in batch])
129
+
130
+ def __len__(self):
131
+ return len(self.lengths)
drqa/reader/layers.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2017-present, Facebook, Inc.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+ """Definitions of model layers/NN modules"""
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+
14
+ # ------------------------------------------------------------------------------
15
+ # Modules
16
+ # ------------------------------------------------------------------------------
17
+
18
+
19
+ class StackedBRNN(nn.Module):
20
+ """Stacked Bi-directional RNNs.
21
+
22
+ Differs from standard PyTorch library in that it has the option to save
23
+ and concat the hidden states between layers. (i.e. the output hidden size
24
+ for each sequence input is num_layers * hidden_size).
25
+ """
26
+
27
+ def __init__(self, input_size, hidden_size, num_layers,
28
+ dropout_rate=0, dropout_output=False, rnn_type=nn.LSTM,
29
+ concat_layers=False, padding=False):
30
+ super(StackedBRNN, self).__init__()
31
+ self.padding = padding
32
+ self.dropout_output = dropout_output
33
+ self.dropout_rate = dropout_rate
34
+ self.num_layers = num_layers
35
+ self.concat_layers = concat_layers
36
+ self.rnns = nn.ModuleList()
37
+ for i in range(num_layers):
38
+ input_size = input_size if i == 0 else 2 * hidden_size
39
+ self.rnns.append(rnn_type(input_size, hidden_size,
40
+ num_layers=1,
41
+ bidirectional=True))
42
+
43
+ def forward(self, x, x_mask):
44
+ """Encode either padded or non-padded sequences.
45
+
46
+ Can choose to either handle or ignore variable length sequences.
47
+ Always handle padding in eval.
48
+
49
+ Args:
50
+ x: batch * len * hdim
51
+ x_mask: batch * len (1 for padding, 0 for true)
52
+ Output:
53
+ x_encoded: batch * len * hdim_encoded
54
+ """
55
+ if x_mask.data.sum() == 0:
56
+ # No padding necessary.
57
+ output = self._forward_unpadded(x, x_mask)
58
+ elif self.padding or not self.training:
59
+ # Pad if we care or if its during eval.
60
+ output = self._forward_padded(x, x_mask)
61
+ else:
62
+ # We don't care.
63
+ output = self._forward_unpadded(x, x_mask)
64
+
65
+ return output.contiguous()
66
+
67
+ def _forward_unpadded(self, x, x_mask):
68
+ """Faster encoding that ignores any padding."""
69
+ # Transpose batch and sequence dims
70
+ x = x.transpose(0, 1)
71
+
72
+ # Encode all layers
73
+ outputs = [x]
74
+ for i in range(self.num_layers):
75
+ rnn_input = outputs[-1]
76
+
77
+ # Apply dropout to hidden input
78
+ if self.dropout_rate > 0:
79
+ rnn_input = F.dropout(rnn_input,
80
+ p=self.dropout_rate,
81
+ training=self.training)
82
+ # Forward
83
+ rnn_output = self.rnns[i](rnn_input)[0]
84
+ outputs.append(rnn_output)
85
+
86
+ # Concat hidden layers
87
+ if self.concat_layers:
88
+ output = torch.cat(outputs[1:], 2)
89
+ else:
90
+ output = outputs[-1]
91
+
92
+ # Transpose back
93
+ output = output.transpose(0, 1)
94
+
95
+ # Dropout on output layer
96
+ if self.dropout_output and self.dropout_rate > 0:
97
+ output = F.dropout(output,
98
+ p=self.dropout_rate,
99
+ training=self.training)
100
+ return output
101
+
102
+ def _forward_padded(self, x, x_mask):
103
+ """Slower (significantly), but more precise, encoding that handles
104
+ padding.
105
+ """
106
+ # Compute sorted sequence lengths
107
+ lengths = x_mask.data.eq(0).long().sum(1).squeeze()
108
+ _, idx_sort = torch.sort(lengths, dim=0, descending=True)
109
+ _, idx_unsort = torch.sort(idx_sort, dim=0)
110
+ lengths = list(lengths[idx_sort])
111
+
112
+ # Sort x
113
+ x = x.index_select(0, idx_sort)
114
+
115
+ # Transpose batch and sequence dims
116
+ x = x.transpose(0, 1)
117
+
118
+ # Pack it up
119
+ rnn_input = nn.utils.rnn.pack_padded_sequence(x, lengths)
120
+
121
+ # Encode all layers
122
+ outputs = [rnn_input]
123
+ for i in range(self.num_layers):
124
+ rnn_input = outputs[-1]
125
+
126
+ # Apply dropout to input
127
+ if self.dropout_rate > 0:
128
+ dropout_input = F.dropout(rnn_input.data,
129
+ p=self.dropout_rate,
130
+ training=self.training)
131
+ rnn_input = nn.utils.rnn.PackedSequence(dropout_input,
132
+ rnn_input.batch_sizes)
133
+ outputs.append(self.rnns[i](rnn_input)[0])
134
+
135
+ # Unpack everything
136
+ for i, o in enumerate(outputs[1:], 1):
137
+ outputs[i] = nn.utils.rnn.pad_packed_sequence(o)[0]
138
+
139
+ # Concat hidden layers or take final
140
+ if self.concat_layers:
141
+ output = torch.cat(outputs[1:], 2)
142
+ else:
143
+ output = outputs[-1]
144
+
145
+ # Transpose and unsort
146
+ output = output.transpose(0, 1)
147
+ output = output.index_select(0, idx_unsort)
148
+
149
+ # Pad up to original batch sequence length
150
+ if output.size(1) != x_mask.size(1):
151
+ padding = torch.zeros(output.size(0),
152
+ x_mask.size(1) - output.size(1),
153
+ output.size(2)).type(output.data.type())
154
+ output = torch.cat([output, padding], 1)
155
+
156
+ # Dropout on output layer
157
+ if self.dropout_output and self.dropout_rate > 0:
158
+ output = F.dropout(output,
159
+ p=self.dropout_rate,
160
+ training=self.training)
161
+ return output
162
+
163
+
164
+ class SeqAttnMatch(nn.Module):
165
+ """Given sequences X and Y, match sequence Y to each element in X.
166
+
167
+ * o_i = sum(alpha_j * y_j) for i in X
168
+ * alpha_j = softmax(y_j * x_i)
169
+ """
170
+
171
+ def __init__(self, input_size, identity=False):
172
+ super(SeqAttnMatch, self).__init__()
173
+ if not identity:
174
+ self.linear = nn.Linear(input_size, input_size)
175
+ else:
176
+ self.linear = None
177
+
178
+ def forward(self, x, y, y_mask):
179
+ """
180
+ Args:
181
+ x: batch * len1 * hdim
182
+ y: batch * len2 * hdim
183
+ y_mask: batch * len2 (1 for padding, 0 for true)
184
+ Output:
185
+ matched_seq: batch * len1 * hdim
186
+ """
187
+ # Project vectors
188
+ if self.linear:
189
+ x_proj = self.linear(x.view(-1, x.size(2))).view(x.size())
190
+ x_proj = F.relu(x_proj)
191
+ y_proj = self.linear(y.view(-1, y.size(2))).view(y.size())
192
+ y_proj = F.relu(y_proj)
193
+ else:
194
+ x_proj = x
195
+ y_proj = y
196
+
197
+ # Compute scores
198
+ scores = x_proj.bmm(y_proj.transpose(2, 1))
199
+
200
+ # Mask padding
201
+ y_mask = y_mask.unsqueeze(1).expand(scores.size())
202
+ scores.data.masked_fill_(y_mask.data, -float('inf'))
203
+
204
+ # Normalize with softmax
205
+ alpha_flat = F.softmax(scores.view(-1, y.size(1)), dim=-1)
206
+ alpha = alpha_flat.view(-1, x.size(1), y.size(1))
207
+
208
+ # Take weighted average
209
+ matched_seq = alpha.bmm(y)
210
+ return matched_seq
211
+
212
+
213
+ class BilinearSeqAttn(nn.Module):
214
+ """A bilinear attention layer over a sequence X w.r.t y:
215
+
216
+ * o_i = softmax(x_i'Wy) for x_i in X.
217
+
218
+ Optionally don't normalize output weights.
219
+ """
220
+
221
+ def __init__(self, x_size, y_size, identity=False, normalize=True):
222
+ super(BilinearSeqAttn, self).__init__()
223
+ self.normalize = normalize
224
+
225
+ # If identity is true, we just use a dot product without transformation.
226
+ if not identity:
227
+ self.linear = nn.Linear(y_size, x_size)
228
+ else:
229
+ self.linear = None
230
+
231
+ def forward(self, x, y, x_mask):
232
+ """
233
+ Args:
234
+ x: batch * len * hdim1
235
+ y: batch * hdim2
236
+ x_mask: batch * len (1 for padding, 0 for true)
237
+ Output:
238
+ alpha = batch * len
239
+ """
240
+ Wy = self.linear(y) if self.linear is not None else y
241
+ xWy = x.bmm(Wy.unsqueeze(2)).squeeze(2)
242
+ xWy.data.masked_fill_(x_mask.data, -float('inf'))
243
+ if self.normalize:
244
+ if self.training:
245
+ # In training we output log-softmax for NLL
246
+ alpha = F.log_softmax(xWy, dim=-1)
247
+ else:
248
+ # ...Otherwise 0-1 probabilities
249
+ alpha = F.softmax(xWy, dim=-1)
250
+ else:
251
+ alpha = xWy.exp()
252
+ return alpha
253
+
254
+
255
+ class LinearSeqAttn(nn.Module):
256
+ """Self attention over a sequence:
257
+
258
+ * o_i = softmax(Wx_i) for x_i in X.
259
+ """
260
+
261
+ def __init__(self, input_size):
262
+ super(LinearSeqAttn, self).__init__()
263
+ self.linear = nn.Linear(input_size, 1)
264
+
265
+ def forward(self, x, x_mask):
266
+ """
267
+ Args:
268
+ x: batch * len * hdim
269
+ x_mask: batch * len (1 for padding, 0 for true)
270
+ Output:
271
+ alpha: batch * len
272
+ """
273
+ x_flat = x.view(-1, x.size(-1))
274
+ scores = self.linear(x_flat).view(x.size(0), x.size(1))
275
+ scores.data.masked_fill_(x_mask.data, -float('inf'))
276
+ alpha = F.softmax(scores, dim=-1)
277
+ return alpha
278
+
279
+
280
+ # ------------------------------------------------------------------------------
281
+ # Functional
282
+ # ------------------------------------------------------------------------------
283
+
284
+
285
+ def uniform_weights(x, x_mask):
286
+ """Return uniform weights over non-masked x (a sequence of vectors).
287
+
288
+ Args:
289
+ x: batch * len * hdim
290
+ x_mask: batch * len (1 for padding, 0 for true)
291
+ Output:
292
+ x_avg: batch * hdim
293
+ """
294
+ alpha = torch.ones(x.size(0), x.size(1))
295
+ if x.data.is_cuda:
296
+ alpha = alpha.cuda()
297
+ alpha = alpha * x_mask.eq(0).float()
298
+ alpha = alpha / alpha.sum(1).expand(alpha.size())
299
+ return alpha
300
+
301
+
302
+ def weighted_avg(x, weights):
303
+ """Return a weighted average of x (a sequence of vectors).
304
+
305
+ Args:
306
+ x: batch * len * hdim
307
+ weights: batch * len, sum(dim = 1) = 1
308
+ Output:
309
+ x_avg: batch * hdim
310
+ """
311
+ return weights.unsqueeze(1).bmm(x).squeeze(1)
drqa/reader/model.py ADDED
@@ -0,0 +1,482 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2017-present, Facebook, Inc.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+ """DrQA Document Reader model"""
8
+
9
+ import torch
10
+ import torch.optim as optim
11
+ import torch.nn.functional as F
12
+ import numpy as np
13
+ import logging
14
+ import copy
15
+
16
+ from .config import override_model_args
17
+ from .rnn_reader import RnnDocReader
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class DocReader(object):
23
+ """High level model that handles intializing the underlying network
24
+ architecture, saving, updating examples, and predicting examples.
25
+ """
26
+
27
+ # --------------------------------------------------------------------------
28
+ # Initialization
29
+ # --------------------------------------------------------------------------
30
+
31
+ def __init__(self, args, word_dict, feature_dict,
32
+ state_dict=None, normalize=True):
33
+ # Book-keeping.
34
+ self.args = args
35
+ self.word_dict = word_dict
36
+ self.args.vocab_size = len(word_dict)
37
+ self.feature_dict = feature_dict
38
+ self.args.num_features = len(feature_dict)
39
+ self.updates = 0
40
+ self.use_cuda = False
41
+ self.parallel = False
42
+
43
+ # Building network. If normalize if false, scores are not normalized
44
+ # 0-1 per paragraph (no softmax).
45
+ if args.model_type == 'rnn':
46
+ self.network = RnnDocReader(args, normalize)
47
+ else:
48
+ raise RuntimeError('Unsupported model: %s' % args.model_type)
49
+
50
+ # Load saved state
51
+ if state_dict:
52
+ # Load buffer separately
53
+ if 'fixed_embedding' in state_dict:
54
+ fixed_embedding = state_dict.pop('fixed_embedding')
55
+ self.network.load_state_dict(state_dict)
56
+ self.network.register_buffer('fixed_embedding', fixed_embedding)
57
+ else:
58
+ self.network.load_state_dict(state_dict)
59
+
60
+ def expand_dictionary(self, words):
61
+ """Add words to the DocReader dictionary if they do not exist. The
62
+ underlying embedding matrix is also expanded (with random embeddings).
63
+
64
+ Args:
65
+ words: iterable of tokens to add to the dictionary.
66
+ Output:
67
+ added: set of tokens that were added.
68
+ """
69
+ to_add = {self.word_dict.normalize(w) for w in words
70
+ if w not in self.word_dict}
71
+
72
+ # Add words to dictionary and expand embedding layer
73
+ if len(to_add) > 0:
74
+ logger.info('Adding %d new words to dictionary...' % len(to_add))
75
+ for w in to_add:
76
+ self.word_dict.add(w)
77
+ self.args.vocab_size = len(self.word_dict)
78
+ logger.info('New vocab size: %d' % len(self.word_dict))
79
+
80
+ old_embedding = self.network.embedding.weight.data
81
+ self.network.embedding = torch.nn.Embedding(self.args.vocab_size,
82
+ self.args.embedding_dim,
83
+ padding_idx=0)
84
+ new_embedding = self.network.embedding.weight.data
85
+ new_embedding[:old_embedding.size(0)] = old_embedding
86
+
87
+ # Return added words
88
+ return to_add
89
+
90
+ def load_embeddings(self, words, embedding_file):
91
+ """Load pretrained embeddings for a given list of words, if they exist.
92
+
93
+ Args:
94
+ words: iterable of tokens. Only those that are indexed in the
95
+ dictionary are kept.
96
+ embedding_file: path to text file of embeddings, space separated.
97
+ """
98
+ words = {w for w in words if w in self.word_dict}
99
+ logger.info('Loading pre-trained embeddings for %d words from %s' %
100
+ (len(words), embedding_file))
101
+ embedding = self.network.embedding.weight.data
102
+
103
+ # When normalized, some words are duplicated. (Average the embeddings).
104
+ vec_counts = {}
105
+ with open(embedding_file) as f:
106
+ # Skip first line if of form count/dim.
107
+ line = f.readline().rstrip().split(' ')
108
+ if len(line) != 2:
109
+ f.seek(0)
110
+ for line in f:
111
+ parsed = line.rstrip().split(' ')
112
+ assert(len(parsed) == embedding.size(1) + 1)
113
+ w = self.word_dict.normalize(parsed[0])
114
+ if w in words:
115
+ vec = torch.Tensor([float(i) for i in parsed[1:]])
116
+ if w not in vec_counts:
117
+ vec_counts[w] = 1
118
+ embedding[self.word_dict[w]].copy_(vec)
119
+ else:
120
+ logging.warning(
121
+ 'WARN: Duplicate embedding found for %s' % w
122
+ )
123
+ vec_counts[w] = vec_counts[w] + 1
124
+ embedding[self.word_dict[w]].add_(vec)
125
+
126
+ for w, c in vec_counts.items():
127
+ embedding[self.word_dict[w]].div_(c)
128
+
129
+ logger.info('Loaded %d embeddings (%.2f%%)' %
130
+ (len(vec_counts), 100 * len(vec_counts) / len(words)))
131
+
132
+ def tune_embeddings(self, words):
133
+ """Unfix the embeddings of a list of words. This is only relevant if
134
+ only some of the embeddings are being tuned (tune_partial = N).
135
+
136
+ Shuffles the N specified words to the front of the dictionary, and saves
137
+ the original vectors of the other N + 1:vocab words in a fixed buffer.
138
+
139
+ Args:
140
+ words: iterable of tokens contained in dictionary.
141
+ """
142
+ words = {w for w in words if w in self.word_dict}
143
+
144
+ if len(words) == 0:
145
+ logger.warning('Tried to tune embeddings, but no words given!')
146
+ return
147
+
148
+ if len(words) == len(self.word_dict):
149
+ logger.warning('Tuning ALL embeddings in dictionary')
150
+ return
151
+
152
+ # Shuffle words and vectors
153
+ embedding = self.network.embedding.weight.data
154
+ for idx, swap_word in enumerate(words, self.word_dict.START):
155
+ # Get current word + embedding for this index
156
+ curr_word = self.word_dict[idx]
157
+ curr_emb = embedding[idx].clone()
158
+ old_idx = self.word_dict[swap_word]
159
+
160
+ # Swap embeddings + dictionary indices
161
+ embedding[idx].copy_(embedding[old_idx])
162
+ embedding[old_idx].copy_(curr_emb)
163
+ self.word_dict[swap_word] = idx
164
+ self.word_dict[idx] = swap_word
165
+ self.word_dict[curr_word] = old_idx
166
+ self.word_dict[old_idx] = curr_word
167
+
168
+ # Save the original, fixed embeddings
169
+ self.network.register_buffer(
170
+ 'fixed_embedding', embedding[idx + 1:].clone()
171
+ )
172
+
173
+ def init_optimizer(self, state_dict=None):
174
+ """Initialize an optimizer for the free parameters of the network.
175
+
176
+ Args:
177
+ state_dict: network parameters
178
+ """
179
+ if self.args.fix_embeddings:
180
+ for p in self.network.embedding.parameters():
181
+ p.requires_grad = False
182
+ parameters = [p for p in self.network.parameters() if p.requires_grad]
183
+ if self.args.optimizer == 'sgd':
184
+ self.optimizer = optim.SGD(parameters, self.args.learning_rate,
185
+ momentum=self.args.momentum,
186
+ weight_decay=self.args.weight_decay)
187
+ elif self.args.optimizer == 'adamax':
188
+ self.optimizer = optim.Adamax(parameters,
189
+ weight_decay=self.args.weight_decay)
190
+ else:
191
+ raise RuntimeError('Unsupported optimizer: %s' %
192
+ self.args.optimizer)
193
+
194
+ # --------------------------------------------------------------------------
195
+ # Learning
196
+ # --------------------------------------------------------------------------
197
+
198
+ def update(self, ex):
199
+ """Forward a batch of examples; step the optimizer to update weights."""
200
+ if not self.optimizer:
201
+ raise RuntimeError('No optimizer set.')
202
+
203
+ # Train mode
204
+ self.network.train()
205
+
206
+ # Transfer to GPU
207
+ if self.use_cuda:
208
+ inputs = [e if e is None else e.cuda(non_blocking=True)
209
+ for e in ex[:5]]
210
+ target_s = ex[5].cuda(non_blocking=True)
211
+ target_e = ex[6].cuda(non_blocking=True)
212
+ else:
213
+ inputs = [e if e is None else e for e in ex[:5]]
214
+ target_s = ex[5]
215
+ target_e = ex[6]
216
+
217
+ # Run forward
218
+ score_s, score_e = self.network(*inputs)
219
+
220
+ # Compute loss and accuracies
221
+ loss = F.nll_loss(score_s, target_s) + F.nll_loss(score_e, target_e)
222
+
223
+ # Clear gradients and run backward
224
+ self.optimizer.zero_grad()
225
+ loss.backward()
226
+
227
+ # Clip gradients
228
+ torch.nn.utils.clip_grad_norm_(self.network.parameters(),
229
+ self.args.grad_clipping)
230
+
231
+ # Update parameters
232
+ self.optimizer.step()
233
+ self.updates += 1
234
+
235
+ # Reset any partially fixed parameters (e.g. rare words)
236
+ self.reset_parameters()
237
+
238
+ return loss.item(), ex[0].size(0)
239
+
240
+ def reset_parameters(self):
241
+ """Reset any partially fixed parameters to original states."""
242
+
243
+ # Reset fixed embeddings to original value
244
+ if self.args.tune_partial > 0:
245
+ if self.parallel:
246
+ embedding = self.network.module.embedding.weight.data
247
+ fixed_embedding = self.network.module.fixed_embedding
248
+ else:
249
+ embedding = self.network.embedding.weight.data
250
+ fixed_embedding = self.network.fixed_embedding
251
+
252
+ # Embeddings to fix are the last indices
253
+ offset = embedding.size(0) - fixed_embedding.size(0)
254
+ if offset >= 0:
255
+ embedding[offset:] = fixed_embedding
256
+
257
+ # --------------------------------------------------------------------------
258
+ # Prediction
259
+ # --------------------------------------------------------------------------
260
+
261
+ def predict(self, ex, candidates=None, top_n=1, async_pool=None):
262
+ """Forward a batch of examples only to get predictions.
263
+
264
+ Args:
265
+ ex: the batch
266
+ candidates: batch * variable length list of string answer options.
267
+ The model will only consider exact spans contained in this list.
268
+ top_n: Number of predictions to return per batch element.
269
+ async_pool: If provided, non-gpu post-processing will be offloaded
270
+ to this CPU process pool.
271
+ Output:
272
+ pred_s: batch * top_n predicted start indices
273
+ pred_e: batch * top_n predicted end indices
274
+ pred_score: batch * top_n prediction scores
275
+
276
+ If async_pool is given, these will be AsyncResult handles.
277
+ """
278
+ # Eval mode
279
+ self.network.eval()
280
+
281
+ # Transfer to GPU
282
+ if self.use_cuda:
283
+ inputs = [e if e is None else e.cuda(non_blocking=True)
284
+ for e in ex[:5]]
285
+ else:
286
+ inputs = [e for e in ex[:5]]
287
+
288
+ # Run forward
289
+ with torch.no_grad():
290
+ score_s, score_e = self.network(*inputs)
291
+
292
+ # Decode predictions
293
+ score_s = score_s.data.cpu()
294
+ score_e = score_e.data.cpu()
295
+ if candidates:
296
+ args = (score_s, score_e, candidates, top_n, self.args.max_len)
297
+ if async_pool:
298
+ return async_pool.apply_async(self.decode_candidates, args)
299
+ else:
300
+ return self.decode_candidates(*args)
301
+ else:
302
+ args = (score_s, score_e, top_n, self.args.max_len)
303
+ if async_pool:
304
+ return async_pool.apply_async(self.decode, args)
305
+ else:
306
+ return self.decode(*args)
307
+
308
+ @staticmethod
309
+ def decode(score_s, score_e, top_n=1, max_len=None):
310
+ """Take argmax of constrained score_s * score_e.
311
+
312
+ Args:
313
+ score_s: independent start predictions
314
+ score_e: independent end predictions
315
+ top_n: number of top scored pairs to take
316
+ max_len: max span length to consider
317
+ """
318
+ pred_s = []
319
+ pred_e = []
320
+ pred_score = []
321
+ max_len = max_len or score_s.size(1)
322
+ for i in range(score_s.size(0)):
323
+ # Outer product of scores to get full p_s * p_e matrix
324
+ scores = torch.ger(score_s[i], score_e[i])
325
+
326
+ # Zero out negative length and over-length span scores
327
+ scores.triu_().tril_(max_len - 1)
328
+
329
+ # Take argmax or top n
330
+ scores = scores.numpy()
331
+ scores_flat = scores.flatten()
332
+ if top_n == 1:
333
+ idx_sort = [np.argmax(scores_flat)]
334
+ elif len(scores_flat) < top_n:
335
+ idx_sort = np.argsort(-scores_flat)
336
+ else:
337
+ idx = np.argpartition(-scores_flat, top_n)[0:top_n]
338
+ idx_sort = idx[np.argsort(-scores_flat[idx])]
339
+ s_idx, e_idx = np.unravel_index(idx_sort, scores.shape)
340
+ pred_s.append(s_idx)
341
+ pred_e.append(e_idx)
342
+ pred_score.append(scores_flat[idx_sort])
343
+ return pred_s, pred_e, pred_score
344
+
345
+ @staticmethod
346
+ def decode_candidates(score_s, score_e, candidates, top_n=1, max_len=None):
347
+ """Take argmax of constrained score_s * score_e. Except only consider
348
+ spans that are in the candidates list.
349
+ """
350
+ pred_s = []
351
+ pred_e = []
352
+ pred_score = []
353
+ for i in range(score_s.size(0)):
354
+ # Extract original tokens stored with candidates
355
+ tokens = candidates[i]['input']
356
+ cands = candidates[i]['cands']
357
+
358
+ if not cands:
359
+ # try getting from globals? (multiprocessing in pipeline mode)
360
+ from ..pipeline.drqa import PROCESS_CANDS
361
+ cands = PROCESS_CANDS
362
+ if not cands:
363
+ raise RuntimeError('No candidates given.')
364
+
365
+ # Score all valid candidates found in text.
366
+ # Brute force get all ngrams and compare against the candidate list.
367
+ max_len = max_len or len(tokens)
368
+ scores, s_idx, e_idx = [], [], []
369
+ for s, e in tokens.ngrams(n=max_len, as_strings=False):
370
+ span = tokens.slice(s, e).untokenize()
371
+ if span in cands or span.lower() in cands:
372
+ # Match! Record its score.
373
+ scores.append(score_s[i][s] * score_e[i][e - 1])
374
+ s_idx.append(s)
375
+ e_idx.append(e - 1)
376
+
377
+ if len(scores) == 0:
378
+ # No candidates present
379
+ pred_s.append([])
380
+ pred_e.append([])
381
+ pred_score.append([])
382
+ else:
383
+ # Rank found candidates
384
+ scores = np.array(scores)
385
+ s_idx = np.array(s_idx)
386
+ e_idx = np.array(e_idx)
387
+
388
+ idx_sort = np.argsort(-scores)[0:top_n]
389
+ pred_s.append(s_idx[idx_sort])
390
+ pred_e.append(e_idx[idx_sort])
391
+ pred_score.append(scores[idx_sort])
392
+ return pred_s, pred_e, pred_score
393
+
394
+ # --------------------------------------------------------------------------
395
+ # Saving and loading
396
+ # --------------------------------------------------------------------------
397
+
398
+ def save(self, filename):
399
+ if self.parallel:
400
+ network = self.network.module
401
+ else:
402
+ network = self.network
403
+ state_dict = copy.copy(network.state_dict())
404
+ if 'fixed_embedding' in state_dict:
405
+ state_dict.pop('fixed_embedding')
406
+ params = {
407
+ 'state_dict': state_dict,
408
+ 'word_dict': self.word_dict,
409
+ 'feature_dict': self.feature_dict,
410
+ 'args': self.args,
411
+ }
412
+ try:
413
+ torch.save(params, filename)
414
+ except BaseException:
415
+ logger.warning('WARN: Saving failed... continuing anyway.')
416
+
417
+ def checkpoint(self, filename, epoch):
418
+ if self.parallel:
419
+ network = self.network.module
420
+ else:
421
+ network = self.network
422
+ params = {
423
+ 'state_dict': network.state_dict(),
424
+ 'word_dict': self.word_dict,
425
+ 'feature_dict': self.feature_dict,
426
+ 'args': self.args,
427
+ 'epoch': epoch,
428
+ 'optimizer': self.optimizer.state_dict(),
429
+ }
430
+ try:
431
+ torch.save(params, filename)
432
+ except BaseException:
433
+ logger.warning('WARN: Saving failed... continuing anyway.')
434
+
435
+ @staticmethod
436
+ def load(filename, new_args=None, normalize=True):
437
+ logger.info('Loading model %s' % filename)
438
+ saved_params = torch.load(
439
+ filename, map_location=lambda storage, loc: storage
440
+ )
441
+ word_dict = saved_params['word_dict']
442
+ feature_dict = saved_params['feature_dict']
443
+ state_dict = saved_params['state_dict']
444
+ args = saved_params['args']
445
+ if new_args:
446
+ args = override_model_args(args, new_args)
447
+ return DocReader(args, word_dict, feature_dict, state_dict, normalize)
448
+
449
+ @staticmethod
450
+ def load_checkpoint(filename, normalize=True):
451
+ logger.info('Loading model %s' % filename)
452
+ saved_params = torch.load(
453
+ filename, map_location=lambda storage, loc: storage
454
+ )
455
+ word_dict = saved_params['word_dict']
456
+ feature_dict = saved_params['feature_dict']
457
+ state_dict = saved_params['state_dict']
458
+ epoch = saved_params['epoch']
459
+ optimizer = saved_params['optimizer']
460
+ args = saved_params['args']
461
+ model = DocReader(args, word_dict, feature_dict, state_dict, normalize)
462
+ model.init_optimizer(optimizer)
463
+ return model, epoch
464
+
465
+ # --------------------------------------------------------------------------
466
+ # Runtime
467
+ # --------------------------------------------------------------------------
468
+
469
+ def cuda(self):
470
+ self.use_cuda = True
471
+ self.network = self.network.cuda()
472
+
473
+ def cpu(self):
474
+ self.use_cuda = False
475
+ self.network = self.network.cpu()
476
+
477
+ def parallelize(self):
478
+ """Use data parallel to copy the model across several gpus.
479
+ This will take all gpus visible with CUDA_VISIBLE_DEVICES.
480
+ """
481
+ self.parallel = True
482
+ self.network = torch.nn.DataParallel(self.network)
drqa/reader/predictor.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2017-present, Facebook, Inc.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+ """DrQA Document Reader predictor"""
8
+
9
+ import logging
10
+
11
+ from multiprocessing import Pool as ProcessPool
12
+ from multiprocessing.util import Finalize
13
+
14
+ from .vector import vectorize, batchify
15
+ from .model import DocReader
16
+ from . import DEFAULTS, utils
17
+ from .. import tokenizers
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ # ------------------------------------------------------------------------------
23
+ # Tokenize + annotate
24
+ # ------------------------------------------------------------------------------
25
+
26
+ PROCESS_TOK = None
27
+
28
+
29
+ def init(tokenizer_class, annotators):
30
+ global PROCESS_TOK
31
+ PROCESS_TOK = tokenizer_class(annotators=annotators)
32
+ Finalize(PROCESS_TOK, PROCESS_TOK.shutdown, exitpriority=100)
33
+
34
+
35
+ def tokenize(text):
36
+ global PROCESS_TOK
37
+ return PROCESS_TOK.tokenize(text)
38
+
39
+
40
+ # ------------------------------------------------------------------------------
41
+ # Predictor class.
42
+ # ------------------------------------------------------------------------------
43
+
44
+
45
+ class Predictor(object):
46
+ """Load a pretrained DocReader model and predict inputs on the fly."""
47
+
48
+ def __init__(self, model=None, tokenizer=None, normalize=True,
49
+ embedding_file=None, num_workers=None):
50
+ """
51
+ Args:
52
+ model: path to saved model file.
53
+ tokenizer: option string to select tokenizer class.
54
+ normalize: squash output score to 0-1 probabilities with a softmax.
55
+ embedding_file: if provided, will expand dictionary to use all
56
+ available pretrained vectors in this file.
57
+ num_workers: number of CPU processes to use to preprocess batches.
58
+ """
59
+ logger.info('Initializing model...')
60
+ self.model = DocReader.load(model or DEFAULTS['model'],
61
+ normalize=normalize)
62
+
63
+ if embedding_file:
64
+ logger.info('Expanding dictionary...')
65
+ words = utils.index_embedding_words(embedding_file)
66
+ added = self.model.expand_dictionary(words)
67
+ self.model.load_embeddings(added, embedding_file)
68
+
69
+ logger.info('Initializing tokenizer...')
70
+ annotators = tokenizers.get_annotators_for_model(self.model)
71
+ if not tokenizer:
72
+ tokenizer_class = DEFAULTS['tokenizer']
73
+ else:
74
+ tokenizer_class = tokenizers.get_class(tokenizer)
75
+
76
+ if num_workers is None or num_workers > 0:
77
+ self.workers = ProcessPool(
78
+ num_workers,
79
+ initializer=init,
80
+ initargs=(tokenizer_class, annotators),
81
+ )
82
+ else:
83
+ self.workers = None
84
+ self.tokenizer = tokenizer_class(annotators=annotators)
85
+
86
+ def predict(self, document, question, candidates=None, top_n=1):
87
+ """Predict a single document - question pair."""
88
+ results = self.predict_batch([(document, question, candidates,)], top_n)
89
+ return results[0]
90
+
91
+ def predict_batch(self, batch, top_n=1):
92
+ """Predict a batch of document - question pairs."""
93
+ documents, questions, candidates = [], [], []
94
+ for b in batch:
95
+ documents.append(b[0])
96
+ questions.append(b[1])
97
+ candidates.append(b[2] if len(b) == 3 else None)
98
+ candidates = candidates if any(candidates) else None
99
+
100
+ # Tokenize the inputs, perhaps multi-processed.
101
+ if self.workers:
102
+ q_tokens = self.workers.map_async(tokenize, questions)
103
+ d_tokens = self.workers.map_async(tokenize, documents)
104
+ q_tokens = list(q_tokens.get())
105
+ d_tokens = list(d_tokens.get())
106
+ else:
107
+ q_tokens = list(map(self.tokenizer.tokenize, questions))
108
+ d_tokens = list(map(self.tokenizer.tokenize, documents))
109
+
110
+ examples = []
111
+ for i in range(len(questions)):
112
+ examples.append({
113
+ 'id': i,
114
+ 'question': q_tokens[i].words(),
115
+ 'qlemma': q_tokens[i].lemmas(),
116
+ 'document': d_tokens[i].words(),
117
+ 'lemma': d_tokens[i].lemmas(),
118
+ 'pos': d_tokens[i].pos(),
119
+ 'ner': d_tokens[i].entities(),
120
+ })
121
+
122
+ # Stick document tokens in candidates for decoding
123
+ if candidates:
124
+ candidates = [{'input': d_tokens[i], 'cands': candidates[i]}
125
+ for i in range(len(candidates))]
126
+
127
+ # Build the batch and run it through the model
128
+ batch_exs = batchify([vectorize(e, self.model) for e in examples])
129
+ s, e, score = self.model.predict(batch_exs, candidates, top_n)
130
+
131
+ # Retrieve the predicted spans
132
+ results = []
133
+ for i in range(len(s)):
134
+ predictions = []
135
+ for j in range(len(s[i])):
136
+ span = d_tokens[i].slice(s[i][j], e[i][j] + 1).untokenize()
137
+ predictions.append((span, score[i][j].item()))
138
+ results.append(predictions)
139
+ return results
140
+
141
+ def cuda(self):
142
+ self.model.cuda()
143
+
144
+ def cpu(self):
145
+ self.model.cpu()
drqa/reader/rnn_reader.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2017-present, Facebook, Inc.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+ """Implementation of the RNN based DrQA reader."""
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from . import layers
12
+
13
+
14
+ # ------------------------------------------------------------------------------
15
+ # Network
16
+ # ------------------------------------------------------------------------------
17
+
18
+
19
+ class RnnDocReader(nn.Module):
20
+ RNN_TYPES = {'lstm': nn.LSTM, 'gru': nn.GRU, 'rnn': nn.RNN}
21
+
22
+ def __init__(self, args, normalize=True):
23
+ super(RnnDocReader, self).__init__()
24
+ # Store config
25
+ self.args = args
26
+
27
+ # Word embeddings (+1 for padding)
28
+ self.embedding = nn.Embedding(args.vocab_size,
29
+ args.embedding_dim,
30
+ padding_idx=0)
31
+
32
+ # Projection for attention weighted question
33
+ if args.use_qemb:
34
+ self.qemb_match = layers.SeqAttnMatch(args.embedding_dim)
35
+
36
+ # Input size to RNN: word emb + question emb + manual features
37
+ doc_input_size = args.embedding_dim + args.num_features
38
+ if args.use_qemb:
39
+ doc_input_size += args.embedding_dim
40
+
41
+ # RNN document encoder
42
+ self.doc_rnn = layers.StackedBRNN(
43
+ input_size=doc_input_size,
44
+ hidden_size=args.hidden_size,
45
+ num_layers=args.doc_layers,
46
+ dropout_rate=args.dropout_rnn,
47
+ dropout_output=args.dropout_rnn_output,
48
+ concat_layers=args.concat_rnn_layers,
49
+ rnn_type=self.RNN_TYPES[args.rnn_type],
50
+ padding=args.rnn_padding,
51
+ )
52
+
53
+ # RNN question encoder
54
+ self.question_rnn = layers.StackedBRNN(
55
+ input_size=args.embedding_dim,
56
+ hidden_size=args.hidden_size,
57
+ num_layers=args.question_layers,
58
+ dropout_rate=args.dropout_rnn,
59
+ dropout_output=args.dropout_rnn_output,
60
+ concat_layers=args.concat_rnn_layers,
61
+ rnn_type=self.RNN_TYPES[args.rnn_type],
62
+ padding=args.rnn_padding,
63
+ )
64
+
65
+ # Output sizes of rnn encoders
66
+ doc_hidden_size = 2 * args.hidden_size
67
+ question_hidden_size = 2 * args.hidden_size
68
+ if args.concat_rnn_layers:
69
+ doc_hidden_size *= args.doc_layers
70
+ question_hidden_size *= args.question_layers
71
+
72
+ # Question merging
73
+ if args.question_merge not in ['avg', 'self_attn']:
74
+ raise NotImplementedError('merge_mode = %s' % args.merge_mode)
75
+ if args.question_merge == 'self_attn':
76
+ self.self_attn = layers.LinearSeqAttn(question_hidden_size)
77
+
78
+ # Bilinear attention for span start/end
79
+ self.start_attn = layers.BilinearSeqAttn(
80
+ doc_hidden_size,
81
+ question_hidden_size,
82
+ normalize=normalize,
83
+ )
84
+ self.end_attn = layers.BilinearSeqAttn(
85
+ doc_hidden_size,
86
+ question_hidden_size,
87
+ normalize=normalize,
88
+ )
89
+
90
+ def forward(self, x1, x1_f, x1_mask, x2, x2_mask):
91
+ """Inputs:
92
+ x1 = document word indices [batch * len_d]
93
+ x1_f = document word features indices [batch * len_d * nfeat]
94
+ x1_mask = document padding mask [batch * len_d]
95
+ x2 = question word indices [batch * len_q]
96
+ x2_mask = question padding mask [batch * len_q]
97
+ """
98
+ # Embed both document and question
99
+ x1_emb = self.embedding(x1)
100
+ x2_emb = self.embedding(x2)
101
+
102
+ # Dropout on embeddings
103
+ if self.args.dropout_emb > 0:
104
+ x1_emb = nn.functional.dropout(x1_emb, p=self.args.dropout_emb,
105
+ training=self.training)
106
+ x2_emb = nn.functional.dropout(x2_emb, p=self.args.dropout_emb,
107
+ training=self.training)
108
+
109
+ # Form document encoding inputs
110
+ drnn_input = [x1_emb]
111
+
112
+ # Add attention-weighted question representation
113
+ if self.args.use_qemb:
114
+ x2_weighted_emb = self.qemb_match(x1_emb, x2_emb, x2_mask)
115
+ drnn_input.append(x2_weighted_emb)
116
+
117
+ # Add manual features
118
+ if self.args.num_features > 0:
119
+ drnn_input.append(x1_f)
120
+
121
+ # Encode document with RNN
122
+ doc_hiddens = self.doc_rnn(torch.cat(drnn_input, 2), x1_mask)
123
+
124
+ # Encode question with RNN + merge hiddens
125
+ question_hiddens = self.question_rnn(x2_emb, x2_mask)
126
+ if self.args.question_merge == 'avg':
127
+ q_merge_weights = layers.uniform_weights(question_hiddens, x2_mask)
128
+ elif self.args.question_merge == 'self_attn':
129
+ q_merge_weights = self.self_attn(question_hiddens, x2_mask)
130
+ question_hidden = layers.weighted_avg(question_hiddens, q_merge_weights)
131
+
132
+ # Predict start and end positions
133
+ start_scores = self.start_attn(doc_hiddens, question_hidden, x1_mask)
134
+ end_scores = self.end_attn(doc_hiddens, question_hidden, x1_mask)
135
+ return start_scores, end_scores
drqa/reader/utils.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2017-present, Facebook, Inc.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+ """DrQA reader utilities."""
8
+
9
+ import json
10
+ import time
11
+ import logging
12
+ import string
13
+ import regex as re
14
+
15
+ from collections import Counter
16
+ from .data import Dictionary
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ # ------------------------------------------------------------------------------
22
+ # Data loading
23
+ # ------------------------------------------------------------------------------
24
+
25
+
26
+ def load_data(args, filename, skip_no_answer=False):
27
+ """Load examples from preprocessed file.
28
+ One example per line, JSON encoded.
29
+ """
30
+ # Load JSON lines
31
+ with open(filename) as f:
32
+ examples = [json.loads(line) for line in f]
33
+
34
+ # Make case insensitive?
35
+ if args.uncased_question or args.uncased_doc:
36
+ for ex in examples:
37
+ if args.uncased_question:
38
+ ex['question'] = [w.lower() for w in ex['question']]
39
+ if args.uncased_doc:
40
+ ex['document'] = [w.lower() for w in ex['document']]
41
+
42
+ # Skip unparsed (start/end) examples
43
+ if skip_no_answer:
44
+ examples = [ex for ex in examples if len(ex['answers']) > 0]
45
+
46
+ return examples
47
+
48
+
49
+ def load_text(filename):
50
+ """Load the paragraphs only of a SQuAD dataset. Store as qid -> text."""
51
+ # Load JSON file
52
+ with open(filename) as f:
53
+ examples = json.load(f)['data']
54
+
55
+ texts = {}
56
+ for article in examples:
57
+ for paragraph in article['paragraphs']:
58
+ for qa in paragraph['qas']:
59
+ texts[qa['id']] = paragraph['context']
60
+ return texts
61
+
62
+
63
+ def load_answers(filename):
64
+ """Load the answers only of a SQuAD dataset. Store as qid -> [answers]."""
65
+ # Load JSON file
66
+ with open(filename) as f:
67
+ examples = json.load(f)['data']
68
+
69
+ ans = {}
70
+ for article in examples:
71
+ for paragraph in article['paragraphs']:
72
+ for qa in paragraph['qas']:
73
+ ans[qa['id']] = list(map(lambda x: x['text'], qa['answers']))
74
+ return ans
75
+
76
+
77
+ # ------------------------------------------------------------------------------
78
+ # Dictionary building
79
+ # ------------------------------------------------------------------------------
80
+
81
+
82
+ def index_embedding_words(embedding_file):
83
+ """Put all the words in embedding_file into a set."""
84
+ words = set()
85
+ with open(embedding_file) as f:
86
+ for line in f:
87
+ w = Dictionary.normalize(line.rstrip().split(' ')[0])
88
+ words.add(w)
89
+ return words
90
+
91
+
92
+ def load_words(args, examples):
93
+ """Iterate and index all the words in examples (documents + questions)."""
94
+ def _insert(iterable):
95
+ for w in iterable:
96
+ w = Dictionary.normalize(w)
97
+ if valid_words and w not in valid_words:
98
+ continue
99
+ words.add(w)
100
+
101
+ if args.restrict_vocab and args.embedding_file:
102
+ logger.info('Restricting to words in %s' % args.embedding_file)
103
+ valid_words = index_embedding_words(args.embedding_file)
104
+ logger.info('Num words in set = %d' % len(valid_words))
105
+ else:
106
+ valid_words = None
107
+
108
+ words = set()
109
+ for ex in examples:
110
+ _insert(ex['question'])
111
+ _insert(ex['document'])
112
+ return words
113
+
114
+
115
+ def build_word_dict(args, examples):
116
+ """Return a dictionary from question and document words in
117
+ provided examples.
118
+ """
119
+ word_dict = Dictionary()
120
+ for w in load_words(args, examples):
121
+ word_dict.add(w)
122
+ return word_dict
123
+
124
+
125
+ def top_question_words(args, examples, word_dict):
126
+ """Count and return the most common question words in provided examples."""
127
+ word_count = Counter()
128
+ for ex in examples:
129
+ for w in ex['question']:
130
+ w = Dictionary.normalize(w)
131
+ if w in word_dict:
132
+ word_count.update([w])
133
+ return word_count.most_common(args.tune_partial)
134
+
135
+
136
+ def build_feature_dict(args, examples):
137
+ """Index features (one hot) from fields in examples and options."""
138
+ def _insert(feature):
139
+ if feature not in feature_dict:
140
+ feature_dict[feature] = len(feature_dict)
141
+
142
+ feature_dict = {}
143
+
144
+ # Exact match features
145
+ if args.use_in_question:
146
+ _insert('in_question')
147
+ _insert('in_question_uncased')
148
+ if args.use_lemma:
149
+ _insert('in_question_lemma')
150
+
151
+ # Part of speech tag features
152
+ if args.use_pos:
153
+ for ex in examples:
154
+ for w in ex['pos']:
155
+ _insert('pos=%s' % w)
156
+
157
+ # Named entity tag features
158
+ if args.use_ner:
159
+ for ex in examples:
160
+ for w in ex['ner']:
161
+ _insert('ner=%s' % w)
162
+
163
+ # Term frequency feature
164
+ if args.use_tf:
165
+ _insert('tf')
166
+ return feature_dict
167
+
168
+
169
+ # ------------------------------------------------------------------------------
170
+ # Evaluation. Follows official evalutation script for v1.1 of the SQuAD dataset.
171
+ # ------------------------------------------------------------------------------
172
+
173
+
174
+ def normalize_answer(s):
175
+ """Lower text and remove punctuation, articles and extra whitespace."""
176
+ def remove_articles(text):
177
+ return re.sub(r'\b(a|an|the)\b', ' ', text)
178
+
179
+ def white_space_fix(text):
180
+ return ' '.join(text.split())
181
+
182
+ def remove_punc(text):
183
+ exclude = set(string.punctuation)
184
+ return ''.join(ch for ch in text if ch not in exclude)
185
+
186
+ def lower(text):
187
+ return text.lower()
188
+
189
+ return white_space_fix(remove_articles(remove_punc(lower(s))))
190
+
191
+
192
+ def f1_score(prediction, ground_truth):
193
+ """Compute the geometric mean of precision and recall for answer tokens."""
194
+ prediction_tokens = normalize_answer(prediction).split()
195
+ ground_truth_tokens = normalize_answer(ground_truth).split()
196
+ common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
197
+ num_same = sum(common.values())
198
+ if num_same == 0:
199
+ return 0
200
+ precision = 1.0 * num_same / len(prediction_tokens)
201
+ recall = 1.0 * num_same / len(ground_truth_tokens)
202
+ f1 = (2 * precision * recall) / (precision + recall)
203
+ return f1
204
+
205
+
206
+ def exact_match_score(prediction, ground_truth):
207
+ """Check if the prediction is a (soft) exact match with the ground truth."""
208
+ return normalize_answer(prediction) == normalize_answer(ground_truth)
209
+
210
+
211
+ def regex_match_score(prediction, pattern):
212
+ """Check if the prediction matches the given regular expression."""
213
+ try:
214
+ compiled = re.compile(
215
+ pattern,
216
+ flags=re.IGNORECASE + re.UNICODE + re.MULTILINE
217
+ )
218
+ except BaseException:
219
+ logger.warn('Regular expression failed to compile: %s' % pattern)
220
+ return False
221
+ return compiled.match(prediction) is not None
222
+
223
+
224
+ def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
225
+ """Given a prediction and multiple valid answers, return the score of
226
+ the best prediction-answer_n pair given a metric function.
227
+ """
228
+ scores_for_ground_truths = []
229
+ for ground_truth in ground_truths:
230
+ score = metric_fn(prediction, ground_truth)
231
+ scores_for_ground_truths.append(score)
232
+ return max(scores_for_ground_truths)
233
+
234
+
235
+ # ------------------------------------------------------------------------------
236
+ # Utility classes
237
+ # ------------------------------------------------------------------------------
238
+
239
+
240
+ class AverageMeter(object):
241
+ """Computes and stores the average and current value."""
242
+
243
+ def __init__(self):
244
+ self.reset()
245
+
246
+ def reset(self):
247
+ self.val = 0
248
+ self.avg = 0
249
+ self.sum = 0
250
+ self.count = 0
251
+
252
+ def update(self, val, n=1):
253
+ self.val = val
254
+ self.sum += val * n
255
+ self.count += n
256
+ self.avg = self.sum / self.count
257
+
258
+
259
+ class Timer(object):
260
+ """Computes elapsed time."""
261
+
262
+ def __init__(self):
263
+ self.running = True
264
+ self.total = 0
265
+ self.start = time.time()
266
+
267
+ def reset(self):
268
+ self.running = True
269
+ self.total = 0
270
+ self.start = time.time()
271
+ return self
272
+
273
+ def resume(self):
274
+ if not self.running:
275
+ self.running = True
276
+ self.start = time.time()
277
+ return self
278
+
279
+ def stop(self):
280
+ if self.running:
281
+ self.running = False
282
+ self.total += time.time() - self.start
283
+ return self
284
+
285
+ def time(self):
286
+ if self.running:
287
+ return self.total + time.time() - self.start
288
+ return self.total
drqa/reader/vector.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2017-present, Facebook, Inc.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+ """Functions for putting examples into torch format."""
8
+
9
+ from collections import Counter
10
+ import torch
11
+
12
+
13
+ def vectorize(ex, model, single_answer=False):
14
+ """Torchify a single example."""
15
+ args = model.args
16
+ word_dict = model.word_dict
17
+ feature_dict = model.feature_dict
18
+
19
+ # Index words
20
+ document = torch.LongTensor([word_dict[w] for w in ex['document']])
21
+ question = torch.LongTensor([word_dict[w] for w in ex['question']])
22
+
23
+ # Create extra features vector
24
+ if len(feature_dict) > 0:
25
+ features = torch.zeros(len(ex['document']), len(feature_dict))
26
+ else:
27
+ features = None
28
+
29
+ # f_{exact_match}
30
+ if args.use_in_question:
31
+ q_words_cased = {w for w in ex['question']}
32
+ q_words_uncased = {w.lower() for w in ex['question']}
33
+ q_lemma = {w for w in ex['qlemma']} if args.use_lemma else None
34
+ for i in range(len(ex['document'])):
35
+ if ex['document'][i] in q_words_cased:
36
+ features[i][feature_dict['in_question']] = 1.0
37
+ if ex['document'][i].lower() in q_words_uncased:
38
+ features[i][feature_dict['in_question_uncased']] = 1.0
39
+ if q_lemma and ex['lemma'][i] in q_lemma:
40
+ features[i][feature_dict['in_question_lemma']] = 1.0
41
+
42
+ # f_{token} (POS)
43
+ if args.use_pos:
44
+ for i, w in enumerate(ex['pos']):
45
+ f = 'pos=%s' % w
46
+ if f in feature_dict:
47
+ features[i][feature_dict[f]] = 1.0
48
+
49
+ # f_{token} (NER)
50
+ if args.use_ner:
51
+ for i, w in enumerate(ex['ner']):
52
+ f = 'ner=%s' % w
53
+ if f in feature_dict:
54
+ features[i][feature_dict[f]] = 1.0
55
+
56
+ # f_{token} (TF)
57
+ if args.use_tf:
58
+ counter = Counter([w.lower() for w in ex['document']])
59
+ l = len(ex['document'])
60
+ for i, w in enumerate(ex['document']):
61
+ features[i][feature_dict['tf']] = counter[w.lower()] * 1.0 / l
62
+
63
+ # Maybe return without target
64
+ if 'answers' not in ex:
65
+ return document, features, question, ex['id']
66
+
67
+ # ...or with target(s) (might still be empty if answers is empty)
68
+ if single_answer:
69
+ assert(len(ex['answers']) > 0)
70
+ start = torch.LongTensor(1).fill_(ex['answers'][0][0])
71
+ end = torch.LongTensor(1).fill_(ex['answers'][0][1])
72
+ else:
73
+ start = [a[0] for a in ex['answers']]
74
+ end = [a[1] for a in ex['answers']]
75
+
76
+ return document, features, question, start, end, ex['id']
77
+
78
+
79
+ def batchify(batch):
80
+ """Gather a batch of individual examples into one batch."""
81
+ NUM_INPUTS = 3
82
+ NUM_TARGETS = 2
83
+ NUM_EXTRA = 1
84
+
85
+ ids = [ex[-1] for ex in batch]
86
+ docs = [ex[0] for ex in batch]
87
+ features = [ex[1] for ex in batch]
88
+ questions = [ex[2] for ex in batch]
89
+
90
+ # Batch documents and features
91
+ max_length = max([d.size(0) for d in docs])
92
+ x1 = torch.LongTensor(len(docs), max_length).zero_()
93
+ x1_mask = torch.ByteTensor(len(docs), max_length).fill_(1)
94
+ if features[0] is None:
95
+ x1_f = None
96
+ else:
97
+ x1_f = torch.zeros(len(docs), max_length, features[0].size(1))
98
+ for i, d in enumerate(docs):
99
+ x1[i, :d.size(0)].copy_(d)
100
+ x1_mask[i, :d.size(0)].fill_(0)
101
+ if x1_f is not None:
102
+ x1_f[i, :d.size(0)].copy_(features[i])
103
+
104
+ # Batch questions
105
+ max_length = max([q.size(0) for q in questions])
106
+ x2 = torch.LongTensor(len(questions), max_length).zero_()
107
+ x2_mask = torch.ByteTensor(len(questions), max_length).fill_(1)
108
+ for i, q in enumerate(questions):
109
+ x2[i, :q.size(0)].copy_(q)
110
+ x2_mask[i, :q.size(0)].fill_(0)
111
+
112
+ # Maybe return without targets
113
+ if len(batch[0]) == NUM_INPUTS + NUM_EXTRA:
114
+ return x1, x1_f, x1_mask, x2, x2_mask, ids
115
+
116
+ elif len(batch[0]) == NUM_INPUTS + NUM_EXTRA + NUM_TARGETS:
117
+ # ...Otherwise add targets
118
+ if torch.is_tensor(batch[0][3]):
119
+ y_s = torch.cat([ex[3] for ex in batch])
120
+ y_e = torch.cat([ex[4] for ex in batch])
121
+ else:
122
+ y_s = [ex[3] for ex in batch]
123
+ y_e = [ex[4] for ex in batch]
124
+ else:
125
+ raise RuntimeError('Incorrect number of inputs per example.')
126
+
127
+ return x1, x1_f, x1_mask, x2, x2_mask, y_s, y_e, ids
drqa/retriever/__init__.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2017-present, Facebook, Inc.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ import os
9
+ from .. import DATA_DIR
10
+
11
+ DEFAULTS = {
12
+ 'db_path': os.path.join(DATA_DIR, 'wikipedia/docs.db'),
13
+ 'tfidf_path': os.path.join(
14
+ DATA_DIR,
15
+ 'wikipedia/docs-tfidf-ngram=2-hash=16777216-tokenizer=simple.npz'
16
+ ),
17
+ 'elastic_url': 'localhost:9200'
18
+ }
19
+
20
+
21
+ def set_default(key, value):
22
+ global DEFAULTS
23
+ DEFAULTS[key] = value
24
+
25
+
26
+ def get_class(name):
27
+ if name == 'tfidf':
28
+ return TfidfDocRanker
29
+ if name == 'sqlite':
30
+ return DocDB
31
+ if name == 'elasticsearch':
32
+ return ElasticDocRanker
33
+ raise RuntimeError('Invalid retriever class: %s' % name)
34
+
35
+
36
+ from .doc_db import DocDB
37
+ from .tfidf_doc_ranker import TfidfDocRanker
38
+ from .elastic_doc_ranker import ElasticDocRanker
drqa/retriever/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (975 Bytes). View file
 
drqa/retriever/__pycache__/doc_db.cpython-38.pyc ADDED
Binary file (2.67 kB). View file
 
drqa/retriever/__pycache__/elastic_doc_ranker.cpython-38.pyc ADDED
Binary file (4.64 kB). View file
 
drqa/retriever/__pycache__/tfidf_doc_ranker.cpython-38.pyc ADDED
Binary file (4.26 kB). View file
 
drqa/retriever/__pycache__/utils.cpython-38.pyc ADDED
Binary file (4.23 kB). View file
 
drqa/retriever/doc_db.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2017-present, Facebook, Inc.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+ """Documents, in a sqlite database."""
8
+
9
+ import sqlite3
10
+ from . import utils
11
+ from . import DEFAULTS
12
+
13
+
14
+ class DocDB(object):
15
+ """Sqlite backed document storage.
16
+
17
+ Implements get_doc_text(doc_id).
18
+ """
19
+
20
+ def __init__(self, db_path=None):
21
+ self.path = db_path or DEFAULTS['db_path']
22
+ self.connection = sqlite3.connect(self.path, check_same_thread=False)
23
+
24
+ def __enter__(self):
25
+ return self
26
+
27
+ def __exit__(self, *args):
28
+ self.close()
29
+
30
+ def path(self):
31
+ """Return the path to the file that backs this database."""
32
+ return self.path
33
+
34
+ def close(self):
35
+ """Close the connection to the database."""
36
+ self.connection.close()
37
+
38
+ def get_doc_ids(self):
39
+ """Fetch all ids of docs stored in the db."""
40
+ cursor = self.connection.cursor()
41
+ cursor.execute("SELECT id FROM documents")
42
+ results = [r[0] for r in cursor.fetchall()]
43
+ cursor.close()
44
+ return results
45
+
46
+ def get_doc_text(self, doc_id):
47
+ """Fetch the raw text of the doc for 'doc_id'."""
48
+ cursor = self.connection.cursor()
49
+ cursor.execute(
50
+ "SELECT text FROM documents WHERE id = ?",
51
+ (utils.normalize(doc_id), )
52
+ # (doc_id, )
53
+ )
54
+ result = cursor.fetchone()
55
+ cursor.close()
56
+ return result if result is None else result[0]
57
+
58
+
59
+ def get_doc_title(self, doc_id):
60
+ """Fetch the raw text of the doc for 'doc_id'."""
61
+ cursor = self.connection.cursor()
62
+ cursor.execute(
63
+ "SELECT title FROM documents WHERE id = ?",
64
+ (utils.normalize(doc_id),)
65
+ # (doc_id, )
66
+ )
67
+ result = cursor.fetchone()
68
+ cursor.close()
69
+ return result if result is None else result[0]
70
+
71
+ def get_doc_intro(self, doc_id):
72
+ """Fetch the raw text of the doc for 'doc_id'."""
73
+ cursor = self.connection.cursor()
74
+ cursor.execute(
75
+ "SELECT intro FROM documents WHERE id = ?", # intro: the introduction of Wikipedia page
76
+ (utils.normalize(doc_id),)
77
+ # (doc_id, )
78
+ )
79
+ result = cursor.fetchone()
80
+ cursor.close()
81
+ return result if result is None else result[0]
drqa/retriever/elastic_doc_ranker.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2017-present, Facebook, Inc.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+ """Rank documents with an ElasticSearch index"""
8
+
9
+ import logging
10
+ import scipy.sparse as sp
11
+
12
+ from multiprocessing.pool import ThreadPool
13
+ from functools import partial
14
+ from elasticsearch import Elasticsearch
15
+
16
+ from . import utils
17
+ from . import DEFAULTS
18
+ from .. import tokenizers
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ class ElasticDocRanker(object):
24
+ """ Connect to an ElasticSearch index.
25
+ Score pairs based on Elasticsearch
26
+ """
27
+
28
+ def __init__(self, elastic_url=None, elastic_index=None, elastic_fields=None, elastic_field_doc_name=None, strict=True, elastic_field_content=None):
29
+ """
30
+ Args:
31
+ elastic_url: URL of the ElasticSearch server containing port
32
+ elastic_index: Index name of ElasticSearch
33
+ elastic_fields: Fields of the Elasticsearch index to search in
34
+ elastic_field_doc_name: Field containing the name of the document (index)
35
+ strict: fail on empty queries or continue (and return empty result)
36
+ elastic_field_content: Field containing the content of document in plaint text
37
+ """
38
+ # Load from disk
39
+ elastic_url = elastic_url or DEFAULTS['elastic_url']
40
+ logger.info('Connecting to %s' % elastic_url)
41
+ self.es = Elasticsearch(hosts=elastic_url)
42
+ self.elastic_index = elastic_index
43
+ self.elastic_fields = elastic_fields
44
+ self.elastic_field_doc_name = elastic_field_doc_name
45
+ self.elastic_field_content = elastic_field_content
46
+ self.strict = strict
47
+
48
+ # Elastic Ranker
49
+
50
+ def get_doc_index(self, doc_id):
51
+ """Convert doc_id --> doc_index"""
52
+ field_index = self.elastic_field_doc_name
53
+ if isinstance(field_index, list):
54
+ field_index = '.'.join(field_index)
55
+ result = self.es.search(index=self.elastic_index, body={'query':{'match':
56
+ {field_index: doc_id}}})
57
+ return result['hits']['hits'][0]['_id']
58
+
59
+
60
+ def get_doc_id(self, doc_index):
61
+ """Convert doc_index --> doc_id"""
62
+ result = self.es.search(index=self.elastic_index, body={'query': { 'match': {"_id": doc_index}}})
63
+ source = result['hits']['hits'][0]['_source']
64
+ return utils.get_field(source, self.elastic_field_doc_name)
65
+
66
+ def closest_docs(self, query, k=1):
67
+ """Closest docs by using ElasticSearch
68
+ """
69
+ results = self.es.search(index=self.elastic_index, body={'size':k ,'query':
70
+ {'multi_match': {
71
+ 'query': query,
72
+ 'type': 'most_fields',
73
+ 'fields': self.elastic_fields}}})
74
+ hits = results['hits']['hits']
75
+ doc_ids = [utils.get_field(row['_source'], self.elastic_field_doc_name) for row in hits]
76
+ doc_scores = [row['_score'] for row in hits]
77
+ return doc_ids, doc_scores
78
+
79
+ def batch_closest_docs(self, queries, k=1, num_workers=None):
80
+ """Process a batch of closest_docs requests multithreaded.
81
+ Note: we can use plain threads here as scipy is outside of the GIL.
82
+ """
83
+ with ThreadPool(num_workers) as threads:
84
+ closest_docs = partial(self.closest_docs, k=k)
85
+ results = threads.map(closest_docs, queries)
86
+ return results
87
+
88
+ # Elastic DB
89
+
90
+ def __enter__(self):
91
+ return self
92
+
93
+ def close(self):
94
+ """Close the connection to the database."""
95
+ self.es = None
96
+
97
+ def get_doc_ids(self):
98
+ """Fetch all ids of docs stored in the db."""
99
+ results = self.es.search(index= self.elastic_index, body={
100
+ "query": {"match_all": {}}})
101
+ doc_ids = [utils.get_field(result['_source'], self.elastic_field_doc_name) for result in results['hits']['hits']]
102
+ return doc_ids
103
+
104
+ def get_doc_text(self, doc_id):
105
+ """Fetch the raw text of the doc for 'doc_id'."""
106
+ idx = self.get_doc_index(doc_id)
107
+ result = self.es.get(index=self.elastic_index, doc_type='_doc', id=idx)
108
+ return result if result is None else result['_source'][self.elastic_field_content]
109
+
drqa/retriever/tfidf_doc_ranker.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2017-present, Facebook, Inc.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+ """Rank documents with TF-IDF scores"""
8
+
9
+ import logging
10
+ import numpy as np
11
+ import scipy.sparse as sp
12
+
13
+ from multiprocessing.pool import ThreadPool
14
+ from functools import partial
15
+
16
+ from . import utils
17
+ from . import DEFAULTS
18
+ from .. import tokenizers
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ class TfidfDocRanker(object):
24
+ """Loads a pre-weighted inverted index of token/document terms.
25
+ Scores new queries by taking sparse dot products.
26
+ """
27
+
28
+ def __init__(self, tfidf_path=None, strict=True):
29
+ """
30
+ Args:
31
+ tfidf_path: path to saved model file
32
+ strict: fail on empty queries or continue (and return empty result)
33
+ """
34
+ # Load from disk
35
+ tfidf_path = tfidf_path or DEFAULTS['tfidf_path']
36
+ logger.info('Loading %s' % tfidf_path)
37
+ matrix, metadata = utils.load_sparse_csr(tfidf_path)
38
+ self.doc_mat = matrix
39
+ self.ngrams = metadata['ngram']
40
+ self.hash_size = metadata['hash_size']
41
+ self.tokenizer = tokenizers.get_class(metadata['tokenizer'])()
42
+ self.doc_freqs = metadata['doc_freqs'].squeeze()
43
+ self.doc_dict = metadata['doc_dict']
44
+ self.num_docs = len(self.doc_dict[0])
45
+ self.strict = strict
46
+
47
+ def get_doc_index(self, doc_id):
48
+ """Convert doc_id --> doc_index"""
49
+ return self.doc_dict[0][doc_id]
50
+
51
+ def get_doc_id(self, doc_index):
52
+ """Convert doc_index --> doc_id"""
53
+ return self.doc_dict[1][doc_index]
54
+
55
+ def closest_docs(self, query, k=1):
56
+ """Closest docs by dot product between query and documents
57
+ in tfidf weighted word vector space.
58
+ """
59
+ spvec = self.text2spvec(query)
60
+ res = spvec * self.doc_mat
61
+
62
+ if len(res.data) <= k:
63
+ o_sort = np.argsort(-res.data)
64
+ else:
65
+ o = np.argpartition(-res.data, k)[0:k]
66
+ o_sort = o[np.argsort(-res.data[o])]
67
+
68
+ doc_scores = res.data[o_sort]
69
+ doc_ids = [self.get_doc_id(i) for i in res.indices[o_sort]]
70
+ return doc_ids, doc_scores
71
+
72
+ def batch_closest_docs(self, queries, k=1, num_workers=None):
73
+ """Process a batch of closest_docs requests multithreaded.
74
+ Note: we can use plain threads here as scipy is outside of the GIL.
75
+ """
76
+ with ThreadPool(num_workers) as threads:
77
+ closest_docs = partial(self.closest_docs, k=k)
78
+ results = threads.map(closest_docs, queries)
79
+ return results
80
+
81
+ def parse(self, query):
82
+ """Parse the query into tokens (either ngrams or tokens)."""
83
+ tokens = self.tokenizer.tokenize(query)
84
+ return tokens.ngrams(n=self.ngrams, uncased=True,
85
+ filter_fn=utils.filter_ngram)
86
+
87
+ def text2spvec(self, query):
88
+ """Create a sparse tfidf-weighted word vector from query.
89
+
90
+ tfidf = log(tf + 1) * log((N - Nt + 0.5) / (Nt + 0.5))
91
+ """
92
+ # Get hashed ngrams
93
+ words = self.parse(utils.normalize(query))
94
+ wids = [utils.hash(w, self.hash_size) for w in words]
95
+
96
+ if len(wids) == 0:
97
+ if self.strict:
98
+ raise RuntimeError('No valid word in: %s' % query)
99
+ else:
100
+ logger.warning('No valid word in: %s' % query)
101
+ return sp.csr_matrix((1, self.hash_size))
102
+
103
+ # Count TF
104
+ wids_unique, wids_counts = np.unique(wids, return_counts=True)
105
+ tfs = np.log1p(wids_counts)
106
+
107
+ # Count IDF
108
+ Ns = self.doc_freqs[wids_unique]
109
+ idfs = np.log((self.num_docs - Ns + 0.5) / (Ns + 0.5))
110
+ idfs[idfs < 0] = 0
111
+
112
+ # TF-IDF
113
+ data = np.multiply(tfs, idfs)
114
+
115
+ # One row, sparse csr matrix
116
+ indptr = np.array([0, len(wids_unique)])
117
+ spvec = sp.csr_matrix(
118
+ (data, wids_unique, indptr), shape=(1, self.hash_size)
119
+ )
120
+
121
+ return spvec
drqa/retriever/utils.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2017-present, Facebook, Inc.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+ """Various retriever utilities."""
8
+
9
+ import regex
10
+ import unicodedata
11
+ import numpy as np
12
+ import scipy.sparse as sp
13
+ from sklearn.utils import murmurhash3_32
14
+
15
+
16
+ # ------------------------------------------------------------------------------
17
+ # Sparse matrix saving/loading helpers.
18
+ # ------------------------------------------------------------------------------
19
+
20
+
21
+ def save_sparse_csr(filename, matrix, metadata=None):
22
+ data = {
23
+ 'data': matrix.data,
24
+ 'indices': matrix.indices,
25
+ 'indptr': matrix.indptr,
26
+ 'shape': matrix.shape,
27
+ 'metadata': metadata,
28
+ }
29
+ np.savez(filename, **data)
30
+
31
+
32
+ def load_sparse_csr(filename):
33
+ loader = np.load(filename, allow_pickle=True)
34
+ matrix = sp.csr_matrix((loader['data'], loader['indices'],
35
+ loader['indptr']), shape=loader['shape'])
36
+ return matrix, loader['metadata'].item(0) if 'metadata' in loader else None
37
+
38
+
39
+ # ------------------------------------------------------------------------------
40
+ # Token hashing.
41
+ # ------------------------------------------------------------------------------
42
+
43
+
44
+ def hash(token, num_buckets):
45
+ """Unsigned 32 bit murmurhash for feature hashing."""
46
+ return murmurhash3_32(token, positive=True) % num_buckets
47
+
48
+
49
+ # ------------------------------------------------------------------------------
50
+ # Text cleaning.
51
+ # ------------------------------------------------------------------------------
52
+
53
+
54
+ STOPWORDS = {
55
+ 'i', 'me', 'my', 'myself', 'we', 'our', 'ours', 'ourselves', 'you', 'your',
56
+ 'yours', 'yourself', 'yourselves', 'he', 'him', 'his', 'himself', 'she',
57
+ 'her', 'hers', 'herself', 'it', 'its', 'itself', 'they', 'them', 'their',
58
+ 'theirs', 'themselves', 'what', 'which', 'who', 'whom', 'this', 'that',
59
+ 'these', 'those', 'am', 'is', 'are', 'was', 'were', 'be', 'been', 'being',
60
+ 'have', 'has', 'had', 'having', 'do', 'does', 'did', 'doing', 'a', 'an',
61
+ 'the', 'and', 'but', 'if', 'or', 'because', 'as', 'until', 'while', 'of',
62
+ 'at', 'by', 'for', 'with', 'about', 'against', 'between', 'into', 'through',
63
+ 'during', 'before', 'after', 'above', 'below', 'to', 'from', 'up', 'down',
64
+ 'in', 'out', 'on', 'off', 'over', 'under', 'again', 'further', 'then',
65
+ 'once', 'here', 'there', 'when', 'where', 'why', 'how', 'all', 'any',
66
+ 'both', 'each', 'few', 'more', 'most', 'other', 'some', 'such', 'no', 'nor',
67
+ 'not', 'only', 'own', 'same', 'so', 'than', 'too', 'very', 's', 't', 'can',
68
+ 'will', 'just', 'don', 'should', 'now', 'd', 'll', 'm', 'o', 're', 've',
69
+ 'y', 'ain', 'aren', 'couldn', 'didn', 'doesn', 'hadn', 'hasn', 'haven',
70
+ 'isn', 'ma', 'mightn', 'mustn', 'needn', 'shan', 'shouldn', 'wasn', 'weren',
71
+ 'won', 'wouldn', "'ll", "'re", "'ve", "n't", "'s", "'d", "'m", "''", "``"
72
+ }
73
+
74
+
75
+ def normalize(text):
76
+ """Resolve different type of unicode encodings."""
77
+ return unicodedata.normalize('NFD', text)
78
+
79
+
80
+ def filter_word(text):
81
+ """Take out english stopwords, punctuation, and compound endings."""
82
+ text = normalize(text)
83
+ if regex.match(r'^\p{P}+$', text):
84
+ return True
85
+ if text.lower() in STOPWORDS:
86
+ return True
87
+ return False
88
+
89
+
90
+ def filter_ngram(gram, mode='any'):
91
+ """Decide whether to keep or discard an n-gram.
92
+
93
+ Args:
94
+ gram: list of tokens (length N)
95
+ mode: Option to throw out ngram if
96
+ 'any': any single token passes filter_word
97
+ 'all': all tokens pass filter_word
98
+ 'ends': book-ended by filterable tokens
99
+ """
100
+ filtered = [filter_word(w) for w in gram]
101
+ if mode == 'any':
102
+ return any(filtered)
103
+ elif mode == 'all':
104
+ return all(filtered)
105
+ elif mode == 'ends':
106
+ return filtered[0] or filtered[-1]
107
+ else:
108
+ raise ValueError('Invalid mode: %s' % mode)
109
+
110
+ def get_field(d, field_list):
111
+ """get the subfield associated to a list of elastic fields
112
+ E.g. ['file', 'filename'] to d['file']['filename']
113
+ """
114
+ if isinstance(field_list, str):
115
+ return d[field_list]
116
+ else:
117
+ idx = d.copy()
118
+ for field in field_list:
119
+ idx = idx[field]
120
+ return idx
drqa/tokenizers/__init__.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2017-present, Facebook, Inc.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ import os
9
+
10
+ DEFAULTS = {
11
+ 'corenlp_classpath': os.getenv('CLASSPATH')
12
+ }
13
+
14
+
15
+ def set_default(key, value):
16
+ global DEFAULTS
17
+ DEFAULTS[key] = value
18
+
19
+
20
+ from .corenlp_tokenizer import CoreNLPTokenizer
21
+ from .regexp_tokenizer import RegexpTokenizer
22
+ from .simple_tokenizer import SimpleTokenizer
23
+
24
+ # Spacy is optional
25
+ try:
26
+ from .spacy_tokenizer import SpacyTokenizer
27
+ except ImportError:
28
+ pass
29
+
30
+
31
+ def get_class(name):
32
+ if name == 'spacy':
33
+ return SpacyTokenizer
34
+ if name == 'corenlp':
35
+ return CoreNLPTokenizer
36
+ if name == 'regexp':
37
+ return RegexpTokenizer
38
+ if name == 'simple':
39
+ return SimpleTokenizer
40
+
41
+ raise RuntimeError('Invalid tokenizer: %s' % name)
42
+
43
+
44
+ def get_annotators_for_args(args):
45
+ annotators = set()
46
+ if args.use_pos:
47
+ annotators.add('pos')
48
+ if args.use_lemma:
49
+ annotators.add('lemma')
50
+ if args.use_ner:
51
+ annotators.add('ner')
52
+ return annotators
53
+
54
+
55
+ def get_annotators_for_model(model):
56
+ return get_annotators_for_args(model.args)
drqa/tokenizers/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (1.33 kB). View file
 
drqa/tokenizers/__pycache__/corenlp_tokenizer.cpython-38.pyc ADDED
Binary file (3.49 kB). View file
 
drqa/tokenizers/__pycache__/regexp_tokenizer.cpython-38.pyc ADDED
Binary file (3.31 kB). View file
 
drqa/tokenizers/__pycache__/simple_tokenizer.cpython-38.pyc ADDED
Binary file (1.77 kB). View file
 
drqa/tokenizers/__pycache__/spacy_tokenizer.cpython-38.pyc ADDED
Binary file (2.05 kB). View file
 
drqa/tokenizers/__pycache__/tokenizer.cpython-38.pyc ADDED
Binary file (5.83 kB). View file
 
drqa/tokenizers/corenlp_tokenizer.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2017-present, Facebook, Inc.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+ """Simple wrapper around the Stanford CoreNLP pipeline.
8
+
9
+ Serves commands to a java subprocess running the jar. Requires java 8.
10
+ """
11
+
12
+ import copy
13
+ import json
14
+ import pexpect
15
+
16
+ from .tokenizer import Tokens, Tokenizer
17
+ from . import DEFAULTS
18
+
19
+
20
+ class CoreNLPTokenizer(Tokenizer):
21
+
22
+ def __init__(self, **kwargs):
23
+ """
24
+ Args:
25
+ annotators: set that can include pos, lemma, and ner.
26
+ classpath: Path to the corenlp directory of jars
27
+ mem: Java heap memory
28
+ """
29
+ self.classpath = (kwargs.get('classpath') or
30
+ DEFAULTS['corenlp_classpath'])
31
+ self.annotators = copy.deepcopy(kwargs.get('annotators', set()))
32
+ self.mem = kwargs.get('mem', '2g')
33
+ self._launch()
34
+
35
+ def _launch(self):
36
+ """Start the CoreNLP jar with pexpect."""
37
+ annotators = ['tokenize', 'ssplit']
38
+ if 'ner' in self.annotators:
39
+ annotators.extend(['pos', 'lemma', 'ner'])
40
+ elif 'lemma' in self.annotators:
41
+ annotators.extend(['pos', 'lemma'])
42
+ elif 'pos' in self.annotators:
43
+ annotators.extend(['pos'])
44
+ annotators = ','.join(annotators)
45
+ options = ','.join(['untokenizable=noneDelete',
46
+ 'invertible=true'])
47
+ cmd = ['java', '-mx' + self.mem, '-cp', '"%s"' % self.classpath,
48
+ 'edu.stanford.nlp.pipeline.StanfordCoreNLP', '-annotators',
49
+ annotators, '-tokenize.options', options,
50
+ '-outputFormat', 'json', '-prettyPrint', 'false']
51
+
52
+ # We use pexpect to keep the subprocess alive and feed it commands.
53
+ # Because we don't want to get hit by the max terminal buffer size,
54
+ # we turn off canonical input processing to have unlimited bytes.
55
+ self.corenlp = pexpect.spawn('/bin/bash', maxread=100000, timeout=60)
56
+ self.corenlp.setecho(False)
57
+ self.corenlp.sendline('stty -icanon')
58
+ self.corenlp.sendline(' '.join(cmd))
59
+ self.corenlp.delaybeforesend = 0
60
+ self.corenlp.delayafterread = 0
61
+ self.corenlp.expect_exact('NLP>', searchwindowsize=100)
62
+
63
+ @staticmethod
64
+ def _convert(token):
65
+ if token == '-LRB-':
66
+ return '('
67
+ if token == '-RRB-':
68
+ return ')'
69
+ if token == '-LSB-':
70
+ return '['
71
+ if token == '-RSB-':
72
+ return ']'
73
+ if token == '-LCB-':
74
+ return '{'
75
+ if token == '-RCB-':
76
+ return '}'
77
+ return token
78
+
79
+ def tokenize(self, text):
80
+ # Since we're feeding text to the commandline, we're waiting on seeing
81
+ # the NLP> prompt. Hacky!
82
+ if 'NLP>' in text:
83
+ raise RuntimeError('Bad token (NLP>) in text!')
84
+
85
+ # Sending q will cause the process to quit -- manually override
86
+ if text.lower().strip() == 'q':
87
+ token = text.strip()
88
+ index = text.index(token)
89
+ data = [(token, text[index:], (index, index + 1), 'NN', 'q', 'O')]
90
+ return Tokens(data, self.annotators)
91
+
92
+ # Minor cleanup before tokenizing.
93
+ clean_text = text.replace('\n', ' ')
94
+
95
+ self.corenlp.sendline(clean_text.encode('utf-8'))
96
+ self.corenlp.expect_exact('NLP>', searchwindowsize=100)
97
+
98
+ # Skip to start of output (may have been stderr logging messages)
99
+ output = self.corenlp.before
100
+ start = output.find(b'{"sentences":')
101
+ output = json.loads(output[start:].decode('utf-8'))
102
+
103
+ data = []
104
+ tokens = [t for s in output['sentences'] for t in s['tokens']]
105
+ for i in range(len(tokens)):
106
+ # Get whitespace
107
+ start_ws = tokens[i]['characterOffsetBegin']
108
+ if i + 1 < len(tokens):
109
+ end_ws = tokens[i + 1]['characterOffsetBegin']
110
+ else:
111
+ end_ws = tokens[i]['characterOffsetEnd']
112
+
113
+ data.append((
114
+ self._convert(tokens[i]['word']),
115
+ text[start_ws: end_ws],
116
+ (tokens[i]['characterOffsetBegin'],
117
+ tokens[i]['characterOffsetEnd']),
118
+ tokens[i].get('pos', None),
119
+ tokens[i].get('lemma', None),
120
+ tokens[i].get('ner', None)
121
+ ))
122
+ return Tokens(data, self.annotators)
drqa/tokenizers/regexp_tokenizer.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2017-present, Facebook, Inc.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+ """Regex based tokenizer that emulates the Stanford/NLTK PTB tokenizers.
8
+
9
+ However it is purely in Python, supports robust untokenization, unicode,
10
+ and requires minimal dependencies.
11
+ """
12
+
13
+ import regex
14
+ import logging
15
+ from .tokenizer import Tokens, Tokenizer
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class RegexpTokenizer(Tokenizer):
21
+ DIGIT = r'\p{Nd}+([:\.\,]\p{Nd}+)*'
22
+ TITLE = (r'(dr|esq|hon|jr|mr|mrs|ms|prof|rev|sr|st|rt|messrs|mmes|msgr)'
23
+ r'\.(?=\p{Z})')
24
+ ABBRV = r'([\p{L}]\.){2,}(?=\p{Z}|$)'
25
+ ALPHA_NUM = r'[\p{L}\p{N}\p{M}]++'
26
+ HYPHEN = r'{A}([-\u058A\u2010\u2011]{A})+'.format(A=ALPHA_NUM)
27
+ NEGATION = r"((?!n't)[\p{L}\p{N}\p{M}])++(?=n't)|n't"
28
+ CONTRACTION1 = r"can(?=not\b)"
29
+ CONTRACTION2 = r"'([tsdm]|re|ll|ve)\b"
30
+ START_DQUOTE = r'(?<=[\p{Z}\(\[{<]|^)(``|["\u0093\u201C\u00AB])(?!\p{Z})'
31
+ START_SQUOTE = r'(?<=[\p{Z}\(\[{<]|^)[\'\u0091\u2018\u201B\u2039](?!\p{Z})'
32
+ END_DQUOTE = r'(?<!\p{Z})(\'\'|["\u0094\u201D\u00BB])'
33
+ END_SQUOTE = r'(?<!\p{Z})[\'\u0092\u2019\u203A]'
34
+ DASH = r'--|[\u0096\u0097\u2013\u2014\u2015]'
35
+ ELLIPSES = r'\.\.\.|\u2026'
36
+ PUNCT = r'\p{P}'
37
+ NON_WS = r'[^\p{Z}\p{C}]'
38
+
39
+ def __init__(self, **kwargs):
40
+ """
41
+ Args:
42
+ annotators: None or empty set (only tokenizes).
43
+ substitutions: if true, normalizes some token types (e.g. quotes).
44
+ """
45
+ self._regexp = regex.compile(
46
+ '(?P<digit>%s)|(?P<title>%s)|(?P<abbr>%s)|(?P<neg>%s)|(?P<hyph>%s)|'
47
+ '(?P<contr1>%s)|(?P<alphanum>%s)|(?P<contr2>%s)|(?P<sdquote>%s)|'
48
+ '(?P<edquote>%s)|(?P<ssquote>%s)|(?P<esquote>%s)|(?P<dash>%s)|'
49
+ '(?<ellipses>%s)|(?P<punct>%s)|(?P<nonws>%s)' %
50
+ (self.DIGIT, self.TITLE, self.ABBRV, self.NEGATION, self.HYPHEN,
51
+ self.CONTRACTION1, self.ALPHA_NUM, self.CONTRACTION2,
52
+ self.START_DQUOTE, self.END_DQUOTE, self.START_SQUOTE,
53
+ self.END_SQUOTE, self.DASH, self.ELLIPSES, self.PUNCT,
54
+ self.NON_WS),
55
+ flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE
56
+ )
57
+ if len(kwargs.get('annotators', {})) > 0:
58
+ logger.warning('%s only tokenizes! Skipping annotators: %s' %
59
+ (type(self).__name__, kwargs.get('annotators')))
60
+ self.annotators = set()
61
+ self.substitutions = kwargs.get('substitutions', True)
62
+
63
+ def tokenize(self, text):
64
+ data = []
65
+ matches = [m for m in self._regexp.finditer(text)]
66
+ for i in range(len(matches)):
67
+ # Get text
68
+ token = matches[i].group()
69
+
70
+ # Make normalizations for special token types
71
+ if self.substitutions:
72
+ groups = matches[i].groupdict()
73
+ if groups['sdquote']:
74
+ token = "``"
75
+ elif groups['edquote']:
76
+ token = "''"
77
+ elif groups['ssquote']:
78
+ token = "`"
79
+ elif groups['esquote']:
80
+ token = "'"
81
+ elif groups['dash']:
82
+ token = '--'
83
+ elif groups['ellipses']:
84
+ token = '...'
85
+
86
+ # Get whitespace
87
+ span = matches[i].span()
88
+ start_ws = span[0]
89
+ if i + 1 < len(matches):
90
+ end_ws = matches[i + 1].span()[0]
91
+ else:
92
+ end_ws = span[1]
93
+
94
+ # Format data
95
+ data.append((
96
+ token,
97
+ text[start_ws: end_ws],
98
+ span,
99
+ ))
100
+ return Tokens(data, self.annotators)
drqa/tokenizers/simple_tokenizer.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2017-present, Facebook, Inc.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+ """Basic tokenizer that splits text into alpha-numeric tokens and
8
+ non-whitespace tokens.
9
+ """
10
+
11
+ import regex
12
+ import logging
13
+ from .tokenizer import Tokens, Tokenizer
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class SimpleTokenizer(Tokenizer):
19
+ ALPHA_NUM = r'[\p{L}\p{N}\p{M}]+'
20
+ NON_WS = r'[^\p{Z}\p{C}]'
21
+
22
+ def __init__(self, **kwargs):
23
+ """
24
+ Args:
25
+ annotators: None or empty set (only tokenizes).
26
+ """
27
+ self._regexp = regex.compile(
28
+ '(%s)|(%s)' % (self.ALPHA_NUM, self.NON_WS),
29
+ flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE
30
+ )
31
+ if len(kwargs.get('annotators', {})) > 0:
32
+ logger.warning('%s only tokenizes! Skipping annotators: %s' %
33
+ (type(self).__name__, kwargs.get('annotators')))
34
+ self.annotators = set()
35
+
36
+ def tokenize(self, text):
37
+ data = []
38
+ matches = [m for m in self._regexp.finditer(text)]
39
+ for i in range(len(matches)):
40
+ # Get text
41
+ token = matches[i].group()
42
+
43
+ # Get whitespace
44
+ span = matches[i].span()
45
+ start_ws = span[0]
46
+ if i + 1 < len(matches):
47
+ end_ws = matches[i + 1].span()[0]
48
+ else:
49
+ end_ws = span[1]
50
+
51
+ # Format data
52
+ data.append((
53
+ token,
54
+ text[start_ws: end_ws],
55
+ span,
56
+ ))
57
+ return Tokens(data, self.annotators)
drqa/tokenizers/spacy_tokenizer.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2017-present, Facebook, Inc.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+ """Tokenizer that is backed by spaCy (spacy.io).
8
+
9
+ Requires spaCy package and the spaCy english model.
10
+ """
11
+
12
+ import spacy
13
+ import copy
14
+ from .tokenizer import Tokens, Tokenizer
15
+
16
+
17
+ class SpacyTokenizer(Tokenizer):
18
+
19
+ def __init__(self, **kwargs):
20
+ """
21
+ Args:
22
+ annotators: set that can include pos, lemma, and ner.
23
+ model: spaCy model to use (either path, or keyword like 'en').
24
+ """
25
+ model = kwargs.get('model', 'en')
26
+ self.annotators = copy.deepcopy(kwargs.get('annotators', set()))
27
+ nlp_kwargs = {'parser': False}
28
+ if not any([p in self.annotators for p in ['lemma', 'pos', 'ner']]):
29
+ nlp_kwargs['tagger'] = False
30
+ if 'ner' not in self.annotators:
31
+ nlp_kwargs['entity'] = False
32
+ self.nlp = spacy.load(model, **nlp_kwargs)
33
+
34
+ def tokenize(self, text):
35
+ # We don't treat new lines as tokens.
36
+ clean_text = text.replace('\n', ' ')
37
+ tokens = self.nlp.tokenizer(clean_text)
38
+ if any([p in self.annotators for p in ['lemma', 'pos', 'ner']]):
39
+ self.nlp.tagger(tokens)
40
+ if 'ner' in self.annotators:
41
+ self.nlp.entity(tokens)
42
+
43
+ data = []
44
+ for i in range(len(tokens)):
45
+ # Get whitespace
46
+ start_ws = tokens[i].idx
47
+ if i + 1 < len(tokens):
48
+ end_ws = tokens[i + 1].idx
49
+ else:
50
+ end_ws = tokens[i].idx + len(tokens[i].text)
51
+
52
+ data.append((
53
+ tokens[i].text,
54
+ text[start_ws: end_ws],
55
+ (tokens[i].idx, tokens[i].idx + len(tokens[i].text)),
56
+ tokens[i].tag_,
57
+ tokens[i].lemma_,
58
+ tokens[i].ent_type_,
59
+ ))
60
+
61
+ # Set special option for non-entity tag: '' vs 'O' in spaCy
62
+ return Tokens(data, self.annotators, opts={'non_ent': ''})
drqa/tokenizers/tokenizer.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2017-present, Facebook, Inc.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+ """Base tokenizer/tokens classes and utilities."""
8
+
9
+ import copy
10
+
11
+
12
+ class Tokens(object):
13
+ """A class to represent a list of tokenized text."""
14
+ TEXT = 0
15
+ TEXT_WS = 1
16
+ SPAN = 2
17
+ POS = 3
18
+ LEMMA = 4
19
+ NER = 5
20
+
21
+ def __init__(self, data, annotators, opts=None):
22
+ self.data = data
23
+ self.annotators = annotators
24
+ self.opts = opts or {}
25
+
26
+ def __len__(self):
27
+ """The number of tokens."""
28
+ return len(self.data)
29
+
30
+ def slice(self, i=None, j=None):
31
+ """Return a view of the list of tokens from [i, j)."""
32
+ new_tokens = copy.copy(self)
33
+ new_tokens.data = self.data[i: j]
34
+ return new_tokens
35
+
36
+ def untokenize(self):
37
+ """Returns the original text (with whitespace reinserted)."""
38
+ return ''.join([t[self.TEXT_WS] for t in self.data]).strip()
39
+
40
+ def words(self, uncased=False):
41
+ """Returns a list of the text of each token
42
+
43
+ Args:
44
+ uncased: lower cases text
45
+ """
46
+ if uncased:
47
+ return [t[self.TEXT].lower() for t in self.data]
48
+ else:
49
+ return [t[self.TEXT] for t in self.data]
50
+
51
+ def offsets(self):
52
+ """Returns a list of [start, end) character offsets of each token."""
53
+ return [t[self.SPAN] for t in self.data]
54
+
55
+ def pos(self):
56
+ """Returns a list of part-of-speech tags of each token.
57
+ Returns None if this annotation was not included.
58
+ """
59
+ if 'pos' not in self.annotators:
60
+ return None
61
+ return [t[self.POS] for t in self.data]
62
+
63
+ def lemmas(self):
64
+ """Returns a list of the lemmatized text of each token.
65
+ Returns None if this annotation was not included.
66
+ """
67
+ if 'lemma' not in self.annotators:
68
+ return None
69
+ return [t[self.LEMMA] for t in self.data]
70
+
71
+ def entities(self):
72
+ """Returns a list of named-entity-recognition tags of each token.
73
+ Returns None if this annotation was not included.
74
+ """
75
+ if 'ner' not in self.annotators:
76
+ return None
77
+ return [t[self.NER] for t in self.data]
78
+
79
+ def ngrams(self, n=1, uncased=False, filter_fn=None, as_strings=True):
80
+ """Returns a list of all ngrams from length 1 to n.
81
+
82
+ Args:
83
+ n: upper limit of ngram length
84
+ uncased: lower cases text
85
+ filter_fn: user function that takes in an ngram list and returns
86
+ True or False to keep or not keep the ngram
87
+ as_string: return the ngram as a string vs list
88
+ """
89
+ def _skip(gram):
90
+ if not filter_fn:
91
+ return False
92
+ return filter_fn(gram)
93
+
94
+ words = self.words(uncased)
95
+ ngrams = [(s, e + 1)
96
+ for s in range(len(words))
97
+ for e in range(s, min(s + n, len(words)))
98
+ if not _skip(words[s:e + 1])]
99
+
100
+ # Concatenate into strings
101
+ if as_strings:
102
+ ngrams = ['{}'.format(' '.join(words[s:e])) for (s, e) in ngrams]
103
+
104
+ return ngrams
105
+
106
+ def entity_groups(self):
107
+ """Group consecutive entity tokens with the same NER tag."""
108
+ entities = self.entities()
109
+ if not entities:
110
+ return None
111
+ non_ent = self.opts.get('non_ent', 'O')
112
+ groups = []
113
+ idx = 0
114
+ while idx < len(entities):
115
+ ner_tag = entities[idx]
116
+ # Check for entity tag
117
+ if ner_tag != non_ent:
118
+ # Chomp the sequence
119
+ start = idx
120
+ while (idx < len(entities) and entities[idx] == ner_tag):
121
+ idx += 1
122
+ groups.append((self.slice(start, idx).untokenize(), ner_tag))
123
+ else:
124
+ idx += 1
125
+ return groups
126
+
127
+
128
+ class Tokenizer(object):
129
+ """Base tokenizer class.
130
+ Tokenizers implement tokenize, which should return a Tokens class.
131
+ """
132
+ def tokenize(self, text):
133
+ raise NotImplementedError
134
+
135
+ def shutdown(self):
136
+ pass
137
+
138
+ def __del__(self):
139
+ self.shutdown()
html2lines.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from distutils.command.config import config
2
+ import requests
3
+ from time import sleep
4
+ import trafilatura
5
+ from trafilatura.meta import reset_caches
6
+ from trafilatura.settings import DEFAULT_CONFIG
7
+ import spacy
8
+ import os
9
+ os.system("python -m spacy download en_core_web_sm")
10
+ nlp = spacy.load('en_core_web_sm')
11
+ import sys
12
+
13
+ DEFAULT_CONFIG.MAX_FILE_SIZE = 50000
14
+
15
+ def get_page(url):
16
+ page = None
17
+ for i in range(3):
18
+ try:
19
+ page = trafilatura.fetch_url(url, config=DEFAULT_CONFIG)
20
+ assert page is not None
21
+ print("Fetched "+url, file=sys.stderr)
22
+ break
23
+ except:
24
+ sleep(3)
25
+ return page
26
+
27
+ def url2lines(url):
28
+ page = get_page(url)
29
+
30
+ if page is None:
31
+ return []
32
+
33
+ lines = html2lines(page)
34
+ return lines
35
+
36
+ def line_correction(lines, max_size=100):
37
+ out_lines = []
38
+ for line in lines:
39
+ if len(line) < 4:
40
+ continue
41
+
42
+ if len(line) > max_size:
43
+ doc = nlp(line[:5000]) # We split lines into sentences, but for performance we take only the first 5k characters per line
44
+ stack = ""
45
+ for sent in doc.sents:
46
+ if len(stack) > 0:
47
+ stack += " "
48
+ stack += str(sent).strip()
49
+ if len(stack) > max_size:
50
+ out_lines.append(stack)
51
+ stack = ""
52
+
53
+ if len(stack) > 0:
54
+ out_lines.append(stack)
55
+ else:
56
+ out_lines.append(line)
57
+
58
+ return out_lines
59
+
60
+ def html2lines(page):
61
+ out_lines = []
62
+
63
+ if len(page.strip()) == 0 or page is None:
64
+ return out_lines
65
+
66
+ text = trafilatura.extract(page, config=DEFAULT_CONFIG)
67
+ reset_caches()
68
+
69
+ if text is None:
70
+ return out_lines
71
+
72
+ return text.split("\n") # We just spit out the entire page, so need to reformat later.
requirements.txt ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ nltk
3
+ rank_bm25
4
+ accelerate
5
+ trafilatura
6
+ spacy
7
+ pytorch_lightning
8
+ transformers==4.29.2
9
+ datasets
10
+ leven
11
+ scikit-learn
12
+ pexpect
13
+ elasticsearch
14
+ torch
15
+ huggingface_hub
16
+ google-api-python-client
17
+ wikipedia-api
18
+ beautifulsoup4
19
+ azure-storage-file-share
20
+ azure-storage-blob
21
+ bm25s
22
+ PyStemmer