Spaces:
Build error
Build error
zhenyundeng
commited on
Commit
·
e62781a
1
Parent(s):
e5c50a7
add files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +6 -5
- app.py +1368 -0
- drqa/__init__.py +23 -0
- drqa/__pycache__/__init__.cpython-38.pyc +0 -0
- drqa/pipeline/__init__.py +27 -0
- drqa/pipeline/__pycache__/__init__.cpython-38.pyc +0 -0
- drqa/pipeline/__pycache__/drqa.cpython-38.pyc +0 -0
- drqa/pipeline/drqa.py +312 -0
- drqa/reader/__init__.py +28 -0
- drqa/reader/__pycache__/__init__.cpython-38.pyc +0 -0
- drqa/reader/__pycache__/config.cpython-38.pyc +0 -0
- drqa/reader/__pycache__/data.cpython-38.pyc +0 -0
- drqa/reader/__pycache__/layers.cpython-38.pyc +0 -0
- drqa/reader/__pycache__/model.cpython-38.pyc +0 -0
- drqa/reader/__pycache__/predictor.cpython-38.pyc +0 -0
- drqa/reader/__pycache__/rnn_reader.cpython-38.pyc +0 -0
- drqa/reader/__pycache__/utils.cpython-38.pyc +0 -0
- drqa/reader/__pycache__/vector.cpython-38.pyc +0 -0
- drqa/reader/config.py +128 -0
- drqa/reader/data.py +131 -0
- drqa/reader/layers.py +311 -0
- drqa/reader/model.py +482 -0
- drqa/reader/predictor.py +145 -0
- drqa/reader/rnn_reader.py +135 -0
- drqa/reader/utils.py +288 -0
- drqa/reader/vector.py +127 -0
- drqa/retriever/__init__.py +38 -0
- drqa/retriever/__pycache__/__init__.cpython-38.pyc +0 -0
- drqa/retriever/__pycache__/doc_db.cpython-38.pyc +0 -0
- drqa/retriever/__pycache__/elastic_doc_ranker.cpython-38.pyc +0 -0
- drqa/retriever/__pycache__/tfidf_doc_ranker.cpython-38.pyc +0 -0
- drqa/retriever/__pycache__/utils.cpython-38.pyc +0 -0
- drqa/retriever/doc_db.py +81 -0
- drqa/retriever/elastic_doc_ranker.py +109 -0
- drqa/retriever/tfidf_doc_ranker.py +121 -0
- drqa/retriever/utils.py +120 -0
- drqa/tokenizers/__init__.py +56 -0
- drqa/tokenizers/__pycache__/__init__.cpython-38.pyc +0 -0
- drqa/tokenizers/__pycache__/corenlp_tokenizer.cpython-38.pyc +0 -0
- drqa/tokenizers/__pycache__/regexp_tokenizer.cpython-38.pyc +0 -0
- drqa/tokenizers/__pycache__/simple_tokenizer.cpython-38.pyc +0 -0
- drqa/tokenizers/__pycache__/spacy_tokenizer.cpython-38.pyc +0 -0
- drqa/tokenizers/__pycache__/tokenizer.cpython-38.pyc +0 -0
- drqa/tokenizers/corenlp_tokenizer.py +122 -0
- drqa/tokenizers/regexp_tokenizer.py +100 -0
- drqa/tokenizers/simple_tokenizer.py +57 -0
- drqa/tokenizers/spacy_tokenizer.py +62 -0
- drqa/tokenizers/tokenizer.py +139 -0
- html2lines.py +72 -0
- requirements.txt +22 -0
README.md
CHANGED
@@ -1,12 +1,13 @@
|
|
1 |
---
|
2 |
-
title: AVeriTeC
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 4.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
|
|
10 |
---
|
11 |
|
12 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: AVeriTeC
|
3 |
+
emoji: 🏆
|
4 |
+
colorFrom: purple
|
5 |
+
colorTo: red
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 4.37.2
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
+
license: apache-2.0
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,1368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# Created by zd302 at 08/07/2024
|
4 |
+
|
5 |
+
import gradio as gr
|
6 |
+
import tqdm
|
7 |
+
import torch
|
8 |
+
import numpy as np
|
9 |
+
from time import sleep
|
10 |
+
import threading
|
11 |
+
import gc
|
12 |
+
import os
|
13 |
+
import json
|
14 |
+
import pytorch_lightning as pl
|
15 |
+
from urllib.parse import urlparse
|
16 |
+
from accelerate import Accelerator
|
17 |
+
|
18 |
+
from transformers import BartTokenizer, BartForConditionalGeneration
|
19 |
+
from transformers import BloomTokenizerFast, BloomForCausalLM, BertTokenizer, BertForSequenceClassification
|
20 |
+
from transformers import RobertaTokenizer, RobertaForSequenceClassification
|
21 |
+
|
22 |
+
from rank_bm25 import BM25Okapi
|
23 |
+
# import bm25s
|
24 |
+
# import Stemmer # optional: for stemming
|
25 |
+
from html2lines import url2lines
|
26 |
+
from googleapiclient.discovery import build
|
27 |
+
from averitec.models.DualEncoderModule import DualEncoderModule
|
28 |
+
from averitec.models.SequenceClassificationModule import SequenceClassificationModule
|
29 |
+
from averitec.models.JustificationGenerationModule import JustificationGenerationModule
|
30 |
+
from averitec.data.sample_claims import CLAIMS_Type
|
31 |
+
|
32 |
+
# ---------------------------------------------------------------------------
|
33 |
+
# load .env
|
34 |
+
from utils import create_user_id
|
35 |
+
user_id = create_user_id()
|
36 |
+
|
37 |
+
from datetime import datetime
|
38 |
+
from azure.storage.fileshare import ShareServiceClient
|
39 |
+
try:
|
40 |
+
from dotenv import load_dotenv
|
41 |
+
load_dotenv()
|
42 |
+
except Exception as e:
|
43 |
+
pass
|
44 |
+
|
45 |
+
account_url = os.environ["AZURE_ACCOUNT_URL"]
|
46 |
+
credential = {
|
47 |
+
"account_key": os.environ['AZURE_ACCOUNT_KEY'],
|
48 |
+
"account_name": os.environ['AZURE_ACCOUNT_NAME']
|
49 |
+
}
|
50 |
+
|
51 |
+
file_share_name = "averitec"
|
52 |
+
azure_service = ShareServiceClient(account_url=account_url, credential=credential)
|
53 |
+
azure_share_client = azure_service.get_share_client(file_share_name)
|
54 |
+
|
55 |
+
# ---------- Setting ----------
|
56 |
+
import requests
|
57 |
+
from bs4 import BeautifulSoup
|
58 |
+
import wikipediaapi
|
59 |
+
wiki_wiki = wikipediaapi.Wikipedia('AVeriTeC ([email protected])', 'en')
|
60 |
+
|
61 |
+
import nltk
|
62 |
+
nltk.download('punkt')
|
63 |
+
from nltk import pos_tag, word_tokenize, sent_tokenize
|
64 |
+
|
65 |
+
import spacy
|
66 |
+
os.system("python -m spacy download en_core_web_sm")
|
67 |
+
nlp = spacy.load("en_core_web_sm")
|
68 |
+
|
69 |
+
# ---------------------------------------------------------------------------
|
70 |
+
# Load sample dict for AVeriTeC search
|
71 |
+
# all_samples_dict = json.load(open('averitec/data/all_samples.json', 'r'))
|
72 |
+
|
73 |
+
# ---------------------------------------------------------------------------
|
74 |
+
# ---------- Load pretrained models ----------
|
75 |
+
# ---------- load Evidence retrieval model ----------
|
76 |
+
# from drqa import retriever
|
77 |
+
# db_class = retriever.get_class('sqlite')
|
78 |
+
# doc_db = db_class("averitec/data/wikipedia_dumps/enwiki.db")
|
79 |
+
# ranker = retriever.get_class('tfidf')(tfidf_path="averitec/data/wikipedia_dumps/enwiki-tfidf-with-id-title.npz")
|
80 |
+
|
81 |
+
# ---------- Load Veracity and Justification prediction model ----------
|
82 |
+
print("Loading models ...")
|
83 |
+
LABEL = [
|
84 |
+
"Supported",
|
85 |
+
"Refuted",
|
86 |
+
"Not Enough Evidence",
|
87 |
+
"Conflicting Evidence/Cherrypicking",
|
88 |
+
]
|
89 |
+
# Veracity
|
90 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
91 |
+
veracity_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
92 |
+
bert_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=4, problem_type="single_label_classification")
|
93 |
+
veracity_model = SequenceClassificationModule.load_from_checkpoint("averitec/pretrained_models/bert_veracity.ckpt",
|
94 |
+
tokenizer=veracity_tokenizer, model=bert_model).to(device)
|
95 |
+
# Justification
|
96 |
+
justification_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large', add_prefix_space=True)
|
97 |
+
bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large")
|
98 |
+
best_checkpoint = 'averitec/pretrained_models/bart_justifications_verdict-epoch=13-val_loss=2.03-val_meteor=0.28.ckpt'
|
99 |
+
justification_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=justification_tokenizer, model=bart_model).to(device)
|
100 |
+
# ---------------------------------------------------------------------------
|
101 |
+
|
102 |
+
|
103 |
+
# Set up Gradio Theme
|
104 |
+
theme = gr.themes.Base(
|
105 |
+
primary_hue="blue",
|
106 |
+
secondary_hue="red",
|
107 |
+
font=[gr.themes.GoogleFont("Poppins"), "ui-sans-serif", "system-ui", "sans-serif"],
|
108 |
+
)
|
109 |
+
|
110 |
+
# ---------- Setting ----------
|
111 |
+
|
112 |
+
class Docs:
|
113 |
+
def __init__(self, metadata=dict(), page_content=""):
|
114 |
+
self.metadata = metadata
|
115 |
+
self.page_content = page_content
|
116 |
+
|
117 |
+
|
118 |
+
def make_html_source(source, i):
|
119 |
+
meta = source.metadata
|
120 |
+
content = source.page_content.strip()
|
121 |
+
|
122 |
+
card = f"""
|
123 |
+
<div class="card" id="doc{i}">
|
124 |
+
<div class="card-content">
|
125 |
+
<h2>Doc {i} - URL: <a href="{meta['url']}" target="_blank" class="pdf-link">{meta['url']}</a></h2>
|
126 |
+
<p>{content}</p>
|
127 |
+
</div>
|
128 |
+
<div class="card-footer">
|
129 |
+
<span>CACHED SOURCE URL:</span>
|
130 |
+
<a href="{meta['cached_source_url']}" target="_blank" class="pdf-link">
|
131 |
+
<span role="img" aria-label="Open PDF">🔗</span>
|
132 |
+
</a>
|
133 |
+
</div>
|
134 |
+
</div>
|
135 |
+
"""
|
136 |
+
|
137 |
+
return card
|
138 |
+
|
139 |
+
|
140 |
+
# ----- veracity_prediction -----
|
141 |
+
class SequenceClassificationDataLoader(pl.LightningDataModule):
|
142 |
+
def __init__(self, tokenizer, data_file, batch_size, add_extra_nee=False):
|
143 |
+
super().__init__()
|
144 |
+
self.tokenizer = tokenizer
|
145 |
+
self.data_file = data_file
|
146 |
+
self.batch_size = batch_size
|
147 |
+
self.add_extra_nee = add_extra_nee
|
148 |
+
|
149 |
+
def tokenize_strings(
|
150 |
+
self,
|
151 |
+
source_sentences,
|
152 |
+
max_length=400,
|
153 |
+
pad_to_max_length=False,
|
154 |
+
return_tensors="pt",
|
155 |
+
):
|
156 |
+
encoded_dict = self.tokenizer(
|
157 |
+
source_sentences,
|
158 |
+
max_length=max_length,
|
159 |
+
padding="max_length" if pad_to_max_length else "longest",
|
160 |
+
truncation=True,
|
161 |
+
return_tensors=return_tensors,
|
162 |
+
)
|
163 |
+
|
164 |
+
input_ids = encoded_dict["input_ids"]
|
165 |
+
attention_masks = encoded_dict["attention_mask"]
|
166 |
+
|
167 |
+
return input_ids, attention_masks
|
168 |
+
|
169 |
+
def quadruple_to_string(self, claim, question, answer, bool_explanation=""):
|
170 |
+
if bool_explanation is not None and len(bool_explanation) > 0:
|
171 |
+
bool_explanation = ", because " + bool_explanation.lower().strip()
|
172 |
+
else:
|
173 |
+
bool_explanation = ""
|
174 |
+
return (
|
175 |
+
"[CLAIM] "
|
176 |
+
+ claim.strip()
|
177 |
+
+ " [QUESTION] "
|
178 |
+
+ question.strip()
|
179 |
+
+ " "
|
180 |
+
+ answer.strip()
|
181 |
+
+ bool_explanation
|
182 |
+
)
|
183 |
+
|
184 |
+
|
185 |
+
def averitec_veracity_prediction(claim, qa_evidence):
|
186 |
+
bert_model_name = "bert-base-uncased"
|
187 |
+
tokenizer = BertTokenizer.from_pretrained(bert_model_name)
|
188 |
+
bert_model = BertForSequenceClassification.from_pretrained(bert_model_name, num_labels=4,
|
189 |
+
problem_type="single_label_classification")
|
190 |
+
|
191 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
192 |
+
trained_model = SequenceClassificationModule.load_from_checkpoint("averitec/pretrained_models/bert_veracity.ckpt",
|
193 |
+
tokenizer=tokenizer, model=bert_model).to(device)
|
194 |
+
|
195 |
+
dataLoader = SequenceClassificationDataLoader(
|
196 |
+
tokenizer=tokenizer,
|
197 |
+
data_file="this_is_discontinued",
|
198 |
+
batch_size=32,
|
199 |
+
add_extra_nee=False,
|
200 |
+
)
|
201 |
+
|
202 |
+
evidence_strings = []
|
203 |
+
for evidence in qa_evidence:
|
204 |
+
evidence_strings.append(
|
205 |
+
dataLoader.quadruple_to_string(claim, evidence.metadata["query"], evidence.metadata["answer"], ""))
|
206 |
+
|
207 |
+
if len(evidence_strings) == 0: # If we found no evidence e.g. because google returned 0 pages, just output NEI.
|
208 |
+
pred_label = "Not Enough Evidence"
|
209 |
+
return pred_label
|
210 |
+
|
211 |
+
tokenized_strings, attention_mask = dataLoader.tokenize_strings(evidence_strings)
|
212 |
+
example_support = torch.argmax(
|
213 |
+
trained_model(tokenized_strings.to(device), attention_mask=attention_mask.to(device)).logits, axis=1)
|
214 |
+
|
215 |
+
has_unanswerable = False
|
216 |
+
has_true = False
|
217 |
+
has_false = False
|
218 |
+
|
219 |
+
for v in example_support:
|
220 |
+
if v == 0:
|
221 |
+
has_true = True
|
222 |
+
if v == 1:
|
223 |
+
has_false = True
|
224 |
+
if v in (2, 3,): # TODO another hack -- we cant have different labels for train and test so we do this
|
225 |
+
has_unanswerable = True
|
226 |
+
|
227 |
+
if has_unanswerable:
|
228 |
+
answer = 2
|
229 |
+
elif has_true and not has_false:
|
230 |
+
answer = 0
|
231 |
+
elif not has_true and has_false:
|
232 |
+
answer = 1
|
233 |
+
else:
|
234 |
+
answer = 3
|
235 |
+
|
236 |
+
pred_label = LABEL[answer]
|
237 |
+
|
238 |
+
return pred_label
|
239 |
+
|
240 |
+
|
241 |
+
def fever_veracity_prediction(claim, evidence):
|
242 |
+
tokenizer = RobertaTokenizer.from_pretrained('Dzeniks/roberta-fact-check')
|
243 |
+
model = RobertaForSequenceClassification.from_pretrained('Dzeniks/roberta-fact-check')
|
244 |
+
|
245 |
+
evidence_string = ""
|
246 |
+
for evi in evidence:
|
247 |
+
evidence_string += evi.metadata['title'] + evi.metadata['evidence'] + ' '
|
248 |
+
|
249 |
+
input_sequence = tokenizer.encode_plus(claim, evidence_string, return_tensors="pt")
|
250 |
+
with torch.no_grad():
|
251 |
+
prediction = model(**input_sequence)
|
252 |
+
|
253 |
+
label = torch.argmax(prediction[0]).item()
|
254 |
+
pred_label = LABEL[label]
|
255 |
+
|
256 |
+
return pred_label
|
257 |
+
|
258 |
+
|
259 |
+
def veracity_prediction(claim, qa_evidence):
|
260 |
+
# bert_model_name = "bert-base-uncased"
|
261 |
+
# tokenizer = BertTokenizer.from_pretrained(bert_model_name)
|
262 |
+
# bert_model = BertForSequenceClassification.from_pretrained(bert_model_name, num_labels=4,
|
263 |
+
# problem_type="single_label_classification")
|
264 |
+
#
|
265 |
+
# device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
266 |
+
# trained_model = SequenceClassificationModule.load_from_checkpoint("averitec/pretrained_models/bert_veracity.ckpt",
|
267 |
+
# tokenizer=tokenizer, model=bert_model).to(device)
|
268 |
+
|
269 |
+
dataLoader = SequenceClassificationDataLoader(
|
270 |
+
tokenizer=veracity_tokenizer,
|
271 |
+
data_file="this_is_discontinued",
|
272 |
+
batch_size=32,
|
273 |
+
add_extra_nee=False,
|
274 |
+
)
|
275 |
+
|
276 |
+
evidence_strings = []
|
277 |
+
for evidence in qa_evidence:
|
278 |
+
evidence_strings.append(
|
279 |
+
dataLoader.quadruple_to_string(claim, evidence.metadata["query"], evidence.metadata["answer"], ""))
|
280 |
+
|
281 |
+
if len(evidence_strings) == 0: # If we found no evidence e.g. because google returned 0 pages, just output NEI.
|
282 |
+
pred_label = "Not Enough Evidence"
|
283 |
+
return pred_label
|
284 |
+
|
285 |
+
tokenized_strings, attention_mask = dataLoader.tokenize_strings(evidence_strings)
|
286 |
+
example_support = torch.argmax(veracity_model(tokenized_strings.to(device), attention_mask=attention_mask.to(device)).logits, axis=1)
|
287 |
+
|
288 |
+
has_unanswerable = False
|
289 |
+
has_true = False
|
290 |
+
has_false = False
|
291 |
+
|
292 |
+
for v in example_support:
|
293 |
+
if v == 0:
|
294 |
+
has_true = True
|
295 |
+
if v == 1:
|
296 |
+
has_false = True
|
297 |
+
if v in (2, 3,): # TODO another hack -- we cant have different labels for train and test so we do this
|
298 |
+
has_unanswerable = True
|
299 |
+
|
300 |
+
if has_unanswerable:
|
301 |
+
answer = 2
|
302 |
+
elif has_true and not has_false:
|
303 |
+
answer = 0
|
304 |
+
elif not has_true and has_false:
|
305 |
+
answer = 1
|
306 |
+
else:
|
307 |
+
answer = 3
|
308 |
+
|
309 |
+
pred_label = LABEL[answer]
|
310 |
+
|
311 |
+
return pred_label
|
312 |
+
|
313 |
+
|
314 |
+
def extract_claim_str(claim, qa_evidence, verdict_label):
|
315 |
+
claim_str = "[CLAIM] " + claim + " [EVIDENCE] "
|
316 |
+
|
317 |
+
for evidence in qa_evidence:
|
318 |
+
q_text = evidence.metadata['query'].strip()
|
319 |
+
|
320 |
+
if len(q_text) == 0:
|
321 |
+
continue
|
322 |
+
|
323 |
+
if not q_text[-1] == "?":
|
324 |
+
q_text += "?"
|
325 |
+
|
326 |
+
answer_strings = []
|
327 |
+
answer_strings.append(evidence.metadata['answer'])
|
328 |
+
|
329 |
+
claim_str += q_text
|
330 |
+
for a_text in answer_strings:
|
331 |
+
if a_text:
|
332 |
+
if not a_text[-1] == ".":
|
333 |
+
a_text += "."
|
334 |
+
claim_str += " " + a_text.strip()
|
335 |
+
|
336 |
+
claim_str += " "
|
337 |
+
|
338 |
+
claim_str += " [VERDICT] " + verdict_label
|
339 |
+
|
340 |
+
return claim_str
|
341 |
+
|
342 |
+
|
343 |
+
def averitec_justification_generation(claim, qa_evidence, verdict_label):
|
344 |
+
#
|
345 |
+
claim_str = extract_claim_str(claim, qa_evidence, verdict_label)
|
346 |
+
claim_str.strip()
|
347 |
+
|
348 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
349 |
+
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large', add_prefix_space=True)
|
350 |
+
bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large")
|
351 |
+
|
352 |
+
best_checkpoint = 'averitec/pretrained_models/bart_justifications_verdict-epoch=13-val_loss=2.03-val_meteor=0.28.ckpt'
|
353 |
+
trained_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=tokenizer,
|
354 |
+
model=bart_model).to(device)
|
355 |
+
|
356 |
+
pred_justification = trained_model.generate(claim_str, device=device)
|
357 |
+
|
358 |
+
return pred_justification.strip()
|
359 |
+
|
360 |
+
|
361 |
+
def justification_generation(claim, qa_evidence, verdict_label):
|
362 |
+
#
|
363 |
+
claim_str = extract_claim_str(claim, qa_evidence, verdict_label)
|
364 |
+
claim_str.strip()
|
365 |
+
|
366 |
+
# device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
367 |
+
# tokenizer = BartTokenizer.from_pretrained('facebook/bart-large', add_prefix_space=True)
|
368 |
+
# bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large")
|
369 |
+
#
|
370 |
+
# best_checkpoint = 'averitec/pretrained_models/bart_justifications_verdict-epoch=13-val_loss=2.03-val_meteor=0.28.ckpt'
|
371 |
+
# trained_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=tokenizer,
|
372 |
+
# model=bart_model).to(device)
|
373 |
+
|
374 |
+
pred_justification = justification_model.generate(claim_str, device=device)
|
375 |
+
|
376 |
+
return pred_justification.strip()
|
377 |
+
|
378 |
+
|
379 |
+
def QAprediction(claim, evidence, sources):
|
380 |
+
parts = []
|
381 |
+
#
|
382 |
+
evidence_title = f"""<h5>Retrieved Evidence:</h5>"""
|
383 |
+
for i, evi in enumerate(evidence, 1):
|
384 |
+
part = f"""<span>Doc {i}</span>"""
|
385 |
+
subpart = f"""<a href="#doc{i}" class="a-doc-ref" target="_self"><span class='doc-ref'><sup>{i}</sup></span></a>"""
|
386 |
+
# subpart = f"""<span class='doc-ref'>{i}</sup></span>"""
|
387 |
+
subparts = "".join([part, subpart])
|
388 |
+
parts.append(subparts)
|
389 |
+
|
390 |
+
evidence_part = ", ".join(parts)
|
391 |
+
|
392 |
+
prediction_title = f"""<h5>Prediction:</h5>"""
|
393 |
+
# if 'Google' in sources or 'AVeriTeC' in sources:
|
394 |
+
# verdict_label = averitec_veracity_prediction(claim, evidence)
|
395 |
+
# justification_label = averitec_justification_generation(claim, evidence, verdict_label)
|
396 |
+
# # justification_label = "See retrieved docs."
|
397 |
+
# justification_part = f"""<span>Justification: {justification_label}</span>"""
|
398 |
+
# if 'WikiPedia' in sources:
|
399 |
+
# # verdict_label = fever_veracity_prediction(claim, evidence)
|
400 |
+
# justification_label = averitec_justification_generation(claim, evidence, verdict_label)
|
401 |
+
# # justification_label = "See retrieved docs."
|
402 |
+
# justification_part = f"""<span>Justification: {justification_label}</span>"""
|
403 |
+
|
404 |
+
verdict_label = veracity_prediction(claim, evidence)
|
405 |
+
justification_label = justification_generation(claim, evidence, verdict_label)
|
406 |
+
# justification_label = "See retrieved docs."
|
407 |
+
justification_part = f"""<span>Justification: {justification_label}</span>"""
|
408 |
+
|
409 |
+
|
410 |
+
verdict_part = f"""Verdict: <span>{verdict_label}.</span><br>"""
|
411 |
+
|
412 |
+
content_parts = "".join([evidence_title, evidence_part, prediction_title, verdict_part, justification_part])
|
413 |
+
# content_parts = "".join([evidence_title, evidence_part, verdict_title, verdict_part, justification_title, justification_part])
|
414 |
+
|
415 |
+
return content_parts, [verdict_label, justification_label]
|
416 |
+
|
417 |
+
|
418 |
+
# ----------GoogleAPIretriever---------
|
419 |
+
def generate_reference_corpus(reference_file):
|
420 |
+
with open(reference_file) as f:
|
421 |
+
j = json.load(f)
|
422 |
+
train_examples = j
|
423 |
+
|
424 |
+
all_data_corpus = []
|
425 |
+
tokenized_corpus = []
|
426 |
+
|
427 |
+
for train_example in train_examples:
|
428 |
+
train_claim = train_example["claim"]
|
429 |
+
|
430 |
+
speaker = train_example["speaker"].strip() if train_example["speaker"] is not None and len(
|
431 |
+
train_example["speaker"]) > 1 else "they"
|
432 |
+
|
433 |
+
questions = [q["question"] for q in train_example["questions"]]
|
434 |
+
|
435 |
+
claim_dict_builder = {}
|
436 |
+
claim_dict_builder["claim"] = train_claim
|
437 |
+
claim_dict_builder["speaker"] = speaker
|
438 |
+
claim_dict_builder["questions"] = questions
|
439 |
+
|
440 |
+
tokenized_corpus.append(nltk.word_tokenize(claim_dict_builder["claim"]))
|
441 |
+
all_data_corpus.append(claim_dict_builder)
|
442 |
+
|
443 |
+
return tokenized_corpus, all_data_corpus
|
444 |
+
|
445 |
+
|
446 |
+
def doc2prompt(doc):
|
447 |
+
prompt_parts = "Outrageously, " + doc["speaker"] + " claimed that \"" + doc[
|
448 |
+
"claim"].strip() + "\". Criticism includes questions like: "
|
449 |
+
questions = [q.strip() for q in doc["questions"]]
|
450 |
+
return prompt_parts + " ".join(questions)
|
451 |
+
|
452 |
+
|
453 |
+
def docs2prompt(top_docs):
|
454 |
+
return "\n\n".join([doc2prompt(d) for d in top_docs])
|
455 |
+
|
456 |
+
|
457 |
+
def prompt_question_generation(test_claim, speaker="they", topk=10):
|
458 |
+
#
|
459 |
+
reference_file = "averitec_code/data/train.json"
|
460 |
+
tokenized_corpus, all_data_corpus = generate_reference_corpus(reference_file)
|
461 |
+
bm25 = BM25Okapi(tokenized_corpus)
|
462 |
+
|
463 |
+
# Define the bloom model:
|
464 |
+
accelerator = Accelerator()
|
465 |
+
accel_device = accelerator.device
|
466 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
467 |
+
tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-7b1")
|
468 |
+
model = BloomForCausalLM.from_pretrained("bigscience/bloom-7b1", torch_dtype=torch.bfloat16).to(device)
|
469 |
+
|
470 |
+
# --------------------------------------------------
|
471 |
+
# test claim
|
472 |
+
s = bm25.get_scores(nltk.word_tokenize(test_claim))
|
473 |
+
top_n = np.argsort(s)[::-1][:topk]
|
474 |
+
docs = [all_data_corpus[i] for i in top_n]
|
475 |
+
# --------------------------------------------------
|
476 |
+
|
477 |
+
prompt = docs2prompt(docs) + "\n\n" + "Outrageously, " + speaker + " claimed that \"" + test_claim.strip() + \
|
478 |
+
"\". Criticism includes questions like: "
|
479 |
+
sentences = [prompt]
|
480 |
+
|
481 |
+
inputs = tokenizer(sentences, padding=True, return_tensors="pt").to(device)
|
482 |
+
outputs = model.generate(inputs["input_ids"], max_length=2000, num_beams=2, no_repeat_ngram_size=2,
|
483 |
+
early_stopping=True)
|
484 |
+
|
485 |
+
tgt_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
|
486 |
+
in_len = len(sentences[0])
|
487 |
+
questions_str = tgt_text[in_len:].split("\n")[0]
|
488 |
+
|
489 |
+
qs = questions_str.split("?")
|
490 |
+
qs = [q.strip() + "?" for q in qs if q.strip() and len(q.strip()) < 300]
|
491 |
+
|
492 |
+
#
|
493 |
+
generate_question = [{"question": q, "answers": []} for q in qs]
|
494 |
+
|
495 |
+
return generate_question
|
496 |
+
|
497 |
+
|
498 |
+
def check_claim_date(check_date):
|
499 |
+
try:
|
500 |
+
year, month, date = check_date.split("-")
|
501 |
+
except:
|
502 |
+
month, date, year = "01", "01", "2022"
|
503 |
+
|
504 |
+
if len(year) == 2 and int(year) <= 30:
|
505 |
+
year = "20" + year
|
506 |
+
elif len(year) == 2:
|
507 |
+
year = "19" + year
|
508 |
+
elif len(year) == 1:
|
509 |
+
year = "200" + year
|
510 |
+
|
511 |
+
if len(month) == 1:
|
512 |
+
month = "0" + month
|
513 |
+
|
514 |
+
if len(date) == 1:
|
515 |
+
date = "0" + date
|
516 |
+
|
517 |
+
sort_date = year + month + date
|
518 |
+
|
519 |
+
return sort_date
|
520 |
+
|
521 |
+
|
522 |
+
def string_to_search_query(text, author):
|
523 |
+
parts = word_tokenize(text.strip())
|
524 |
+
tags = pos_tag(parts)
|
525 |
+
|
526 |
+
keep_tags = ["CD", "JJ", "NN", "VB"]
|
527 |
+
|
528 |
+
if author is not None:
|
529 |
+
search_string = author.split()
|
530 |
+
else:
|
531 |
+
search_string = []
|
532 |
+
|
533 |
+
for token, tag in zip(parts, tags):
|
534 |
+
for keep_tag in keep_tags:
|
535 |
+
if tag[1].startswith(keep_tag):
|
536 |
+
search_string.append(token)
|
537 |
+
|
538 |
+
search_string = " ".join(search_string)
|
539 |
+
return search_string
|
540 |
+
|
541 |
+
|
542 |
+
def google_search(search_term, api_key, cse_id, **kwargs):
|
543 |
+
service = build("customsearch", "v1", developerKey=api_key)
|
544 |
+
res = service.cse().list(q=search_term, cx=cse_id, **kwargs).execute()
|
545 |
+
|
546 |
+
if "items" in res:
|
547 |
+
return res['items']
|
548 |
+
else:
|
549 |
+
return []
|
550 |
+
|
551 |
+
|
552 |
+
def get_domain_name(url):
|
553 |
+
if '://' not in url:
|
554 |
+
url = 'http://' + url
|
555 |
+
|
556 |
+
domain = urlparse(url).netloc
|
557 |
+
|
558 |
+
if domain.startswith("www."):
|
559 |
+
return domain[4:]
|
560 |
+
else:
|
561 |
+
return domain
|
562 |
+
|
563 |
+
|
564 |
+
def get_and_store(url_link, fp, worker, worker_stack):
|
565 |
+
page_lines = url2lines(url_link)
|
566 |
+
|
567 |
+
with open(fp, "w") as out_f:
|
568 |
+
print("\n".join([url_link] + page_lines), file=out_f)
|
569 |
+
|
570 |
+
worker_stack.append(worker)
|
571 |
+
gc.collect()
|
572 |
+
|
573 |
+
|
574 |
+
def get_google_search_results(api_key, search_engine_id, google_search, sort_date, search_string, page=0):
|
575 |
+
search_results = []
|
576 |
+
for i in range(3):
|
577 |
+
try:
|
578 |
+
search_results += google_search(
|
579 |
+
search_string,
|
580 |
+
api_key,
|
581 |
+
search_engine_id,
|
582 |
+
num=10,
|
583 |
+
start=0 + 10 * page,
|
584 |
+
sort="date:r:19000101:" + sort_date,
|
585 |
+
dateRestrict=None,
|
586 |
+
gl="US"
|
587 |
+
)
|
588 |
+
break
|
589 |
+
except:
|
590 |
+
sleep(3)
|
591 |
+
|
592 |
+
return search_results
|
593 |
+
|
594 |
+
|
595 |
+
def averitec_search(claim, generate_question, speaker="they", check_date="2024-01-01", n_pages=1): # n_pages=3
|
596 |
+
# default config
|
597 |
+
api_key = os.environ["GOOGLE_API_KEY"]
|
598 |
+
search_engine_id = os.environ["GOOGLE_SEARCH_ENGINE_ID"]
|
599 |
+
|
600 |
+
blacklist = [
|
601 |
+
"jstor.org", # Blacklisted because their pdfs are not labelled as such, and clog up the download
|
602 |
+
"facebook.com", # Blacklisted because only post titles can be scraped, but the scraper doesn't know this,
|
603 |
+
"ftp.cs.princeton.edu", # Blacklisted because it hosts many large NLP corpora that keep showing up
|
604 |
+
"nlp.cs.princeton.edu",
|
605 |
+
"huggingface.co"
|
606 |
+
]
|
607 |
+
|
608 |
+
blacklist_files = [ # Blacklisted some NLP nonsense that crashes my machine with OOM errors
|
609 |
+
"/glove.",
|
610 |
+
"ftp://ftp.cs.princeton.edu/pub/cs226/autocomplete/words-333333.txt",
|
611 |
+
"https://web.mit.edu/adamrose/Public/googlelist",
|
612 |
+
]
|
613 |
+
|
614 |
+
# save to folder
|
615 |
+
store_folder = "averitec_code/store/retrieved_docs"
|
616 |
+
#
|
617 |
+
index = 0
|
618 |
+
questions = [q["question"] for q in generate_question]
|
619 |
+
|
620 |
+
# check the date of the claim
|
621 |
+
sort_date = check_claim_date(check_date) # check_date="2022-01-01"
|
622 |
+
|
623 |
+
#
|
624 |
+
search_strings = []
|
625 |
+
search_types = []
|
626 |
+
|
627 |
+
search_string_2 = string_to_search_query(claim, None)
|
628 |
+
search_strings += [search_string_2, claim, ]
|
629 |
+
search_types += ["claim", "claim-noformat", ]
|
630 |
+
|
631 |
+
search_strings += questions
|
632 |
+
search_types += ["question" for _ in questions]
|
633 |
+
|
634 |
+
# start to search
|
635 |
+
search_results = []
|
636 |
+
visited = {}
|
637 |
+
store_counter = 0
|
638 |
+
worker_stack = list(range(10))
|
639 |
+
|
640 |
+
retrieve_evidence = []
|
641 |
+
|
642 |
+
for this_search_string, this_search_type in zip(search_strings, search_types):
|
643 |
+
for page_num in range(n_pages):
|
644 |
+
search_results = get_google_search_results(api_key, search_engine_id, google_search, sort_date,
|
645 |
+
this_search_string, page=page_num)
|
646 |
+
|
647 |
+
for result in search_results:
|
648 |
+
link = str(result["link"])
|
649 |
+
domain = get_domain_name(link)
|
650 |
+
|
651 |
+
if domain in blacklist:
|
652 |
+
continue
|
653 |
+
broken = False
|
654 |
+
for b_file in blacklist_files:
|
655 |
+
if b_file in link:
|
656 |
+
broken = True
|
657 |
+
if broken:
|
658 |
+
continue
|
659 |
+
if link.endswith(".pdf") or link.endswith(".doc"):
|
660 |
+
continue
|
661 |
+
|
662 |
+
store_file_path = ""
|
663 |
+
|
664 |
+
if link in visited:
|
665 |
+
store_file_path = visited[link]
|
666 |
+
else:
|
667 |
+
store_counter += 1
|
668 |
+
store_file_path = store_folder + "/search_result_" + str(index) + "_" + str(
|
669 |
+
store_counter) + ".store"
|
670 |
+
visited[link] = store_file_path
|
671 |
+
|
672 |
+
while len(worker_stack) == 0: # Wait for a wrrker to become available. Check every second.
|
673 |
+
sleep(1)
|
674 |
+
|
675 |
+
worker = worker_stack.pop()
|
676 |
+
|
677 |
+
t = threading.Thread(target=get_and_store, args=(link, store_file_path, worker, worker_stack))
|
678 |
+
t.start()
|
679 |
+
|
680 |
+
line = [str(index), claim, link, str(page_num), this_search_string, this_search_type, store_file_path]
|
681 |
+
retrieve_evidence.append(line)
|
682 |
+
|
683 |
+
return retrieve_evidence
|
684 |
+
|
685 |
+
|
686 |
+
def claim2prompts(example):
|
687 |
+
claim = example["claim"]
|
688 |
+
|
689 |
+
# claim_str = "Claim: " + claim + "||Evidence: "
|
690 |
+
claim_str = "Evidence: "
|
691 |
+
|
692 |
+
for question in example["questions"]:
|
693 |
+
q_text = question["question"].strip()
|
694 |
+
if len(q_text) == 0:
|
695 |
+
continue
|
696 |
+
|
697 |
+
if not q_text[-1] == "?":
|
698 |
+
q_text += "?"
|
699 |
+
|
700 |
+
answer_strings = []
|
701 |
+
|
702 |
+
for a in question["answers"]:
|
703 |
+
if a["answer_type"] in ["Extractive", "Abstractive"]:
|
704 |
+
answer_strings.append(a["answer"])
|
705 |
+
if a["answer_type"] == "Boolean":
|
706 |
+
answer_strings.append(a["answer"] + ", because " + a["boolean_explanation"].lower().strip())
|
707 |
+
|
708 |
+
for a_text in answer_strings:
|
709 |
+
if not a_text[-1] in [".", "!", ":", "?"]:
|
710 |
+
a_text += "."
|
711 |
+
|
712 |
+
# prompt_lookup_str = claim + " " + a_text
|
713 |
+
prompt_lookup_str = a_text
|
714 |
+
this_q_claim_str = claim_str + " " + a_text.strip() + "||Question answered: " + q_text
|
715 |
+
yield (prompt_lookup_str, this_q_claim_str.replace("\n", " ").replace("||", "\n"))
|
716 |
+
|
717 |
+
|
718 |
+
def generate_step2_reference_corpus(reference_file):
|
719 |
+
with open(reference_file) as f:
|
720 |
+
train_examples = json.load(f)
|
721 |
+
|
722 |
+
prompt_corpus = []
|
723 |
+
tokenized_corpus = []
|
724 |
+
|
725 |
+
for example in train_examples:
|
726 |
+
for lookup_str, prompt in claim2prompts(example):
|
727 |
+
entry = nltk.word_tokenize(lookup_str)
|
728 |
+
tokenized_corpus.append(entry)
|
729 |
+
prompt_corpus.append(prompt)
|
730 |
+
|
731 |
+
return tokenized_corpus, prompt_corpus
|
732 |
+
|
733 |
+
|
734 |
+
def decorate_with_questions(claim, retrieve_evidence, top_k=10): # top_k=100
|
735 |
+
#
|
736 |
+
reference_file = "averitec_code/data/train.json"
|
737 |
+
tokenized_corpus, prompt_corpus = generate_step2_reference_corpus(reference_file)
|
738 |
+
prompt_bm25 = BM25Okapi(tokenized_corpus)
|
739 |
+
|
740 |
+
# Define the bloom model:
|
741 |
+
accelerator = Accelerator()
|
742 |
+
accel_device = accelerator.device
|
743 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
744 |
+
tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-7b1")
|
745 |
+
model = BloomForCausalLM.from_pretrained(
|
746 |
+
"bigscience/bloom-7b1",
|
747 |
+
device_map="auto",
|
748 |
+
torch_dtype=torch.bfloat16,
|
749 |
+
offload_folder="./offload"
|
750 |
+
)
|
751 |
+
|
752 |
+
#
|
753 |
+
tokenized_corpus = []
|
754 |
+
all_data_corpus = []
|
755 |
+
|
756 |
+
for retri_evi in tqdm.tqdm(retrieve_evidence):
|
757 |
+
store_file = retri_evi[-1]
|
758 |
+
|
759 |
+
with open(store_file, 'r') as f:
|
760 |
+
first = True
|
761 |
+
for line in f:
|
762 |
+
line = line.strip()
|
763 |
+
|
764 |
+
if first:
|
765 |
+
first = False
|
766 |
+
location_url = line
|
767 |
+
continue
|
768 |
+
|
769 |
+
if len(line) > 3:
|
770 |
+
entry = nltk.word_tokenize(line)
|
771 |
+
if (location_url, line) not in all_data_corpus:
|
772 |
+
tokenized_corpus.append(entry)
|
773 |
+
all_data_corpus.append((location_url, line))
|
774 |
+
|
775 |
+
if len(tokenized_corpus) == 0:
|
776 |
+
print("")
|
777 |
+
|
778 |
+
bm25 = BM25Okapi(tokenized_corpus)
|
779 |
+
s = bm25.get_scores(nltk.word_tokenize(claim))
|
780 |
+
top_n = np.argsort(s)[::-1][:top_k]
|
781 |
+
docs = [all_data_corpus[i] for i in top_n]
|
782 |
+
|
783 |
+
generate_qa_pairs = []
|
784 |
+
# Then, generate questions for those top 50:
|
785 |
+
for doc in tqdm.tqdm(docs):
|
786 |
+
# prompt_lookup_str = example["claim"] + " " + doc[1]
|
787 |
+
prompt_lookup_str = doc[1]
|
788 |
+
|
789 |
+
prompt_s = prompt_bm25.get_scores(nltk.word_tokenize(prompt_lookup_str))
|
790 |
+
prompt_n = 10
|
791 |
+
prompt_top_n = np.argsort(prompt_s)[::-1][:prompt_n]
|
792 |
+
prompt_docs = [prompt_corpus[i] for i in prompt_top_n]
|
793 |
+
|
794 |
+
claim_prompt = "Evidence: " + doc[1].replace("\n", " ") + "\nQuestion answered: "
|
795 |
+
prompt = "\n\n".join(prompt_docs + [claim_prompt])
|
796 |
+
sentences = [prompt]
|
797 |
+
|
798 |
+
inputs = tokenizer(sentences, padding=True, return_tensors="pt").to(device)
|
799 |
+
outputs = model.generate(inputs["input_ids"], max_length=5000, num_beams=2, no_repeat_ngram_size=2,
|
800 |
+
early_stopping=True)
|
801 |
+
|
802 |
+
tgt_text = tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:], skip_special_tokens=True)[0]
|
803 |
+
# We are not allowed to generate more than 250 characters:
|
804 |
+
tgt_text = tgt_text[:250]
|
805 |
+
|
806 |
+
qa_pair = [tgt_text.strip().split("?")[0].replace("\n", " ") + "?", doc[1].replace("\n", " "), doc[0]]
|
807 |
+
generate_qa_pairs.append(qa_pair)
|
808 |
+
|
809 |
+
return generate_qa_pairs
|
810 |
+
|
811 |
+
|
812 |
+
def triple_to_string(x):
|
813 |
+
return " </s> ".join([item.strip() for item in x])
|
814 |
+
|
815 |
+
|
816 |
+
def rerank_questions(claim, bm25_qas, topk=3):
|
817 |
+
#
|
818 |
+
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
819 |
+
bert_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2,
|
820 |
+
problem_type="single_label_classification") # Must specify single_label for some reason
|
821 |
+
best_checkpoint = "averitec_code/pretrained_models/bert_dual_encoder.ckpt"
|
822 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
823 |
+
trained_model = DualEncoderModule.load_from_checkpoint(best_checkpoint, tokenizer=tokenizer, model=bert_model).to(
|
824 |
+
device)
|
825 |
+
|
826 |
+
#
|
827 |
+
strs_to_score = []
|
828 |
+
values = []
|
829 |
+
|
830 |
+
for question, answer, source in bm25_qas:
|
831 |
+
str_to_score = triple_to_string([claim, question, answer])
|
832 |
+
|
833 |
+
strs_to_score.append(str_to_score)
|
834 |
+
values.append([question, answer, source])
|
835 |
+
|
836 |
+
if len(bm25_qas) > 0:
|
837 |
+
encoded_dict = tokenizer(strs_to_score, max_length=512, padding="longest", truncation=True,
|
838 |
+
return_tensors="pt").to(device)
|
839 |
+
|
840 |
+
input_ids = encoded_dict['input_ids']
|
841 |
+
attention_masks = encoded_dict['attention_mask']
|
842 |
+
|
843 |
+
scores = torch.softmax(trained_model(input_ids, attention_mask=attention_masks).logits, axis=-1)[:, 1]
|
844 |
+
|
845 |
+
top_n = torch.argsort(scores, descending=True)[:topk]
|
846 |
+
pass_through = [{"question": values[i][0], "answers": values[i][1], "source_url": values[i][2]} for i in top_n]
|
847 |
+
else:
|
848 |
+
pass_through = []
|
849 |
+
|
850 |
+
top3_qa_pairs = pass_through
|
851 |
+
|
852 |
+
return top3_qa_pairs
|
853 |
+
|
854 |
+
|
855 |
+
def GoogleAPIretriever(query):
|
856 |
+
# ----- Generate QA pairs using AVeriTeC
|
857 |
+
top3_qa_pairs_path = "averitec_code/top3_qa_pairs1.json"
|
858 |
+
if not os.path.exists(top3_qa_pairs_path):
|
859 |
+
# step 1: generate questions for the query/claim using Bloom
|
860 |
+
generate_question = prompt_question_generation(query)
|
861 |
+
# step 2: retrieve evidence for the generated questions using Google API
|
862 |
+
retrieve_evidence = averitec_search(query, generate_question)
|
863 |
+
# step 3: generate QA pairs for each retrieved document
|
864 |
+
bm25_qa_pairs = decorate_with_questions(query, retrieve_evidence)
|
865 |
+
# step 4: rerank QA pairs
|
866 |
+
top3_qa_pairs = rerank_questions(query, bm25_qa_pairs)
|
867 |
+
else:
|
868 |
+
top3_qa_pairs = json.load(open(top3_qa_pairs_path, 'r'))
|
869 |
+
|
870 |
+
# Add score to metadata
|
871 |
+
results = []
|
872 |
+
for i, qa in enumerate(top3_qa_pairs):
|
873 |
+
metadata = dict()
|
874 |
+
|
875 |
+
metadata['name'] = qa['question']
|
876 |
+
metadata['url'] = qa['source_url']
|
877 |
+
metadata['cached_source_url'] = qa['source_url']
|
878 |
+
metadata['short_name'] = "Evidence {}".format(i + 1)
|
879 |
+
metadata['page_number'] = ""
|
880 |
+
metadata['query'] = qa['question']
|
881 |
+
metadata['answer'] = qa['answers']
|
882 |
+
metadata['page_content'] = "<b>Question</b>: " + qa['question'] + "<br>" + "<b>Answer</b>: " + qa['answers']
|
883 |
+
page_content = f"""{metadata['page_content']}"""
|
884 |
+
results.append((metadata, page_content))
|
885 |
+
|
886 |
+
return results
|
887 |
+
|
888 |
+
|
889 |
+
# ----------GoogleAPIretriever---------
|
890 |
+
|
891 |
+
# ----------Wikipediaretriever---------
|
892 |
+
def bm25_retriever(query, corpus, topk=3):
|
893 |
+
bm25 = BM25Okapi(corpus)
|
894 |
+
#
|
895 |
+
query_tokens = word_tokenize(query)
|
896 |
+
scores = bm25.get_scores(query_tokens)
|
897 |
+
top_n = np.argsort(scores)[::-1][:topk]
|
898 |
+
top_n_scores = [scores[i] for i in top_n]
|
899 |
+
|
900 |
+
return top_n, top_n_scores
|
901 |
+
|
902 |
+
|
903 |
+
def bm25s_retriever(query, corpus, topk=3):
|
904 |
+
# optional: create a stemmer
|
905 |
+
stemmer = Stemmer.Stemmer("english")
|
906 |
+
# Tokenize the corpus and only keep the ids (faster and saves memory)
|
907 |
+
corpus_tokens = bm25s.tokenize(corpus, stopwords="en", stemmer=stemmer)
|
908 |
+
# Create the BM25 model and index the corpus
|
909 |
+
retriever = bm25s.BM25()
|
910 |
+
retriever.index(corpus_tokens)
|
911 |
+
# Query the corpus
|
912 |
+
query_tokens = bm25s.tokenize(query, stemmer=stemmer)
|
913 |
+
# Get top-k results as a tuple of (doc ids, scores). Both are arrays of shape (n_queries, k)
|
914 |
+
results, scores = retriever.retrieve(query_tokens, corpus=corpus, k=topk)
|
915 |
+
top_n = [corpus.index(res) for res in results[0]]
|
916 |
+
return top_n, scores
|
917 |
+
|
918 |
+
|
919 |
+
def find_evidence_from_wikipedia_dumps(claim):
|
920 |
+
#
|
921 |
+
doc = nlp(claim)
|
922 |
+
entities_in_claim = [str(ent).lower() for ent in doc.ents]
|
923 |
+
title2id = ranker.doc_dict[0]
|
924 |
+
wiki_intro, ent_list = [], []
|
925 |
+
for ent in entities_in_claim:
|
926 |
+
if ent in title2id.keys():
|
927 |
+
ids = title2id[ent]
|
928 |
+
introduction = doc_db.get_doc_intro(ids)
|
929 |
+
wiki_intro.append([ent, introduction])
|
930 |
+
# fulltext = doc_db.get_doc_text(ids)
|
931 |
+
# evidence.append([ent, fulltext])
|
932 |
+
ent_list.append(ent)
|
933 |
+
|
934 |
+
if len(wiki_intro) < 5:
|
935 |
+
evidence_tfidf = process_topk(claim, title2id, ent_list, k=5)
|
936 |
+
wiki_intro.extend(evidence_tfidf)
|
937 |
+
|
938 |
+
return wiki_intro, doc
|
939 |
+
|
940 |
+
|
941 |
+
def relevant_sentence_retrieval(query, wiki_intro, k):
|
942 |
+
# 1. Create corpus here
|
943 |
+
corpus, sentences = [], []
|
944 |
+
titles = []
|
945 |
+
for i, (title, intro) in enumerate(wiki_intro):
|
946 |
+
sents_in_intro = sent_tokenize(intro)
|
947 |
+
|
948 |
+
for sent in sents_in_intro:
|
949 |
+
corpus.append(word_tokenize(sent))
|
950 |
+
sentences.append(sent)
|
951 |
+
titles.append(title)
|
952 |
+
#
|
953 |
+
# ----- BM25
|
954 |
+
bm25_top_n, bm25_top_n_scores = bm25_retriever(query, corpus, topk=k)
|
955 |
+
bm25_top_n_sents = [sentences[i] for i in bm25_top_n]
|
956 |
+
bm25_top_n_titles = [titles[i] for i in bm25_top_n]
|
957 |
+
|
958 |
+
# ----- BM25s
|
959 |
+
# bm25s_top_n, bm25s_top_n_scores = bm25s_retriever(query, sentences, topk=k) # corpus->sentences
|
960 |
+
# bm25s_top_n_sents = [sentences[i] for i in bm25s_top_n]
|
961 |
+
# bm25s_top_n_titles = [titles[i] for i in bm25s_top_n]
|
962 |
+
|
963 |
+
return bm25_top_n_sents, bm25_top_n_titles
|
964 |
+
|
965 |
+
|
966 |
+
def process_topk(query, title2id, ent_list, k=1):
|
967 |
+
doc_names, doc_scores = ranker.closest_docs(query, k)
|
968 |
+
evidence_tfidf = []
|
969 |
+
|
970 |
+
for _name in doc_names:
|
971 |
+
if _name not in ent_list and len(ent_list) < 5:
|
972 |
+
ent_list.append(_name)
|
973 |
+
idx = title2id[_name]
|
974 |
+
introduction = doc_db.get_doc_intro(idx)
|
975 |
+
evidence_tfidf.append([_name, introduction])
|
976 |
+
# fulltext = doc_db.get_doc_text(idx)
|
977 |
+
# evidence_tfidf.append([_name,fulltext])
|
978 |
+
|
979 |
+
return evidence_tfidf
|
980 |
+
|
981 |
+
|
982 |
+
def WikipediaDumpsretriever(claim):
|
983 |
+
#
|
984 |
+
# 1. extract relevant wikipedia pages from wikipedia dumps
|
985 |
+
wiki_intro, doc = find_evidence_from_wikipedia_dumps(claim)
|
986 |
+
# wiki_intro = [['trump', "'''Trump''' most commonly refers to:\n* Donald Trump (born 1946), President of the United States from 2017 to 2021 \n* Trump (card games), any playing card given an ad-hoc high rank\n\n'''Trump''' may also refer to:"]]
|
987 |
+
|
988 |
+
# 2. extract relevant sentences from extracted wikipedia pages
|
989 |
+
sents, titles = relevant_sentence_retrieval(claim, wiki_intro, k=3)
|
990 |
+
|
991 |
+
#
|
992 |
+
results = []
|
993 |
+
for i, (sent, title) in enumerate(zip(sents, titles)):
|
994 |
+
metadata = dict()
|
995 |
+
metadata['name'] = claim
|
996 |
+
metadata['url'] = "https://en.wikipedia.org/wiki/" + "_".join(title.split())
|
997 |
+
metadata['cached_source_url'] = "https://en.wikipedia.org/wiki/" + "_".join(title.split())
|
998 |
+
metadata['short_name'] = "Evidence {}".format(i + 1)
|
999 |
+
metadata['page_number'] = ""
|
1000 |
+
metadata['query'] = sent
|
1001 |
+
metadata['title'] = title
|
1002 |
+
metadata['evidence'] = sent
|
1003 |
+
metadata['answer'] = ""
|
1004 |
+
metadata['page_content'] = "<b>Title</b>: " + str(metadata['title']) + "<br>" + "<b>Evidence</b>: " + metadata[
|
1005 |
+
'evidence']
|
1006 |
+
page_content = f"""{metadata['page_content']}"""
|
1007 |
+
|
1008 |
+
results.append(Docs(metadata, page_content))
|
1009 |
+
|
1010 |
+
return results
|
1011 |
+
|
1012 |
+
# ----------WikipediaAPIretriever---------
|
1013 |
+
def clean_str(p):
|
1014 |
+
return p.encode().decode("unicode-escape").encode("latin1").decode("utf-8")
|
1015 |
+
|
1016 |
+
|
1017 |
+
def get_page_obs(page):
|
1018 |
+
# find all paragraphs
|
1019 |
+
paragraphs = page.split("\n")
|
1020 |
+
paragraphs = [p.strip() for p in paragraphs if p.strip()]
|
1021 |
+
|
1022 |
+
# # find all sentence
|
1023 |
+
# sentences = []
|
1024 |
+
# for p in paragraphs:
|
1025 |
+
# sentences += p.split('. ')
|
1026 |
+
# sentences = [s.strip() + '.' for s in sentences if s.strip()]
|
1027 |
+
# # return ' '.join(sentences[:5])
|
1028 |
+
# return ' '.join(sentences)
|
1029 |
+
|
1030 |
+
return ' '.join(paragraphs[:5])
|
1031 |
+
|
1032 |
+
|
1033 |
+
def search_entity_wikipeida(entity):
|
1034 |
+
find_evidence = []
|
1035 |
+
|
1036 |
+
page_py = wiki_wiki.page(entity)
|
1037 |
+
if page_py.exists():
|
1038 |
+
introduction = page_py.summary
|
1039 |
+
|
1040 |
+
find_evidence.append([str(entity), introduction])
|
1041 |
+
|
1042 |
+
return find_evidence
|
1043 |
+
|
1044 |
+
|
1045 |
+
def search_step(entity):
|
1046 |
+
ent_ = entity.replace(" ", "+")
|
1047 |
+
search_url = f"https://en.wikipedia.org/w/index.php?search={ent_}"
|
1048 |
+
response_text = requests.get(search_url).text
|
1049 |
+
soup = BeautifulSoup(response_text, features="html.parser")
|
1050 |
+
result_divs = soup.find_all("div", {"class": "mw-search-result-heading"})
|
1051 |
+
|
1052 |
+
find_evidence = []
|
1053 |
+
|
1054 |
+
if result_divs: # mismatch
|
1055 |
+
# If the wikipeida page of the entity is not exist, find similar wikipedia pages.
|
1056 |
+
result_titles = [clean_str(div.get_text().strip()) for div in result_divs]
|
1057 |
+
similar_titles = result_titles[:5]
|
1058 |
+
|
1059 |
+
for _t in similar_titles:
|
1060 |
+
if len(find_evidence) < 5:
|
1061 |
+
_evi = search_step(_t)
|
1062 |
+
find_evidence.extend(_evi)
|
1063 |
+
else:
|
1064 |
+
page = [p.get_text().strip() for p in soup.find_all("p") + soup.find_all("ul")]
|
1065 |
+
if any("may refer to:" in p for p in page):
|
1066 |
+
_evi = search_step("[" + entity + "]")
|
1067 |
+
find_evidence.extend(_evi)
|
1068 |
+
else:
|
1069 |
+
# page_py = wiki_wiki.page(entity)
|
1070 |
+
#
|
1071 |
+
# if page_py.exists():
|
1072 |
+
# introduction = page_py.summary
|
1073 |
+
# else:
|
1074 |
+
page_text = ""
|
1075 |
+
for p in page:
|
1076 |
+
if len(p.split(" ")) > 2:
|
1077 |
+
page_text += clean_str(p)
|
1078 |
+
if not p.endswith("\n"):
|
1079 |
+
page_text += "\n"
|
1080 |
+
introduction = get_page_obs(page_text)
|
1081 |
+
|
1082 |
+
find_evidence.append([entity, introduction])
|
1083 |
+
|
1084 |
+
return find_evidence
|
1085 |
+
|
1086 |
+
|
1087 |
+
def find_similar_wikipedia(entity, relevant_wikipages):
|
1088 |
+
# If the relevant wikipeida page of the entity is less than 5, find similar wikipedia pages.
|
1089 |
+
ent_ = entity.replace(" ", "+")
|
1090 |
+
search_url = f"https://en.wikipedia.org/w/index.php?search={ent_}&title=Special:Search&profile=advanced&fulltext=1&ns0=1"
|
1091 |
+
response_text = requests.get(search_url).text
|
1092 |
+
soup = BeautifulSoup(response_text, features="html.parser")
|
1093 |
+
result_divs = soup.find_all("div", {"class": "mw-search-result-heading"})
|
1094 |
+
|
1095 |
+
if result_divs:
|
1096 |
+
result_titles = [clean_str(div.get_text().strip()) for div in result_divs]
|
1097 |
+
similar_titles = result_titles[:5]
|
1098 |
+
|
1099 |
+
saved_titles = [ent[0] for ent in relevant_wikipages] if relevant_wikipages else relevant_wikipages
|
1100 |
+
for _t in similar_titles:
|
1101 |
+
if _t not in saved_titles and len(relevant_wikipages) < 5:
|
1102 |
+
_evi = search_entity_wikipeida(_t)
|
1103 |
+
# _evi = search_step(_t)
|
1104 |
+
relevant_wikipages.extend(_evi)
|
1105 |
+
|
1106 |
+
return relevant_wikipages
|
1107 |
+
|
1108 |
+
|
1109 |
+
def find_evidence_from_wikipedia(claim):
|
1110 |
+
#
|
1111 |
+
doc = nlp(claim)
|
1112 |
+
#
|
1113 |
+
wikipedia_page = []
|
1114 |
+
for ent in doc.ents:
|
1115 |
+
relevant_wikipages = search_entity_wikipeida(ent)
|
1116 |
+
|
1117 |
+
if len(relevant_wikipages) < 5:
|
1118 |
+
relevant_wikipages = find_similar_wikipedia(str(ent), relevant_wikipages)
|
1119 |
+
|
1120 |
+
wikipedia_page.extend(relevant_wikipages)
|
1121 |
+
|
1122 |
+
return wikipedia_page
|
1123 |
+
|
1124 |
+
|
1125 |
+
def relevant_wikipedia_API_retriever(claim):
|
1126 |
+
#
|
1127 |
+
doc = nlp(claim)
|
1128 |
+
|
1129 |
+
wiki_intro = []
|
1130 |
+
for ent in doc.ents:
|
1131 |
+
page_py = wiki_wiki.page(ent)
|
1132 |
+
|
1133 |
+
if page_py.exists():
|
1134 |
+
introduction = page_py.summary
|
1135 |
+
else:
|
1136 |
+
introduction = "No documents found."
|
1137 |
+
|
1138 |
+
wiki_intro.append([str(ent), introduction])
|
1139 |
+
|
1140 |
+
return wiki_intro, doc
|
1141 |
+
|
1142 |
+
|
1143 |
+
def Wikipediaretriever(claim, sources):
|
1144 |
+
#
|
1145 |
+
# 1. extract relevant wikipedia pages from wikipedia dumps
|
1146 |
+
if "Dump" in sources:
|
1147 |
+
wikipedia_page = find_evidence_from_wikipedia_dumps(claim)
|
1148 |
+
else:
|
1149 |
+
wikipedia_page = find_evidence_from_wikipedia(claim)
|
1150 |
+
# wiki_intro, doc = relevant_wikipedia_API_retriever(claim)
|
1151 |
+
|
1152 |
+
# 2. extract relevant sentences from extracted wikipedia pages
|
1153 |
+
sents, titles = relevant_sentence_retrieval(claim, wikipedia_page, k=3)
|
1154 |
+
|
1155 |
+
#
|
1156 |
+
results = []
|
1157 |
+
for i, (sent, title) in enumerate(zip(sents, titles)):
|
1158 |
+
metadata = dict()
|
1159 |
+
metadata['name'] = claim
|
1160 |
+
metadata['url'] = "https://en.wikipedia.org/wiki/" + "_".join(title.split())
|
1161 |
+
metadata['cached_source_url'] = "https://en.wikipedia.org/wiki/" + "_".join(title)
|
1162 |
+
metadata['short_name'] = "Evidence {}".format(i + 1)
|
1163 |
+
metadata['page_number'] = ""
|
1164 |
+
metadata['query'] = sent
|
1165 |
+
metadata['title'] = title
|
1166 |
+
metadata['evidence'] = sent
|
1167 |
+
metadata['answer'] = ""
|
1168 |
+
metadata['page_content'] = "<b>Title</b>: " + str(metadata['title']) + "<br>" + "<b>Evidence</b>: " + metadata['evidence']
|
1169 |
+
page_content = f"""{metadata['page_content']}"""
|
1170 |
+
|
1171 |
+
results.append(Docs(metadata, page_content))
|
1172 |
+
|
1173 |
+
return results
|
1174 |
+
|
1175 |
+
|
1176 |
+
def log_on_azure(file, logs, azure_share_client):
|
1177 |
+
logs = json.dumps(logs)
|
1178 |
+
file_client = azure_share_client.get_file_client(file)
|
1179 |
+
file_client.upload_file(logs)
|
1180 |
+
|
1181 |
+
|
1182 |
+
def chat(claim, history, sources):
|
1183 |
+
evidence = []
|
1184 |
+
# if 'Google' in sources:
|
1185 |
+
# evidence = GoogleAPIretriever(query)
|
1186 |
+
|
1187 |
+
# if 'WikiPediaDumps' in sources:
|
1188 |
+
# evidence = WikipediaDumpsretriever(query)
|
1189 |
+
|
1190 |
+
if 'WikiPedia' in sources:
|
1191 |
+
evidence = Wikipediaretriever(claim, sources)
|
1192 |
+
|
1193 |
+
answer_set, answer_output = QAprediction(claim, evidence, sources)
|
1194 |
+
|
1195 |
+
docs_html = ""
|
1196 |
+
if len(evidence) > 0:
|
1197 |
+
docs_html = []
|
1198 |
+
for i, evi in enumerate(evidence, 1):
|
1199 |
+
docs_html.append(make_html_source(evi, i))
|
1200 |
+
docs_html = "".join(docs_html)
|
1201 |
+
else:
|
1202 |
+
print("No documents found")
|
1203 |
+
|
1204 |
+
url_of_evidence = ""
|
1205 |
+
output_language = "English"
|
1206 |
+
output_query = claim
|
1207 |
+
history[-1] = (claim, answer_set)
|
1208 |
+
history = [tuple(x) for x in history]
|
1209 |
+
|
1210 |
+
############################################################
|
1211 |
+
evi_list = []
|
1212 |
+
for evi in evidence:
|
1213 |
+
title_str = evi.metadata['title']
|
1214 |
+
evi_str = evi.metadata['evidence']
|
1215 |
+
evi_list.append([title_str, evi_str])
|
1216 |
+
|
1217 |
+
try:
|
1218 |
+
# Log answer on Azure Blob Storage
|
1219 |
+
# IF AZURE_ISSAVE=TRUE, save the logs into the Azure share client.
|
1220 |
+
if bool(os.environ["AZURE_ISSAVE"]):
|
1221 |
+
timestamp = str(datetime.now().timestamp())
|
1222 |
+
# timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
1223 |
+
file = timestamp + ".json"
|
1224 |
+
logs = {
|
1225 |
+
"user_id": str(user_id),
|
1226 |
+
"claim": claim,
|
1227 |
+
"sources": sources,
|
1228 |
+
"evidence": evi_list,
|
1229 |
+
"url": url_of_evidence,
|
1230 |
+
"answer": answer_output,
|
1231 |
+
"time": timestamp,
|
1232 |
+
}
|
1233 |
+
log_on_azure(file, logs, azure_share_client)
|
1234 |
+
except Exception as e:
|
1235 |
+
print(f"Error logging on Azure Blob Storage: {e}")
|
1236 |
+
raise gr.Error(
|
1237 |
+
f"AVeriTeC Error: {str(e)[:100]} - The error has been noted, try another question and if the error remains, you can contact us :)")
|
1238 |
+
##########
|
1239 |
+
|
1240 |
+
return history, docs_html, output_query, output_language
|
1241 |
+
|
1242 |
+
|
1243 |
+
def main():
|
1244 |
+
init_prompt = """
|
1245 |
+
Hello, I am a fact-checking assistant designed to help you find appropriate evidence to predict the veracity of claims.
|
1246 |
+
|
1247 |
+
What do you want to fact-check?
|
1248 |
+
"""
|
1249 |
+
|
1250 |
+
with gr.Blocks(title="AVeriTeC fact-checker", css="style.css", theme=theme, elem_id="main-component") as demo:
|
1251 |
+
with gr.Tab("AVeriTeC"):
|
1252 |
+
with gr.Row(elem_id="chatbot-row"):
|
1253 |
+
with gr.Column(scale=2):
|
1254 |
+
chatbot = gr.Chatbot(
|
1255 |
+
value=[(None, init_prompt)],
|
1256 |
+
show_copy_button=True, show_label=False, elem_id="chatbot", layout="panel",
|
1257 |
+
avatar_images=(None, "assets/averitec.png")
|
1258 |
+
) # avatar_images=(None, "https://i.ibb.co/YNyd5W2/logo4.png"),
|
1259 |
+
|
1260 |
+
with gr.Row(elem_id="input-message"):
|
1261 |
+
textbox = gr.Textbox(placeholder="Ask me what claim do you want to check!", show_label=False,
|
1262 |
+
scale=7, lines=1, interactive=True, elem_id="input-textbox")
|
1263 |
+
# submit = gr.Button("",elem_id = "submit-button",scale = 1,interactive = True,icon = "https://static-00.iconduck.com/assets.00/settings-icon-2048x2046-cw28eevx.png")
|
1264 |
+
|
1265 |
+
with gr.Column(scale=1, variant="panel", elem_id="right-panel"):
|
1266 |
+
with gr.Tabs() as tabs:
|
1267 |
+
with gr.TabItem("Examples", elem_id="tab-examples", id=0):
|
1268 |
+
examples_hidden = gr.Textbox(visible=False)
|
1269 |
+
first_key = list(CLAIMS_Type.keys())[0]
|
1270 |
+
dropdown_samples = gr.Dropdown(CLAIMS_Type.keys(), value=first_key, interactive=True,
|
1271 |
+
show_label=True,
|
1272 |
+
label="Select claim type",
|
1273 |
+
elem_id="dropdown-samples")
|
1274 |
+
|
1275 |
+
samples = []
|
1276 |
+
for i, key in enumerate(CLAIMS_Type.keys()):
|
1277 |
+
examples_visible = True if i == 0 else False
|
1278 |
+
|
1279 |
+
with gr.Row(visible=examples_visible) as group_examples:
|
1280 |
+
examples_questions = gr.Examples(
|
1281 |
+
CLAIMS_Type[key],
|
1282 |
+
[examples_hidden],
|
1283 |
+
examples_per_page=8,
|
1284 |
+
run_on_click=False,
|
1285 |
+
elem_id=f"examples{i}",
|
1286 |
+
api_name=f"examples{i}",
|
1287 |
+
# label = "Click on the example question or enter your own",
|
1288 |
+
# cache_examples=True,
|
1289 |
+
)
|
1290 |
+
|
1291 |
+
samples.append(group_examples)
|
1292 |
+
|
1293 |
+
with gr.Tab("Sources", elem_id="tab-citations", id=1):
|
1294 |
+
sources_textbox = gr.HTML(show_label=False, elem_id="sources-textbox")
|
1295 |
+
docs_textbox = gr.State("")
|
1296 |
+
|
1297 |
+
with gr.Tab("Configuration", elem_id="tab-config", id=2):
|
1298 |
+
gr.Markdown("Reminder: We currently only support fact-checking in English!")
|
1299 |
+
|
1300 |
+
# dropdown_sources = gr.Radio(
|
1301 |
+
# ["AVeriTeC", "WikiPediaDumps", "Google", "WikiPediaAPI"],
|
1302 |
+
# label="Select source",
|
1303 |
+
# value="WikiPediaAPI",
|
1304 |
+
# interactive=True,
|
1305 |
+
# )
|
1306 |
+
|
1307 |
+
dropdown_sources = gr.Radio(
|
1308 |
+
["Google", "WikiPedia"],
|
1309 |
+
label="Select source",
|
1310 |
+
value="WikiPedia",
|
1311 |
+
interactive=True,
|
1312 |
+
)
|
1313 |
+
|
1314 |
+
dropdown_retriever = gr.Dropdown(
|
1315 |
+
["BM25", "BM25s"],
|
1316 |
+
label="Select evidence retriever",
|
1317 |
+
multiselect=False,
|
1318 |
+
value="BM25",
|
1319 |
+
interactive=True,
|
1320 |
+
)
|
1321 |
+
|
1322 |
+
output_query = gr.Textbox(label="Query used for retrieval", show_label=True,
|
1323 |
+
elem_id="reformulated-query", lines=2, interactive=False)
|
1324 |
+
output_language = gr.Textbox(label="Language", show_label=True, elem_id="language", lines=1,
|
1325 |
+
interactive=False)
|
1326 |
+
|
1327 |
+
with gr.Tab("About", elem_classes="max-height other-tabs"):
|
1328 |
+
with gr.Row():
|
1329 |
+
with gr.Column(scale=1):
|
1330 |
+
gr.Markdown("See more info at [https://fever.ai/task.html](https://fever.ai/task.html)")
|
1331 |
+
|
1332 |
+
def start_chat(query, history):
|
1333 |
+
history = history + [(query, None)]
|
1334 |
+
history = [tuple(x) for x in history]
|
1335 |
+
return (gr.update(interactive=False), gr.update(selected=1), history)
|
1336 |
+
|
1337 |
+
def finish_chat():
|
1338 |
+
return (gr.update(interactive=True, value=""))
|
1339 |
+
|
1340 |
+
(textbox
|
1341 |
+
.submit(start_chat, [textbox, chatbot], [textbox, tabs, chatbot], queue=False, api_name="start_chat_textbox")
|
1342 |
+
.then(chat, [textbox, chatbot, dropdown_sources],
|
1343 |
+
[chatbot, sources_textbox, output_query, output_language], concurrency_limit=8, api_name="chat_textbox")
|
1344 |
+
.then(finish_chat, None, [textbox], api_name="finish_chat_textbox")
|
1345 |
+
)
|
1346 |
+
|
1347 |
+
(examples_hidden
|
1348 |
+
.change(start_chat, [examples_hidden, chatbot], [textbox, tabs, chatbot], queue=False,
|
1349 |
+
api_name="start_chat_examples")
|
1350 |
+
.then(chat, [examples_hidden, chatbot, dropdown_sources],
|
1351 |
+
[chatbot, sources_textbox, output_query, output_language], concurrency_limit=8, api_name="chat_examples")
|
1352 |
+
.then(finish_chat, None, [textbox], api_name="finish_chat_examples")
|
1353 |
+
)
|
1354 |
+
|
1355 |
+
def change_sample_questions(key):
|
1356 |
+
index = list(CLAIMS_Type.keys()).index(key)
|
1357 |
+
visible_bools = [False] * len(samples)
|
1358 |
+
visible_bools[index] = True
|
1359 |
+
return [gr.update(visible=visible_bools[i]) for i in range(len(samples))]
|
1360 |
+
|
1361 |
+
dropdown_samples.change(change_sample_questions, dropdown_samples, samples)
|
1362 |
+
demo.queue()
|
1363 |
+
|
1364 |
+
demo.launch(share=True)
|
1365 |
+
|
1366 |
+
|
1367 |
+
if __name__ == "__main__":
|
1368 |
+
main()
|
drqa/__init__.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright 2017-present, Facebook, Inc.
|
3 |
+
# All rights reserved.
|
4 |
+
#
|
5 |
+
# This source code is licensed under the license found in the
|
6 |
+
# LICENSE file in the root directory of this source tree.
|
7 |
+
|
8 |
+
import os
|
9 |
+
import sys
|
10 |
+
from pathlib import PosixPath
|
11 |
+
|
12 |
+
if sys.version_info < (3, 5):
|
13 |
+
raise RuntimeError('DrQA supports Python 3.5 or higher.')
|
14 |
+
|
15 |
+
DATA_DIR = (
|
16 |
+
os.getenv('DRQA_DATA') or
|
17 |
+
os.path.join(PosixPath(__file__).absolute().parents[1].as_posix(), 'data')
|
18 |
+
)
|
19 |
+
|
20 |
+
from . import tokenizers
|
21 |
+
from . import reader
|
22 |
+
from . import retriever
|
23 |
+
from . import pipeline
|
drqa/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (612 Bytes). View file
|
|
drqa/pipeline/__init__.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright 2017-present, Facebook, Inc.
|
3 |
+
# All rights reserved.
|
4 |
+
#
|
5 |
+
# This source code is licensed under the license found in the
|
6 |
+
# LICENSE file in the root directory of this source tree.
|
7 |
+
|
8 |
+
import os
|
9 |
+
from ..tokenizers import CoreNLPTokenizer
|
10 |
+
from ..retriever import TfidfDocRanker
|
11 |
+
from ..retriever import DocDB
|
12 |
+
from .. import DATA_DIR
|
13 |
+
|
14 |
+
DEFAULTS = {
|
15 |
+
'tokenizer': CoreNLPTokenizer,
|
16 |
+
'ranker': TfidfDocRanker,
|
17 |
+
'db': DocDB,
|
18 |
+
'reader_model': os.path.join(DATA_DIR, 'reader/multitask.mdl'),
|
19 |
+
}
|
20 |
+
|
21 |
+
|
22 |
+
def set_default(key, value):
|
23 |
+
global DEFAULTS
|
24 |
+
DEFAULTS[key] = value
|
25 |
+
|
26 |
+
|
27 |
+
from .drqa import DrQA
|
drqa/pipeline/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (645 Bytes). View file
|
|
drqa/pipeline/__pycache__/drqa.cpython-38.pyc
ADDED
Binary file (7.78 kB). View file
|
|
drqa/pipeline/drqa.py
ADDED
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright 2017-present, Facebook, Inc.
|
3 |
+
# All rights reserved.
|
4 |
+
#
|
5 |
+
# This source code is licensed under the license found in the
|
6 |
+
# LICENSE file in the root directory of this source tree.
|
7 |
+
"""Full DrQA pipeline."""
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import regex
|
11 |
+
import heapq
|
12 |
+
import math
|
13 |
+
import time
|
14 |
+
import logging
|
15 |
+
|
16 |
+
from multiprocessing import Pool as ProcessPool
|
17 |
+
from multiprocessing.util import Finalize
|
18 |
+
|
19 |
+
from ..reader.vector import batchify
|
20 |
+
from ..reader.data import ReaderDataset, SortedBatchSampler
|
21 |
+
from .. import reader
|
22 |
+
from .. import tokenizers
|
23 |
+
from . import DEFAULTS
|
24 |
+
|
25 |
+
logger = logging.getLogger(__name__)
|
26 |
+
|
27 |
+
|
28 |
+
# ------------------------------------------------------------------------------
|
29 |
+
# Multiprocessing functions to fetch and tokenize text
|
30 |
+
# ------------------------------------------------------------------------------
|
31 |
+
|
32 |
+
PROCESS_TOK = None
|
33 |
+
PROCESS_DB = None
|
34 |
+
PROCESS_CANDS = None
|
35 |
+
|
36 |
+
|
37 |
+
def init(tokenizer_class, tokenizer_opts, db_class, db_opts, candidates=None):
|
38 |
+
global PROCESS_TOK, PROCESS_DB, PROCESS_CANDS
|
39 |
+
PROCESS_TOK = tokenizer_class(**tokenizer_opts)
|
40 |
+
Finalize(PROCESS_TOK, PROCESS_TOK.shutdown, exitpriority=100)
|
41 |
+
PROCESS_DB = db_class(**db_opts)
|
42 |
+
Finalize(PROCESS_DB, PROCESS_DB.close, exitpriority=100)
|
43 |
+
PROCESS_CANDS = candidates
|
44 |
+
|
45 |
+
|
46 |
+
def fetch_text(doc_id):
|
47 |
+
global PROCESS_DB
|
48 |
+
return PROCESS_DB.get_doc_text(doc_id)
|
49 |
+
|
50 |
+
|
51 |
+
def tokenize_text(text):
|
52 |
+
global PROCESS_TOK
|
53 |
+
return PROCESS_TOK.tokenize(text)
|
54 |
+
|
55 |
+
|
56 |
+
# ------------------------------------------------------------------------------
|
57 |
+
# Main DrQA pipeline
|
58 |
+
# ------------------------------------------------------------------------------
|
59 |
+
|
60 |
+
|
61 |
+
class DrQA(object):
|
62 |
+
# Target size for squashing short paragraphs together.
|
63 |
+
# 0 = read every paragraph independently
|
64 |
+
# infty = read all paragraphs together
|
65 |
+
GROUP_LENGTH = 0
|
66 |
+
|
67 |
+
def __init__(
|
68 |
+
self,
|
69 |
+
reader_model=None,
|
70 |
+
embedding_file=None,
|
71 |
+
tokenizer=None,
|
72 |
+
fixed_candidates=None,
|
73 |
+
batch_size=128,
|
74 |
+
cuda=True,
|
75 |
+
data_parallel=False,
|
76 |
+
max_loaders=5,
|
77 |
+
num_workers=None,
|
78 |
+
db_config=None,
|
79 |
+
ranker_config=None
|
80 |
+
):
|
81 |
+
"""Initialize the pipeline.
|
82 |
+
|
83 |
+
Args:
|
84 |
+
reader_model: model file from which to load the DocReader.
|
85 |
+
embedding_file: if given, will expand DocReader dictionary to use
|
86 |
+
all available pretrained embeddings.
|
87 |
+
tokenizer: string option to specify tokenizer used on docs.
|
88 |
+
fixed_candidates: if given, all predictions will be constrated to
|
89 |
+
the set of candidates contained in the file. One entry per line.
|
90 |
+
batch_size: batch size when processing paragraphs.
|
91 |
+
cuda: whether to use the gpu.
|
92 |
+
data_parallel: whether to use multile gpus.
|
93 |
+
max_loaders: max number of async data loading workers when reading.
|
94 |
+
(default is fine).
|
95 |
+
num_workers: number of parallel CPU processes to use for tokenizing
|
96 |
+
and post processing resuls.
|
97 |
+
db_config: config for doc db.
|
98 |
+
ranker_config: config for ranker.
|
99 |
+
"""
|
100 |
+
self.batch_size = batch_size
|
101 |
+
self.max_loaders = max_loaders
|
102 |
+
self.fixed_candidates = fixed_candidates is not None
|
103 |
+
self.cuda = cuda
|
104 |
+
|
105 |
+
logger.info('Initializing document ranker...')
|
106 |
+
ranker_config = ranker_config or {}
|
107 |
+
ranker_class = ranker_config.get('class', DEFAULTS['ranker'])
|
108 |
+
ranker_opts = ranker_config.get('options', {})
|
109 |
+
self.ranker = ranker_class(**ranker_opts)
|
110 |
+
|
111 |
+
logger.info('Initializing document reader...')
|
112 |
+
reader_model = reader_model or DEFAULTS['reader_model']
|
113 |
+
self.reader = reader.DocReader.load(reader_model, normalize=False)
|
114 |
+
if embedding_file:
|
115 |
+
logger.info('Expanding dictionary...')
|
116 |
+
words = reader.utils.index_embedding_words(embedding_file)
|
117 |
+
added = self.reader.expand_dictionary(words)
|
118 |
+
self.reader.load_embeddings(added, embedding_file)
|
119 |
+
if cuda:
|
120 |
+
self.reader.cuda()
|
121 |
+
if data_parallel:
|
122 |
+
self.reader.parallelize()
|
123 |
+
|
124 |
+
if not tokenizer:
|
125 |
+
tok_class = DEFAULTS['tokenizer']
|
126 |
+
else:
|
127 |
+
tok_class = tokenizers.get_class(tokenizer)
|
128 |
+
annotators = tokenizers.get_annotators_for_model(self.reader)
|
129 |
+
tok_opts = {'annotators': annotators}
|
130 |
+
|
131 |
+
# ElasticSearch is also used as backend if used as ranker
|
132 |
+
if hasattr(self.ranker, 'es'):
|
133 |
+
db_config = ranker_config
|
134 |
+
db_class = ranker_class
|
135 |
+
db_opts = ranker_opts
|
136 |
+
else:
|
137 |
+
db_config = db_config or {}
|
138 |
+
db_class = db_config.get('class', DEFAULTS['db'])
|
139 |
+
db_opts = db_config.get('options', {})
|
140 |
+
|
141 |
+
logger.info('Initializing tokenizers and document retrievers...')
|
142 |
+
self.num_workers = num_workers
|
143 |
+
self.processes = ProcessPool(
|
144 |
+
num_workers,
|
145 |
+
initializer=init,
|
146 |
+
initargs=(tok_class, tok_opts, db_class, db_opts, fixed_candidates)
|
147 |
+
)
|
148 |
+
|
149 |
+
def _split_doc(self, doc):
|
150 |
+
"""Given a doc, split it into chunks (by paragraph)."""
|
151 |
+
curr = []
|
152 |
+
curr_len = 0
|
153 |
+
for split in regex.split(r'\n+', doc):
|
154 |
+
split = split.strip()
|
155 |
+
if len(split) == 0:
|
156 |
+
continue
|
157 |
+
# Maybe group paragraphs together until we hit a length limit
|
158 |
+
if len(curr) > 0 and curr_len + len(split) > self.GROUP_LENGTH:
|
159 |
+
yield ' '.join(curr)
|
160 |
+
curr = []
|
161 |
+
curr_len = 0
|
162 |
+
curr.append(split)
|
163 |
+
curr_len += len(split)
|
164 |
+
if len(curr) > 0:
|
165 |
+
yield ' '.join(curr)
|
166 |
+
|
167 |
+
def _get_loader(self, data, num_loaders):
|
168 |
+
"""Return a pytorch data iterator for provided examples."""
|
169 |
+
dataset = ReaderDataset(data, self.reader)
|
170 |
+
sampler = SortedBatchSampler(
|
171 |
+
dataset.lengths(),
|
172 |
+
self.batch_size,
|
173 |
+
shuffle=False
|
174 |
+
)
|
175 |
+
loader = torch.utils.data.DataLoader(
|
176 |
+
dataset,
|
177 |
+
batch_size=self.batch_size,
|
178 |
+
sampler=sampler,
|
179 |
+
num_workers=num_loaders,
|
180 |
+
collate_fn=batchify,
|
181 |
+
pin_memory=self.cuda,
|
182 |
+
)
|
183 |
+
return loader
|
184 |
+
|
185 |
+
def process(self, query, candidates=None, top_n=1, n_docs=5,
|
186 |
+
return_context=False):
|
187 |
+
"""Run a single query."""
|
188 |
+
predictions = self.process_batch(
|
189 |
+
[query], [candidates] if candidates else None,
|
190 |
+
top_n, n_docs, return_context
|
191 |
+
)
|
192 |
+
return predictions[0]
|
193 |
+
|
194 |
+
def process_batch(self, queries, candidates=None, top_n=1, n_docs=5,
|
195 |
+
return_context=False):
|
196 |
+
"""Run a batch of queries (more efficient)."""
|
197 |
+
t0 = time.time()
|
198 |
+
logger.info('Processing %d queries...' % len(queries))
|
199 |
+
logger.info('Retrieving top %d docs...' % n_docs)
|
200 |
+
|
201 |
+
# Rank documents for queries.
|
202 |
+
if len(queries) == 1:
|
203 |
+
ranked = [self.ranker.closest_docs(queries[0], k=n_docs)]
|
204 |
+
else:
|
205 |
+
ranked = self.ranker.batch_closest_docs(
|
206 |
+
queries, k=n_docs, num_workers=self.num_workers
|
207 |
+
)
|
208 |
+
all_docids, all_doc_scores = zip(*ranked)
|
209 |
+
|
210 |
+
# Flatten document ids and retrieve text from database.
|
211 |
+
# We remove duplicates for processing efficiency.
|
212 |
+
flat_docids = list({d for docids in all_docids for d in docids})
|
213 |
+
did2didx = {did: didx for didx, did in enumerate(flat_docids)}
|
214 |
+
doc_texts = self.processes.map(fetch_text, flat_docids)
|
215 |
+
|
216 |
+
# Split and flatten documents. Maintain a mapping from doc (index in
|
217 |
+
# flat list) to split (index in flat list).
|
218 |
+
flat_splits = []
|
219 |
+
didx2sidx = []
|
220 |
+
for text in doc_texts:
|
221 |
+
splits = self._split_doc(text)
|
222 |
+
didx2sidx.append([len(flat_splits), -1])
|
223 |
+
for split in splits:
|
224 |
+
flat_splits.append(split)
|
225 |
+
didx2sidx[-1][1] = len(flat_splits)
|
226 |
+
|
227 |
+
# Push through the tokenizers as fast as possible.
|
228 |
+
q_tokens = self.processes.map_async(tokenize_text, queries)
|
229 |
+
s_tokens = self.processes.map_async(tokenize_text, flat_splits)
|
230 |
+
q_tokens = q_tokens.get()
|
231 |
+
s_tokens = s_tokens.get()
|
232 |
+
|
233 |
+
# Group into structured example inputs. Examples' ids represent
|
234 |
+
# mappings to their question, document, and split ids.
|
235 |
+
examples = []
|
236 |
+
for qidx in range(len(queries)):
|
237 |
+
for rel_didx, did in enumerate(all_docids[qidx]):
|
238 |
+
start, end = didx2sidx[did2didx[did]]
|
239 |
+
for sidx in range(start, end):
|
240 |
+
if (len(q_tokens[qidx].words()) > 0 and
|
241 |
+
len(s_tokens[sidx].words()) > 0):
|
242 |
+
examples.append({
|
243 |
+
'id': (qidx, rel_didx, sidx),
|
244 |
+
'question': q_tokens[qidx].words(),
|
245 |
+
'qlemma': q_tokens[qidx].lemmas(),
|
246 |
+
'document': s_tokens[sidx].words(),
|
247 |
+
'lemma': s_tokens[sidx].lemmas(),
|
248 |
+
'pos': s_tokens[sidx].pos(),
|
249 |
+
'ner': s_tokens[sidx].entities(),
|
250 |
+
})
|
251 |
+
|
252 |
+
logger.info('Reading %d paragraphs...' % len(examples))
|
253 |
+
|
254 |
+
# Push all examples through the document reader.
|
255 |
+
# We decode argmax start/end indices asychronously on CPU.
|
256 |
+
result_handles = []
|
257 |
+
num_loaders = min(self.max_loaders, math.floor(len(examples) / 1e3))
|
258 |
+
for batch in self._get_loader(examples, num_loaders):
|
259 |
+
if candidates or self.fixed_candidates:
|
260 |
+
batch_cands = []
|
261 |
+
for ex_id in batch[-1]:
|
262 |
+
batch_cands.append({
|
263 |
+
'input': s_tokens[ex_id[2]],
|
264 |
+
'cands': candidates[ex_id[0]] if candidates else None
|
265 |
+
})
|
266 |
+
handle = self.reader.predict(
|
267 |
+
batch, batch_cands, async_pool=self.processes
|
268 |
+
)
|
269 |
+
else:
|
270 |
+
handle = self.reader.predict(batch, async_pool=self.processes)
|
271 |
+
result_handles.append((handle, batch[-1], batch[0].size(0)))
|
272 |
+
|
273 |
+
# Iterate through the predictions, and maintain priority queues for
|
274 |
+
# top scored answers for each question in the batch.
|
275 |
+
queues = [[] for _ in range(len(queries))]
|
276 |
+
for result, ex_ids, batch_size in result_handles:
|
277 |
+
s, e, score = result.get()
|
278 |
+
for i in range(batch_size):
|
279 |
+
# We take the top prediction per split.
|
280 |
+
if len(score[i]) > 0:
|
281 |
+
item = (score[i][0], ex_ids[i], s[i][0], e[i][0])
|
282 |
+
queue = queues[ex_ids[i][0]]
|
283 |
+
if len(queue) < top_n:
|
284 |
+
heapq.heappush(queue, item)
|
285 |
+
else:
|
286 |
+
heapq.heappushpop(queue, item)
|
287 |
+
|
288 |
+
# Arrange final top prediction data.
|
289 |
+
all_predictions = []
|
290 |
+
for queue in queues:
|
291 |
+
predictions = []
|
292 |
+
while len(queue) > 0:
|
293 |
+
score, (qidx, rel_didx, sidx), s, e = heapq.heappop(queue)
|
294 |
+
prediction = {
|
295 |
+
'doc_id': all_docids[qidx][rel_didx],
|
296 |
+
'span': s_tokens[sidx].slice(s, e + 1).untokenize(),
|
297 |
+
'doc_score': float(all_doc_scores[qidx][rel_didx]),
|
298 |
+
'span_score': float(score),
|
299 |
+
}
|
300 |
+
if return_context:
|
301 |
+
prediction['context'] = {
|
302 |
+
'text': s_tokens[sidx].untokenize(),
|
303 |
+
'start': s_tokens[sidx].offsets()[s][0],
|
304 |
+
'end': s_tokens[sidx].offsets()[e][1],
|
305 |
+
}
|
306 |
+
predictions.append(prediction)
|
307 |
+
all_predictions.append(predictions[-1::-1])
|
308 |
+
|
309 |
+
logger.info('Processed %d queries in %.4f (s)' %
|
310 |
+
(len(queries), time.time() - t0))
|
311 |
+
|
312 |
+
return all_predictions
|
drqa/reader/__init__.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright 2017-present, Facebook, Inc.
|
3 |
+
# All rights reserved.
|
4 |
+
#
|
5 |
+
# This source code is licensed under the license found in the
|
6 |
+
# LICENSE file in the root directory of this source tree.
|
7 |
+
|
8 |
+
import os
|
9 |
+
from ..tokenizers import CoreNLPTokenizer
|
10 |
+
from .. import DATA_DIR
|
11 |
+
|
12 |
+
|
13 |
+
DEFAULTS = {
|
14 |
+
'tokenizer': CoreNLPTokenizer,
|
15 |
+
'model': os.path.join(DATA_DIR, 'reader/single.mdl'),
|
16 |
+
}
|
17 |
+
|
18 |
+
|
19 |
+
def set_default(key, value):
|
20 |
+
global DEFAULTS
|
21 |
+
DEFAULTS[key] = value
|
22 |
+
|
23 |
+
from .model import DocReader
|
24 |
+
from .predictor import Predictor
|
25 |
+
from . import config
|
26 |
+
from . import vector
|
27 |
+
from . import data
|
28 |
+
from . import utils
|
drqa/reader/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (697 Bytes). View file
|
|
drqa/reader/__pycache__/config.cpython-38.pyc
ADDED
Binary file (4.52 kB). View file
|
|
drqa/reader/__pycache__/data.cpython-38.pyc
ADDED
Binary file (4.99 kB). View file
|
|
drqa/reader/__pycache__/layers.cpython-38.pyc
ADDED
Binary file (7.73 kB). View file
|
|
drqa/reader/__pycache__/model.cpython-38.pyc
ADDED
Binary file (13.3 kB). View file
|
|
drqa/reader/__pycache__/predictor.cpython-38.pyc
ADDED
Binary file (4.22 kB). View file
|
|
drqa/reader/__pycache__/rnn_reader.cpython-38.pyc
ADDED
Binary file (2.9 kB). View file
|
|
drqa/reader/__pycache__/utils.cpython-38.pyc
ADDED
Binary file (8.67 kB). View file
|
|
drqa/reader/__pycache__/vector.cpython-38.pyc
ADDED
Binary file (4.71 kB). View file
|
|
drqa/reader/config.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright 2017-present, Facebook, Inc.
|
3 |
+
# All rights reserved.
|
4 |
+
#
|
5 |
+
# This source code is licensed under the license found in the
|
6 |
+
# LICENSE file in the root directory of this source tree.
|
7 |
+
"""Model architecture/optimization options for DrQA document reader."""
|
8 |
+
|
9 |
+
import argparse
|
10 |
+
import logging
|
11 |
+
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
+
|
14 |
+
# Index of arguments concerning the core model architecture
|
15 |
+
MODEL_ARCHITECTURE = {
|
16 |
+
'model_type', 'embedding_dim', 'hidden_size', 'doc_layers',
|
17 |
+
'question_layers', 'rnn_type', 'concat_rnn_layers', 'question_merge',
|
18 |
+
'use_qemb', 'use_in_question', 'use_pos', 'use_ner', 'use_lemma', 'use_tf'
|
19 |
+
}
|
20 |
+
|
21 |
+
# Index of arguments concerning the model optimizer/training
|
22 |
+
MODEL_OPTIMIZER = {
|
23 |
+
'fix_embeddings', 'optimizer', 'learning_rate', 'momentum', 'weight_decay',
|
24 |
+
'rnn_padding', 'dropout_rnn', 'dropout_rnn_output', 'dropout_emb',
|
25 |
+
'max_len', 'grad_clipping', 'tune_partial'
|
26 |
+
}
|
27 |
+
|
28 |
+
|
29 |
+
def str2bool(v):
|
30 |
+
return v.lower() in ('yes', 'true', 't', '1', 'y')
|
31 |
+
|
32 |
+
|
33 |
+
def add_model_args(parser):
|
34 |
+
parser.register('type', 'bool', str2bool)
|
35 |
+
|
36 |
+
# Model architecture
|
37 |
+
model = parser.add_argument_group('DrQA Reader Model Architecture')
|
38 |
+
model.add_argument('--model-type', type=str, default='rnn',
|
39 |
+
help='Model architecture type')
|
40 |
+
model.add_argument('--embedding-dim', type=int, default=300,
|
41 |
+
help='Embedding size if embedding_file is not given')
|
42 |
+
model.add_argument('--hidden-size', type=int, default=128,
|
43 |
+
help='Hidden size of RNN units')
|
44 |
+
model.add_argument('--doc-layers', type=int, default=3,
|
45 |
+
help='Number of encoding layers for document')
|
46 |
+
model.add_argument('--question-layers', type=int, default=3,
|
47 |
+
help='Number of encoding layers for question')
|
48 |
+
model.add_argument('--rnn-type', type=str, default='lstm',
|
49 |
+
help='RNN type: LSTM, GRU, or RNN')
|
50 |
+
|
51 |
+
# Model specific details
|
52 |
+
detail = parser.add_argument_group('DrQA Reader Model Details')
|
53 |
+
detail.add_argument('--concat-rnn-layers', type='bool', default=True,
|
54 |
+
help='Combine hidden states from each encoding layer')
|
55 |
+
detail.add_argument('--question-merge', type=str, default='self_attn',
|
56 |
+
help='The way of computing the question representation')
|
57 |
+
detail.add_argument('--use-qemb', type='bool', default=True,
|
58 |
+
help='Whether to use weighted question embeddings')
|
59 |
+
detail.add_argument('--use-in-question', type='bool', default=True,
|
60 |
+
help='Whether to use in_question_* features')
|
61 |
+
detail.add_argument('--use-pos', type='bool', default=True,
|
62 |
+
help='Whether to use pos features')
|
63 |
+
detail.add_argument('--use-ner', type='bool', default=True,
|
64 |
+
help='Whether to use ner features')
|
65 |
+
detail.add_argument('--use-lemma', type='bool', default=True,
|
66 |
+
help='Whether to use lemma features')
|
67 |
+
detail.add_argument('--use-tf', type='bool', default=True,
|
68 |
+
help='Whether to use term frequency features')
|
69 |
+
|
70 |
+
# Optimization details
|
71 |
+
optim = parser.add_argument_group('DrQA Reader Optimization')
|
72 |
+
optim.add_argument('--dropout-emb', type=float, default=0.4,
|
73 |
+
help='Dropout rate for word embeddings')
|
74 |
+
optim.add_argument('--dropout-rnn', type=float, default=0.4,
|
75 |
+
help='Dropout rate for RNN states')
|
76 |
+
optim.add_argument('--dropout-rnn-output', type='bool', default=True,
|
77 |
+
help='Whether to dropout the RNN output')
|
78 |
+
optim.add_argument('--optimizer', type=str, default='adamax',
|
79 |
+
help='Optimizer: sgd or adamax')
|
80 |
+
optim.add_argument('--learning-rate', type=float, default=0.1,
|
81 |
+
help='Learning rate for SGD only')
|
82 |
+
optim.add_argument('--grad-clipping', type=float, default=10,
|
83 |
+
help='Gradient clipping')
|
84 |
+
optim.add_argument('--weight-decay', type=float, default=0,
|
85 |
+
help='Weight decay factor')
|
86 |
+
optim.add_argument('--momentum', type=float, default=0,
|
87 |
+
help='Momentum factor')
|
88 |
+
optim.add_argument('--fix-embeddings', type='bool', default=True,
|
89 |
+
help='Keep word embeddings fixed (use pretrained)')
|
90 |
+
optim.add_argument('--tune-partial', type=int, default=0,
|
91 |
+
help='Backprop through only the top N question words')
|
92 |
+
optim.add_argument('--rnn-padding', type='bool', default=False,
|
93 |
+
help='Explicitly account for padding in RNN encoding')
|
94 |
+
optim.add_argument('--max-len', type=int, default=15,
|
95 |
+
help='The max span allowed during decoding')
|
96 |
+
|
97 |
+
|
98 |
+
def get_model_args(args):
|
99 |
+
"""Filter args for model ones.
|
100 |
+
|
101 |
+
From a args Namespace, return a new Namespace with *only* the args specific
|
102 |
+
to the model architecture or optimization. (i.e. the ones defined here.)
|
103 |
+
"""
|
104 |
+
global MODEL_ARCHITECTURE, MODEL_OPTIMIZER
|
105 |
+
required_args = MODEL_ARCHITECTURE | MODEL_OPTIMIZER
|
106 |
+
arg_values = {k: v for k, v in vars(args).items() if k in required_args}
|
107 |
+
return argparse.Namespace(**arg_values)
|
108 |
+
|
109 |
+
|
110 |
+
def override_model_args(old_args, new_args):
|
111 |
+
"""Set args to new parameters.
|
112 |
+
|
113 |
+
Decide which model args to keep and which to override when resolving a set
|
114 |
+
of saved args and new args.
|
115 |
+
|
116 |
+
We keep the new optimation, but leave the model architecture alone.
|
117 |
+
"""
|
118 |
+
global MODEL_OPTIMIZER
|
119 |
+
old_args, new_args = vars(old_args), vars(new_args)
|
120 |
+
for k in old_args.keys():
|
121 |
+
if k in new_args and old_args[k] != new_args[k]:
|
122 |
+
if k in MODEL_OPTIMIZER:
|
123 |
+
logger.info('Overriding saved %s: %s --> %s' %
|
124 |
+
(k, old_args[k], new_args[k]))
|
125 |
+
old_args[k] = new_args[k]
|
126 |
+
else:
|
127 |
+
logger.info('Keeping saved %s: %s' % (k, old_args[k]))
|
128 |
+
return argparse.Namespace(**old_args)
|
drqa/reader/data.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright 2017-present, Facebook, Inc.
|
3 |
+
# All rights reserved.
|
4 |
+
#
|
5 |
+
# This source code is licensed under the license found in the
|
6 |
+
# LICENSE file in the root directory of this source tree.
|
7 |
+
"""Data processing/loading helpers."""
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import logging
|
11 |
+
import unicodedata
|
12 |
+
|
13 |
+
from torch.utils.data import Dataset
|
14 |
+
from torch.utils.data.sampler import Sampler
|
15 |
+
from .vector import vectorize
|
16 |
+
|
17 |
+
logger = logging.getLogger(__name__)
|
18 |
+
|
19 |
+
|
20 |
+
# ------------------------------------------------------------------------------
|
21 |
+
# Dictionary class for tokens.
|
22 |
+
# ------------------------------------------------------------------------------
|
23 |
+
|
24 |
+
|
25 |
+
class Dictionary(object):
|
26 |
+
NULL = '<NULL>'
|
27 |
+
UNK = '<UNK>'
|
28 |
+
START = 2
|
29 |
+
|
30 |
+
@staticmethod
|
31 |
+
def normalize(token):
|
32 |
+
return unicodedata.normalize('NFD', token)
|
33 |
+
|
34 |
+
def __init__(self):
|
35 |
+
self.tok2ind = {self.NULL: 0, self.UNK: 1}
|
36 |
+
self.ind2tok = {0: self.NULL, 1: self.UNK}
|
37 |
+
|
38 |
+
def __len__(self):
|
39 |
+
return len(self.tok2ind)
|
40 |
+
|
41 |
+
def __iter__(self):
|
42 |
+
return iter(self.tok2ind)
|
43 |
+
|
44 |
+
def __contains__(self, key):
|
45 |
+
if type(key) == int:
|
46 |
+
return key in self.ind2tok
|
47 |
+
elif type(key) == str:
|
48 |
+
return self.normalize(key) in self.tok2ind
|
49 |
+
|
50 |
+
def __getitem__(self, key):
|
51 |
+
if type(key) == int:
|
52 |
+
return self.ind2tok.get(key, self.UNK)
|
53 |
+
if type(key) == str:
|
54 |
+
return self.tok2ind.get(self.normalize(key),
|
55 |
+
self.tok2ind.get(self.UNK))
|
56 |
+
|
57 |
+
def __setitem__(self, key, item):
|
58 |
+
if type(key) == int and type(item) == str:
|
59 |
+
self.ind2tok[key] = item
|
60 |
+
elif type(key) == str and type(item) == int:
|
61 |
+
self.tok2ind[key] = item
|
62 |
+
else:
|
63 |
+
raise RuntimeError('Invalid (key, item) types.')
|
64 |
+
|
65 |
+
def add(self, token):
|
66 |
+
token = self.normalize(token)
|
67 |
+
if token not in self.tok2ind:
|
68 |
+
index = len(self.tok2ind)
|
69 |
+
self.tok2ind[token] = index
|
70 |
+
self.ind2tok[index] = token
|
71 |
+
|
72 |
+
def tokens(self):
|
73 |
+
"""Get dictionary tokens.
|
74 |
+
|
75 |
+
Return all the words indexed by this dictionary, except for special
|
76 |
+
tokens.
|
77 |
+
"""
|
78 |
+
tokens = [k for k in self.tok2ind.keys()
|
79 |
+
if k not in {'<NULL>', '<UNK>'}]
|
80 |
+
return tokens
|
81 |
+
|
82 |
+
|
83 |
+
# ------------------------------------------------------------------------------
|
84 |
+
# PyTorch dataset class for SQuAD (and SQuAD-like) data.
|
85 |
+
# ------------------------------------------------------------------------------
|
86 |
+
|
87 |
+
|
88 |
+
class ReaderDataset(Dataset):
|
89 |
+
|
90 |
+
def __init__(self, examples, model, single_answer=False):
|
91 |
+
self.model = model
|
92 |
+
self.examples = examples
|
93 |
+
self.single_answer = single_answer
|
94 |
+
|
95 |
+
def __len__(self):
|
96 |
+
return len(self.examples)
|
97 |
+
|
98 |
+
def __getitem__(self, index):
|
99 |
+
return vectorize(self.examples[index], self.model, self.single_answer)
|
100 |
+
|
101 |
+
def lengths(self):
|
102 |
+
return [(len(ex['document']), len(ex['question']))
|
103 |
+
for ex in self.examples]
|
104 |
+
|
105 |
+
|
106 |
+
# ------------------------------------------------------------------------------
|
107 |
+
# PyTorch sampler returning batched of sorted lengths (by doc and question).
|
108 |
+
# ------------------------------------------------------------------------------
|
109 |
+
|
110 |
+
|
111 |
+
class SortedBatchSampler(Sampler):
|
112 |
+
|
113 |
+
def __init__(self, lengths, batch_size, shuffle=True):
|
114 |
+
self.lengths = lengths
|
115 |
+
self.batch_size = batch_size
|
116 |
+
self.shuffle = shuffle
|
117 |
+
|
118 |
+
def __iter__(self):
|
119 |
+
lengths = np.array(
|
120 |
+
[(-l[0], -l[1], np.random.random()) for l in self.lengths],
|
121 |
+
dtype=[('l1', np.int_), ('l2', np.int_), ('rand', np.float_)]
|
122 |
+
)
|
123 |
+
indices = np.argsort(lengths, order=('l1', 'l2', 'rand'))
|
124 |
+
batches = [indices[i:i + self.batch_size]
|
125 |
+
for i in range(0, len(indices), self.batch_size)]
|
126 |
+
if self.shuffle:
|
127 |
+
np.random.shuffle(batches)
|
128 |
+
return iter([i for batch in batches for i in batch])
|
129 |
+
|
130 |
+
def __len__(self):
|
131 |
+
return len(self.lengths)
|
drqa/reader/layers.py
ADDED
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright 2017-present, Facebook, Inc.
|
3 |
+
# All rights reserved.
|
4 |
+
#
|
5 |
+
# This source code is licensed under the license found in the
|
6 |
+
# LICENSE file in the root directory of this source tree.
|
7 |
+
"""Definitions of model layers/NN modules"""
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
|
13 |
+
|
14 |
+
# ------------------------------------------------------------------------------
|
15 |
+
# Modules
|
16 |
+
# ------------------------------------------------------------------------------
|
17 |
+
|
18 |
+
|
19 |
+
class StackedBRNN(nn.Module):
|
20 |
+
"""Stacked Bi-directional RNNs.
|
21 |
+
|
22 |
+
Differs from standard PyTorch library in that it has the option to save
|
23 |
+
and concat the hidden states between layers. (i.e. the output hidden size
|
24 |
+
for each sequence input is num_layers * hidden_size).
|
25 |
+
"""
|
26 |
+
|
27 |
+
def __init__(self, input_size, hidden_size, num_layers,
|
28 |
+
dropout_rate=0, dropout_output=False, rnn_type=nn.LSTM,
|
29 |
+
concat_layers=False, padding=False):
|
30 |
+
super(StackedBRNN, self).__init__()
|
31 |
+
self.padding = padding
|
32 |
+
self.dropout_output = dropout_output
|
33 |
+
self.dropout_rate = dropout_rate
|
34 |
+
self.num_layers = num_layers
|
35 |
+
self.concat_layers = concat_layers
|
36 |
+
self.rnns = nn.ModuleList()
|
37 |
+
for i in range(num_layers):
|
38 |
+
input_size = input_size if i == 0 else 2 * hidden_size
|
39 |
+
self.rnns.append(rnn_type(input_size, hidden_size,
|
40 |
+
num_layers=1,
|
41 |
+
bidirectional=True))
|
42 |
+
|
43 |
+
def forward(self, x, x_mask):
|
44 |
+
"""Encode either padded or non-padded sequences.
|
45 |
+
|
46 |
+
Can choose to either handle or ignore variable length sequences.
|
47 |
+
Always handle padding in eval.
|
48 |
+
|
49 |
+
Args:
|
50 |
+
x: batch * len * hdim
|
51 |
+
x_mask: batch * len (1 for padding, 0 for true)
|
52 |
+
Output:
|
53 |
+
x_encoded: batch * len * hdim_encoded
|
54 |
+
"""
|
55 |
+
if x_mask.data.sum() == 0:
|
56 |
+
# No padding necessary.
|
57 |
+
output = self._forward_unpadded(x, x_mask)
|
58 |
+
elif self.padding or not self.training:
|
59 |
+
# Pad if we care or if its during eval.
|
60 |
+
output = self._forward_padded(x, x_mask)
|
61 |
+
else:
|
62 |
+
# We don't care.
|
63 |
+
output = self._forward_unpadded(x, x_mask)
|
64 |
+
|
65 |
+
return output.contiguous()
|
66 |
+
|
67 |
+
def _forward_unpadded(self, x, x_mask):
|
68 |
+
"""Faster encoding that ignores any padding."""
|
69 |
+
# Transpose batch and sequence dims
|
70 |
+
x = x.transpose(0, 1)
|
71 |
+
|
72 |
+
# Encode all layers
|
73 |
+
outputs = [x]
|
74 |
+
for i in range(self.num_layers):
|
75 |
+
rnn_input = outputs[-1]
|
76 |
+
|
77 |
+
# Apply dropout to hidden input
|
78 |
+
if self.dropout_rate > 0:
|
79 |
+
rnn_input = F.dropout(rnn_input,
|
80 |
+
p=self.dropout_rate,
|
81 |
+
training=self.training)
|
82 |
+
# Forward
|
83 |
+
rnn_output = self.rnns[i](rnn_input)[0]
|
84 |
+
outputs.append(rnn_output)
|
85 |
+
|
86 |
+
# Concat hidden layers
|
87 |
+
if self.concat_layers:
|
88 |
+
output = torch.cat(outputs[1:], 2)
|
89 |
+
else:
|
90 |
+
output = outputs[-1]
|
91 |
+
|
92 |
+
# Transpose back
|
93 |
+
output = output.transpose(0, 1)
|
94 |
+
|
95 |
+
# Dropout on output layer
|
96 |
+
if self.dropout_output and self.dropout_rate > 0:
|
97 |
+
output = F.dropout(output,
|
98 |
+
p=self.dropout_rate,
|
99 |
+
training=self.training)
|
100 |
+
return output
|
101 |
+
|
102 |
+
def _forward_padded(self, x, x_mask):
|
103 |
+
"""Slower (significantly), but more precise, encoding that handles
|
104 |
+
padding.
|
105 |
+
"""
|
106 |
+
# Compute sorted sequence lengths
|
107 |
+
lengths = x_mask.data.eq(0).long().sum(1).squeeze()
|
108 |
+
_, idx_sort = torch.sort(lengths, dim=0, descending=True)
|
109 |
+
_, idx_unsort = torch.sort(idx_sort, dim=0)
|
110 |
+
lengths = list(lengths[idx_sort])
|
111 |
+
|
112 |
+
# Sort x
|
113 |
+
x = x.index_select(0, idx_sort)
|
114 |
+
|
115 |
+
# Transpose batch and sequence dims
|
116 |
+
x = x.transpose(0, 1)
|
117 |
+
|
118 |
+
# Pack it up
|
119 |
+
rnn_input = nn.utils.rnn.pack_padded_sequence(x, lengths)
|
120 |
+
|
121 |
+
# Encode all layers
|
122 |
+
outputs = [rnn_input]
|
123 |
+
for i in range(self.num_layers):
|
124 |
+
rnn_input = outputs[-1]
|
125 |
+
|
126 |
+
# Apply dropout to input
|
127 |
+
if self.dropout_rate > 0:
|
128 |
+
dropout_input = F.dropout(rnn_input.data,
|
129 |
+
p=self.dropout_rate,
|
130 |
+
training=self.training)
|
131 |
+
rnn_input = nn.utils.rnn.PackedSequence(dropout_input,
|
132 |
+
rnn_input.batch_sizes)
|
133 |
+
outputs.append(self.rnns[i](rnn_input)[0])
|
134 |
+
|
135 |
+
# Unpack everything
|
136 |
+
for i, o in enumerate(outputs[1:], 1):
|
137 |
+
outputs[i] = nn.utils.rnn.pad_packed_sequence(o)[0]
|
138 |
+
|
139 |
+
# Concat hidden layers or take final
|
140 |
+
if self.concat_layers:
|
141 |
+
output = torch.cat(outputs[1:], 2)
|
142 |
+
else:
|
143 |
+
output = outputs[-1]
|
144 |
+
|
145 |
+
# Transpose and unsort
|
146 |
+
output = output.transpose(0, 1)
|
147 |
+
output = output.index_select(0, idx_unsort)
|
148 |
+
|
149 |
+
# Pad up to original batch sequence length
|
150 |
+
if output.size(1) != x_mask.size(1):
|
151 |
+
padding = torch.zeros(output.size(0),
|
152 |
+
x_mask.size(1) - output.size(1),
|
153 |
+
output.size(2)).type(output.data.type())
|
154 |
+
output = torch.cat([output, padding], 1)
|
155 |
+
|
156 |
+
# Dropout on output layer
|
157 |
+
if self.dropout_output and self.dropout_rate > 0:
|
158 |
+
output = F.dropout(output,
|
159 |
+
p=self.dropout_rate,
|
160 |
+
training=self.training)
|
161 |
+
return output
|
162 |
+
|
163 |
+
|
164 |
+
class SeqAttnMatch(nn.Module):
|
165 |
+
"""Given sequences X and Y, match sequence Y to each element in X.
|
166 |
+
|
167 |
+
* o_i = sum(alpha_j * y_j) for i in X
|
168 |
+
* alpha_j = softmax(y_j * x_i)
|
169 |
+
"""
|
170 |
+
|
171 |
+
def __init__(self, input_size, identity=False):
|
172 |
+
super(SeqAttnMatch, self).__init__()
|
173 |
+
if not identity:
|
174 |
+
self.linear = nn.Linear(input_size, input_size)
|
175 |
+
else:
|
176 |
+
self.linear = None
|
177 |
+
|
178 |
+
def forward(self, x, y, y_mask):
|
179 |
+
"""
|
180 |
+
Args:
|
181 |
+
x: batch * len1 * hdim
|
182 |
+
y: batch * len2 * hdim
|
183 |
+
y_mask: batch * len2 (1 for padding, 0 for true)
|
184 |
+
Output:
|
185 |
+
matched_seq: batch * len1 * hdim
|
186 |
+
"""
|
187 |
+
# Project vectors
|
188 |
+
if self.linear:
|
189 |
+
x_proj = self.linear(x.view(-1, x.size(2))).view(x.size())
|
190 |
+
x_proj = F.relu(x_proj)
|
191 |
+
y_proj = self.linear(y.view(-1, y.size(2))).view(y.size())
|
192 |
+
y_proj = F.relu(y_proj)
|
193 |
+
else:
|
194 |
+
x_proj = x
|
195 |
+
y_proj = y
|
196 |
+
|
197 |
+
# Compute scores
|
198 |
+
scores = x_proj.bmm(y_proj.transpose(2, 1))
|
199 |
+
|
200 |
+
# Mask padding
|
201 |
+
y_mask = y_mask.unsqueeze(1).expand(scores.size())
|
202 |
+
scores.data.masked_fill_(y_mask.data, -float('inf'))
|
203 |
+
|
204 |
+
# Normalize with softmax
|
205 |
+
alpha_flat = F.softmax(scores.view(-1, y.size(1)), dim=-1)
|
206 |
+
alpha = alpha_flat.view(-1, x.size(1), y.size(1))
|
207 |
+
|
208 |
+
# Take weighted average
|
209 |
+
matched_seq = alpha.bmm(y)
|
210 |
+
return matched_seq
|
211 |
+
|
212 |
+
|
213 |
+
class BilinearSeqAttn(nn.Module):
|
214 |
+
"""A bilinear attention layer over a sequence X w.r.t y:
|
215 |
+
|
216 |
+
* o_i = softmax(x_i'Wy) for x_i in X.
|
217 |
+
|
218 |
+
Optionally don't normalize output weights.
|
219 |
+
"""
|
220 |
+
|
221 |
+
def __init__(self, x_size, y_size, identity=False, normalize=True):
|
222 |
+
super(BilinearSeqAttn, self).__init__()
|
223 |
+
self.normalize = normalize
|
224 |
+
|
225 |
+
# If identity is true, we just use a dot product without transformation.
|
226 |
+
if not identity:
|
227 |
+
self.linear = nn.Linear(y_size, x_size)
|
228 |
+
else:
|
229 |
+
self.linear = None
|
230 |
+
|
231 |
+
def forward(self, x, y, x_mask):
|
232 |
+
"""
|
233 |
+
Args:
|
234 |
+
x: batch * len * hdim1
|
235 |
+
y: batch * hdim2
|
236 |
+
x_mask: batch * len (1 for padding, 0 for true)
|
237 |
+
Output:
|
238 |
+
alpha = batch * len
|
239 |
+
"""
|
240 |
+
Wy = self.linear(y) if self.linear is not None else y
|
241 |
+
xWy = x.bmm(Wy.unsqueeze(2)).squeeze(2)
|
242 |
+
xWy.data.masked_fill_(x_mask.data, -float('inf'))
|
243 |
+
if self.normalize:
|
244 |
+
if self.training:
|
245 |
+
# In training we output log-softmax for NLL
|
246 |
+
alpha = F.log_softmax(xWy, dim=-1)
|
247 |
+
else:
|
248 |
+
# ...Otherwise 0-1 probabilities
|
249 |
+
alpha = F.softmax(xWy, dim=-1)
|
250 |
+
else:
|
251 |
+
alpha = xWy.exp()
|
252 |
+
return alpha
|
253 |
+
|
254 |
+
|
255 |
+
class LinearSeqAttn(nn.Module):
|
256 |
+
"""Self attention over a sequence:
|
257 |
+
|
258 |
+
* o_i = softmax(Wx_i) for x_i in X.
|
259 |
+
"""
|
260 |
+
|
261 |
+
def __init__(self, input_size):
|
262 |
+
super(LinearSeqAttn, self).__init__()
|
263 |
+
self.linear = nn.Linear(input_size, 1)
|
264 |
+
|
265 |
+
def forward(self, x, x_mask):
|
266 |
+
"""
|
267 |
+
Args:
|
268 |
+
x: batch * len * hdim
|
269 |
+
x_mask: batch * len (1 for padding, 0 for true)
|
270 |
+
Output:
|
271 |
+
alpha: batch * len
|
272 |
+
"""
|
273 |
+
x_flat = x.view(-1, x.size(-1))
|
274 |
+
scores = self.linear(x_flat).view(x.size(0), x.size(1))
|
275 |
+
scores.data.masked_fill_(x_mask.data, -float('inf'))
|
276 |
+
alpha = F.softmax(scores, dim=-1)
|
277 |
+
return alpha
|
278 |
+
|
279 |
+
|
280 |
+
# ------------------------------------------------------------------------------
|
281 |
+
# Functional
|
282 |
+
# ------------------------------------------------------------------------------
|
283 |
+
|
284 |
+
|
285 |
+
def uniform_weights(x, x_mask):
|
286 |
+
"""Return uniform weights over non-masked x (a sequence of vectors).
|
287 |
+
|
288 |
+
Args:
|
289 |
+
x: batch * len * hdim
|
290 |
+
x_mask: batch * len (1 for padding, 0 for true)
|
291 |
+
Output:
|
292 |
+
x_avg: batch * hdim
|
293 |
+
"""
|
294 |
+
alpha = torch.ones(x.size(0), x.size(1))
|
295 |
+
if x.data.is_cuda:
|
296 |
+
alpha = alpha.cuda()
|
297 |
+
alpha = alpha * x_mask.eq(0).float()
|
298 |
+
alpha = alpha / alpha.sum(1).expand(alpha.size())
|
299 |
+
return alpha
|
300 |
+
|
301 |
+
|
302 |
+
def weighted_avg(x, weights):
|
303 |
+
"""Return a weighted average of x (a sequence of vectors).
|
304 |
+
|
305 |
+
Args:
|
306 |
+
x: batch * len * hdim
|
307 |
+
weights: batch * len, sum(dim = 1) = 1
|
308 |
+
Output:
|
309 |
+
x_avg: batch * hdim
|
310 |
+
"""
|
311 |
+
return weights.unsqueeze(1).bmm(x).squeeze(1)
|
drqa/reader/model.py
ADDED
@@ -0,0 +1,482 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright 2017-present, Facebook, Inc.
|
3 |
+
# All rights reserved.
|
4 |
+
#
|
5 |
+
# This source code is licensed under the license found in the
|
6 |
+
# LICENSE file in the root directory of this source tree.
|
7 |
+
"""DrQA Document Reader model"""
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.optim as optim
|
11 |
+
import torch.nn.functional as F
|
12 |
+
import numpy as np
|
13 |
+
import logging
|
14 |
+
import copy
|
15 |
+
|
16 |
+
from .config import override_model_args
|
17 |
+
from .rnn_reader import RnnDocReader
|
18 |
+
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
|
21 |
+
|
22 |
+
class DocReader(object):
|
23 |
+
"""High level model that handles intializing the underlying network
|
24 |
+
architecture, saving, updating examples, and predicting examples.
|
25 |
+
"""
|
26 |
+
|
27 |
+
# --------------------------------------------------------------------------
|
28 |
+
# Initialization
|
29 |
+
# --------------------------------------------------------------------------
|
30 |
+
|
31 |
+
def __init__(self, args, word_dict, feature_dict,
|
32 |
+
state_dict=None, normalize=True):
|
33 |
+
# Book-keeping.
|
34 |
+
self.args = args
|
35 |
+
self.word_dict = word_dict
|
36 |
+
self.args.vocab_size = len(word_dict)
|
37 |
+
self.feature_dict = feature_dict
|
38 |
+
self.args.num_features = len(feature_dict)
|
39 |
+
self.updates = 0
|
40 |
+
self.use_cuda = False
|
41 |
+
self.parallel = False
|
42 |
+
|
43 |
+
# Building network. If normalize if false, scores are not normalized
|
44 |
+
# 0-1 per paragraph (no softmax).
|
45 |
+
if args.model_type == 'rnn':
|
46 |
+
self.network = RnnDocReader(args, normalize)
|
47 |
+
else:
|
48 |
+
raise RuntimeError('Unsupported model: %s' % args.model_type)
|
49 |
+
|
50 |
+
# Load saved state
|
51 |
+
if state_dict:
|
52 |
+
# Load buffer separately
|
53 |
+
if 'fixed_embedding' in state_dict:
|
54 |
+
fixed_embedding = state_dict.pop('fixed_embedding')
|
55 |
+
self.network.load_state_dict(state_dict)
|
56 |
+
self.network.register_buffer('fixed_embedding', fixed_embedding)
|
57 |
+
else:
|
58 |
+
self.network.load_state_dict(state_dict)
|
59 |
+
|
60 |
+
def expand_dictionary(self, words):
|
61 |
+
"""Add words to the DocReader dictionary if they do not exist. The
|
62 |
+
underlying embedding matrix is also expanded (with random embeddings).
|
63 |
+
|
64 |
+
Args:
|
65 |
+
words: iterable of tokens to add to the dictionary.
|
66 |
+
Output:
|
67 |
+
added: set of tokens that were added.
|
68 |
+
"""
|
69 |
+
to_add = {self.word_dict.normalize(w) for w in words
|
70 |
+
if w not in self.word_dict}
|
71 |
+
|
72 |
+
# Add words to dictionary and expand embedding layer
|
73 |
+
if len(to_add) > 0:
|
74 |
+
logger.info('Adding %d new words to dictionary...' % len(to_add))
|
75 |
+
for w in to_add:
|
76 |
+
self.word_dict.add(w)
|
77 |
+
self.args.vocab_size = len(self.word_dict)
|
78 |
+
logger.info('New vocab size: %d' % len(self.word_dict))
|
79 |
+
|
80 |
+
old_embedding = self.network.embedding.weight.data
|
81 |
+
self.network.embedding = torch.nn.Embedding(self.args.vocab_size,
|
82 |
+
self.args.embedding_dim,
|
83 |
+
padding_idx=0)
|
84 |
+
new_embedding = self.network.embedding.weight.data
|
85 |
+
new_embedding[:old_embedding.size(0)] = old_embedding
|
86 |
+
|
87 |
+
# Return added words
|
88 |
+
return to_add
|
89 |
+
|
90 |
+
def load_embeddings(self, words, embedding_file):
|
91 |
+
"""Load pretrained embeddings for a given list of words, if they exist.
|
92 |
+
|
93 |
+
Args:
|
94 |
+
words: iterable of tokens. Only those that are indexed in the
|
95 |
+
dictionary are kept.
|
96 |
+
embedding_file: path to text file of embeddings, space separated.
|
97 |
+
"""
|
98 |
+
words = {w for w in words if w in self.word_dict}
|
99 |
+
logger.info('Loading pre-trained embeddings for %d words from %s' %
|
100 |
+
(len(words), embedding_file))
|
101 |
+
embedding = self.network.embedding.weight.data
|
102 |
+
|
103 |
+
# When normalized, some words are duplicated. (Average the embeddings).
|
104 |
+
vec_counts = {}
|
105 |
+
with open(embedding_file) as f:
|
106 |
+
# Skip first line if of form count/dim.
|
107 |
+
line = f.readline().rstrip().split(' ')
|
108 |
+
if len(line) != 2:
|
109 |
+
f.seek(0)
|
110 |
+
for line in f:
|
111 |
+
parsed = line.rstrip().split(' ')
|
112 |
+
assert(len(parsed) == embedding.size(1) + 1)
|
113 |
+
w = self.word_dict.normalize(parsed[0])
|
114 |
+
if w in words:
|
115 |
+
vec = torch.Tensor([float(i) for i in parsed[1:]])
|
116 |
+
if w not in vec_counts:
|
117 |
+
vec_counts[w] = 1
|
118 |
+
embedding[self.word_dict[w]].copy_(vec)
|
119 |
+
else:
|
120 |
+
logging.warning(
|
121 |
+
'WARN: Duplicate embedding found for %s' % w
|
122 |
+
)
|
123 |
+
vec_counts[w] = vec_counts[w] + 1
|
124 |
+
embedding[self.word_dict[w]].add_(vec)
|
125 |
+
|
126 |
+
for w, c in vec_counts.items():
|
127 |
+
embedding[self.word_dict[w]].div_(c)
|
128 |
+
|
129 |
+
logger.info('Loaded %d embeddings (%.2f%%)' %
|
130 |
+
(len(vec_counts), 100 * len(vec_counts) / len(words)))
|
131 |
+
|
132 |
+
def tune_embeddings(self, words):
|
133 |
+
"""Unfix the embeddings of a list of words. This is only relevant if
|
134 |
+
only some of the embeddings are being tuned (tune_partial = N).
|
135 |
+
|
136 |
+
Shuffles the N specified words to the front of the dictionary, and saves
|
137 |
+
the original vectors of the other N + 1:vocab words in a fixed buffer.
|
138 |
+
|
139 |
+
Args:
|
140 |
+
words: iterable of tokens contained in dictionary.
|
141 |
+
"""
|
142 |
+
words = {w for w in words if w in self.word_dict}
|
143 |
+
|
144 |
+
if len(words) == 0:
|
145 |
+
logger.warning('Tried to tune embeddings, but no words given!')
|
146 |
+
return
|
147 |
+
|
148 |
+
if len(words) == len(self.word_dict):
|
149 |
+
logger.warning('Tuning ALL embeddings in dictionary')
|
150 |
+
return
|
151 |
+
|
152 |
+
# Shuffle words and vectors
|
153 |
+
embedding = self.network.embedding.weight.data
|
154 |
+
for idx, swap_word in enumerate(words, self.word_dict.START):
|
155 |
+
# Get current word + embedding for this index
|
156 |
+
curr_word = self.word_dict[idx]
|
157 |
+
curr_emb = embedding[idx].clone()
|
158 |
+
old_idx = self.word_dict[swap_word]
|
159 |
+
|
160 |
+
# Swap embeddings + dictionary indices
|
161 |
+
embedding[idx].copy_(embedding[old_idx])
|
162 |
+
embedding[old_idx].copy_(curr_emb)
|
163 |
+
self.word_dict[swap_word] = idx
|
164 |
+
self.word_dict[idx] = swap_word
|
165 |
+
self.word_dict[curr_word] = old_idx
|
166 |
+
self.word_dict[old_idx] = curr_word
|
167 |
+
|
168 |
+
# Save the original, fixed embeddings
|
169 |
+
self.network.register_buffer(
|
170 |
+
'fixed_embedding', embedding[idx + 1:].clone()
|
171 |
+
)
|
172 |
+
|
173 |
+
def init_optimizer(self, state_dict=None):
|
174 |
+
"""Initialize an optimizer for the free parameters of the network.
|
175 |
+
|
176 |
+
Args:
|
177 |
+
state_dict: network parameters
|
178 |
+
"""
|
179 |
+
if self.args.fix_embeddings:
|
180 |
+
for p in self.network.embedding.parameters():
|
181 |
+
p.requires_grad = False
|
182 |
+
parameters = [p for p in self.network.parameters() if p.requires_grad]
|
183 |
+
if self.args.optimizer == 'sgd':
|
184 |
+
self.optimizer = optim.SGD(parameters, self.args.learning_rate,
|
185 |
+
momentum=self.args.momentum,
|
186 |
+
weight_decay=self.args.weight_decay)
|
187 |
+
elif self.args.optimizer == 'adamax':
|
188 |
+
self.optimizer = optim.Adamax(parameters,
|
189 |
+
weight_decay=self.args.weight_decay)
|
190 |
+
else:
|
191 |
+
raise RuntimeError('Unsupported optimizer: %s' %
|
192 |
+
self.args.optimizer)
|
193 |
+
|
194 |
+
# --------------------------------------------------------------------------
|
195 |
+
# Learning
|
196 |
+
# --------------------------------------------------------------------------
|
197 |
+
|
198 |
+
def update(self, ex):
|
199 |
+
"""Forward a batch of examples; step the optimizer to update weights."""
|
200 |
+
if not self.optimizer:
|
201 |
+
raise RuntimeError('No optimizer set.')
|
202 |
+
|
203 |
+
# Train mode
|
204 |
+
self.network.train()
|
205 |
+
|
206 |
+
# Transfer to GPU
|
207 |
+
if self.use_cuda:
|
208 |
+
inputs = [e if e is None else e.cuda(non_blocking=True)
|
209 |
+
for e in ex[:5]]
|
210 |
+
target_s = ex[5].cuda(non_blocking=True)
|
211 |
+
target_e = ex[6].cuda(non_blocking=True)
|
212 |
+
else:
|
213 |
+
inputs = [e if e is None else e for e in ex[:5]]
|
214 |
+
target_s = ex[5]
|
215 |
+
target_e = ex[6]
|
216 |
+
|
217 |
+
# Run forward
|
218 |
+
score_s, score_e = self.network(*inputs)
|
219 |
+
|
220 |
+
# Compute loss and accuracies
|
221 |
+
loss = F.nll_loss(score_s, target_s) + F.nll_loss(score_e, target_e)
|
222 |
+
|
223 |
+
# Clear gradients and run backward
|
224 |
+
self.optimizer.zero_grad()
|
225 |
+
loss.backward()
|
226 |
+
|
227 |
+
# Clip gradients
|
228 |
+
torch.nn.utils.clip_grad_norm_(self.network.parameters(),
|
229 |
+
self.args.grad_clipping)
|
230 |
+
|
231 |
+
# Update parameters
|
232 |
+
self.optimizer.step()
|
233 |
+
self.updates += 1
|
234 |
+
|
235 |
+
# Reset any partially fixed parameters (e.g. rare words)
|
236 |
+
self.reset_parameters()
|
237 |
+
|
238 |
+
return loss.item(), ex[0].size(0)
|
239 |
+
|
240 |
+
def reset_parameters(self):
|
241 |
+
"""Reset any partially fixed parameters to original states."""
|
242 |
+
|
243 |
+
# Reset fixed embeddings to original value
|
244 |
+
if self.args.tune_partial > 0:
|
245 |
+
if self.parallel:
|
246 |
+
embedding = self.network.module.embedding.weight.data
|
247 |
+
fixed_embedding = self.network.module.fixed_embedding
|
248 |
+
else:
|
249 |
+
embedding = self.network.embedding.weight.data
|
250 |
+
fixed_embedding = self.network.fixed_embedding
|
251 |
+
|
252 |
+
# Embeddings to fix are the last indices
|
253 |
+
offset = embedding.size(0) - fixed_embedding.size(0)
|
254 |
+
if offset >= 0:
|
255 |
+
embedding[offset:] = fixed_embedding
|
256 |
+
|
257 |
+
# --------------------------------------------------------------------------
|
258 |
+
# Prediction
|
259 |
+
# --------------------------------------------------------------------------
|
260 |
+
|
261 |
+
def predict(self, ex, candidates=None, top_n=1, async_pool=None):
|
262 |
+
"""Forward a batch of examples only to get predictions.
|
263 |
+
|
264 |
+
Args:
|
265 |
+
ex: the batch
|
266 |
+
candidates: batch * variable length list of string answer options.
|
267 |
+
The model will only consider exact spans contained in this list.
|
268 |
+
top_n: Number of predictions to return per batch element.
|
269 |
+
async_pool: If provided, non-gpu post-processing will be offloaded
|
270 |
+
to this CPU process pool.
|
271 |
+
Output:
|
272 |
+
pred_s: batch * top_n predicted start indices
|
273 |
+
pred_e: batch * top_n predicted end indices
|
274 |
+
pred_score: batch * top_n prediction scores
|
275 |
+
|
276 |
+
If async_pool is given, these will be AsyncResult handles.
|
277 |
+
"""
|
278 |
+
# Eval mode
|
279 |
+
self.network.eval()
|
280 |
+
|
281 |
+
# Transfer to GPU
|
282 |
+
if self.use_cuda:
|
283 |
+
inputs = [e if e is None else e.cuda(non_blocking=True)
|
284 |
+
for e in ex[:5]]
|
285 |
+
else:
|
286 |
+
inputs = [e for e in ex[:5]]
|
287 |
+
|
288 |
+
# Run forward
|
289 |
+
with torch.no_grad():
|
290 |
+
score_s, score_e = self.network(*inputs)
|
291 |
+
|
292 |
+
# Decode predictions
|
293 |
+
score_s = score_s.data.cpu()
|
294 |
+
score_e = score_e.data.cpu()
|
295 |
+
if candidates:
|
296 |
+
args = (score_s, score_e, candidates, top_n, self.args.max_len)
|
297 |
+
if async_pool:
|
298 |
+
return async_pool.apply_async(self.decode_candidates, args)
|
299 |
+
else:
|
300 |
+
return self.decode_candidates(*args)
|
301 |
+
else:
|
302 |
+
args = (score_s, score_e, top_n, self.args.max_len)
|
303 |
+
if async_pool:
|
304 |
+
return async_pool.apply_async(self.decode, args)
|
305 |
+
else:
|
306 |
+
return self.decode(*args)
|
307 |
+
|
308 |
+
@staticmethod
|
309 |
+
def decode(score_s, score_e, top_n=1, max_len=None):
|
310 |
+
"""Take argmax of constrained score_s * score_e.
|
311 |
+
|
312 |
+
Args:
|
313 |
+
score_s: independent start predictions
|
314 |
+
score_e: independent end predictions
|
315 |
+
top_n: number of top scored pairs to take
|
316 |
+
max_len: max span length to consider
|
317 |
+
"""
|
318 |
+
pred_s = []
|
319 |
+
pred_e = []
|
320 |
+
pred_score = []
|
321 |
+
max_len = max_len or score_s.size(1)
|
322 |
+
for i in range(score_s.size(0)):
|
323 |
+
# Outer product of scores to get full p_s * p_e matrix
|
324 |
+
scores = torch.ger(score_s[i], score_e[i])
|
325 |
+
|
326 |
+
# Zero out negative length and over-length span scores
|
327 |
+
scores.triu_().tril_(max_len - 1)
|
328 |
+
|
329 |
+
# Take argmax or top n
|
330 |
+
scores = scores.numpy()
|
331 |
+
scores_flat = scores.flatten()
|
332 |
+
if top_n == 1:
|
333 |
+
idx_sort = [np.argmax(scores_flat)]
|
334 |
+
elif len(scores_flat) < top_n:
|
335 |
+
idx_sort = np.argsort(-scores_flat)
|
336 |
+
else:
|
337 |
+
idx = np.argpartition(-scores_flat, top_n)[0:top_n]
|
338 |
+
idx_sort = idx[np.argsort(-scores_flat[idx])]
|
339 |
+
s_idx, e_idx = np.unravel_index(idx_sort, scores.shape)
|
340 |
+
pred_s.append(s_idx)
|
341 |
+
pred_e.append(e_idx)
|
342 |
+
pred_score.append(scores_flat[idx_sort])
|
343 |
+
return pred_s, pred_e, pred_score
|
344 |
+
|
345 |
+
@staticmethod
|
346 |
+
def decode_candidates(score_s, score_e, candidates, top_n=1, max_len=None):
|
347 |
+
"""Take argmax of constrained score_s * score_e. Except only consider
|
348 |
+
spans that are in the candidates list.
|
349 |
+
"""
|
350 |
+
pred_s = []
|
351 |
+
pred_e = []
|
352 |
+
pred_score = []
|
353 |
+
for i in range(score_s.size(0)):
|
354 |
+
# Extract original tokens stored with candidates
|
355 |
+
tokens = candidates[i]['input']
|
356 |
+
cands = candidates[i]['cands']
|
357 |
+
|
358 |
+
if not cands:
|
359 |
+
# try getting from globals? (multiprocessing in pipeline mode)
|
360 |
+
from ..pipeline.drqa import PROCESS_CANDS
|
361 |
+
cands = PROCESS_CANDS
|
362 |
+
if not cands:
|
363 |
+
raise RuntimeError('No candidates given.')
|
364 |
+
|
365 |
+
# Score all valid candidates found in text.
|
366 |
+
# Brute force get all ngrams and compare against the candidate list.
|
367 |
+
max_len = max_len or len(tokens)
|
368 |
+
scores, s_idx, e_idx = [], [], []
|
369 |
+
for s, e in tokens.ngrams(n=max_len, as_strings=False):
|
370 |
+
span = tokens.slice(s, e).untokenize()
|
371 |
+
if span in cands or span.lower() in cands:
|
372 |
+
# Match! Record its score.
|
373 |
+
scores.append(score_s[i][s] * score_e[i][e - 1])
|
374 |
+
s_idx.append(s)
|
375 |
+
e_idx.append(e - 1)
|
376 |
+
|
377 |
+
if len(scores) == 0:
|
378 |
+
# No candidates present
|
379 |
+
pred_s.append([])
|
380 |
+
pred_e.append([])
|
381 |
+
pred_score.append([])
|
382 |
+
else:
|
383 |
+
# Rank found candidates
|
384 |
+
scores = np.array(scores)
|
385 |
+
s_idx = np.array(s_idx)
|
386 |
+
e_idx = np.array(e_idx)
|
387 |
+
|
388 |
+
idx_sort = np.argsort(-scores)[0:top_n]
|
389 |
+
pred_s.append(s_idx[idx_sort])
|
390 |
+
pred_e.append(e_idx[idx_sort])
|
391 |
+
pred_score.append(scores[idx_sort])
|
392 |
+
return pred_s, pred_e, pred_score
|
393 |
+
|
394 |
+
# --------------------------------------------------------------------------
|
395 |
+
# Saving and loading
|
396 |
+
# --------------------------------------------------------------------------
|
397 |
+
|
398 |
+
def save(self, filename):
|
399 |
+
if self.parallel:
|
400 |
+
network = self.network.module
|
401 |
+
else:
|
402 |
+
network = self.network
|
403 |
+
state_dict = copy.copy(network.state_dict())
|
404 |
+
if 'fixed_embedding' in state_dict:
|
405 |
+
state_dict.pop('fixed_embedding')
|
406 |
+
params = {
|
407 |
+
'state_dict': state_dict,
|
408 |
+
'word_dict': self.word_dict,
|
409 |
+
'feature_dict': self.feature_dict,
|
410 |
+
'args': self.args,
|
411 |
+
}
|
412 |
+
try:
|
413 |
+
torch.save(params, filename)
|
414 |
+
except BaseException:
|
415 |
+
logger.warning('WARN: Saving failed... continuing anyway.')
|
416 |
+
|
417 |
+
def checkpoint(self, filename, epoch):
|
418 |
+
if self.parallel:
|
419 |
+
network = self.network.module
|
420 |
+
else:
|
421 |
+
network = self.network
|
422 |
+
params = {
|
423 |
+
'state_dict': network.state_dict(),
|
424 |
+
'word_dict': self.word_dict,
|
425 |
+
'feature_dict': self.feature_dict,
|
426 |
+
'args': self.args,
|
427 |
+
'epoch': epoch,
|
428 |
+
'optimizer': self.optimizer.state_dict(),
|
429 |
+
}
|
430 |
+
try:
|
431 |
+
torch.save(params, filename)
|
432 |
+
except BaseException:
|
433 |
+
logger.warning('WARN: Saving failed... continuing anyway.')
|
434 |
+
|
435 |
+
@staticmethod
|
436 |
+
def load(filename, new_args=None, normalize=True):
|
437 |
+
logger.info('Loading model %s' % filename)
|
438 |
+
saved_params = torch.load(
|
439 |
+
filename, map_location=lambda storage, loc: storage
|
440 |
+
)
|
441 |
+
word_dict = saved_params['word_dict']
|
442 |
+
feature_dict = saved_params['feature_dict']
|
443 |
+
state_dict = saved_params['state_dict']
|
444 |
+
args = saved_params['args']
|
445 |
+
if new_args:
|
446 |
+
args = override_model_args(args, new_args)
|
447 |
+
return DocReader(args, word_dict, feature_dict, state_dict, normalize)
|
448 |
+
|
449 |
+
@staticmethod
|
450 |
+
def load_checkpoint(filename, normalize=True):
|
451 |
+
logger.info('Loading model %s' % filename)
|
452 |
+
saved_params = torch.load(
|
453 |
+
filename, map_location=lambda storage, loc: storage
|
454 |
+
)
|
455 |
+
word_dict = saved_params['word_dict']
|
456 |
+
feature_dict = saved_params['feature_dict']
|
457 |
+
state_dict = saved_params['state_dict']
|
458 |
+
epoch = saved_params['epoch']
|
459 |
+
optimizer = saved_params['optimizer']
|
460 |
+
args = saved_params['args']
|
461 |
+
model = DocReader(args, word_dict, feature_dict, state_dict, normalize)
|
462 |
+
model.init_optimizer(optimizer)
|
463 |
+
return model, epoch
|
464 |
+
|
465 |
+
# --------------------------------------------------------------------------
|
466 |
+
# Runtime
|
467 |
+
# --------------------------------------------------------------------------
|
468 |
+
|
469 |
+
def cuda(self):
|
470 |
+
self.use_cuda = True
|
471 |
+
self.network = self.network.cuda()
|
472 |
+
|
473 |
+
def cpu(self):
|
474 |
+
self.use_cuda = False
|
475 |
+
self.network = self.network.cpu()
|
476 |
+
|
477 |
+
def parallelize(self):
|
478 |
+
"""Use data parallel to copy the model across several gpus.
|
479 |
+
This will take all gpus visible with CUDA_VISIBLE_DEVICES.
|
480 |
+
"""
|
481 |
+
self.parallel = True
|
482 |
+
self.network = torch.nn.DataParallel(self.network)
|
drqa/reader/predictor.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright 2017-present, Facebook, Inc.
|
3 |
+
# All rights reserved.
|
4 |
+
#
|
5 |
+
# This source code is licensed under the license found in the
|
6 |
+
# LICENSE file in the root directory of this source tree.
|
7 |
+
"""DrQA Document Reader predictor"""
|
8 |
+
|
9 |
+
import logging
|
10 |
+
|
11 |
+
from multiprocessing import Pool as ProcessPool
|
12 |
+
from multiprocessing.util import Finalize
|
13 |
+
|
14 |
+
from .vector import vectorize, batchify
|
15 |
+
from .model import DocReader
|
16 |
+
from . import DEFAULTS, utils
|
17 |
+
from .. import tokenizers
|
18 |
+
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
|
21 |
+
|
22 |
+
# ------------------------------------------------------------------------------
|
23 |
+
# Tokenize + annotate
|
24 |
+
# ------------------------------------------------------------------------------
|
25 |
+
|
26 |
+
PROCESS_TOK = None
|
27 |
+
|
28 |
+
|
29 |
+
def init(tokenizer_class, annotators):
|
30 |
+
global PROCESS_TOK
|
31 |
+
PROCESS_TOK = tokenizer_class(annotators=annotators)
|
32 |
+
Finalize(PROCESS_TOK, PROCESS_TOK.shutdown, exitpriority=100)
|
33 |
+
|
34 |
+
|
35 |
+
def tokenize(text):
|
36 |
+
global PROCESS_TOK
|
37 |
+
return PROCESS_TOK.tokenize(text)
|
38 |
+
|
39 |
+
|
40 |
+
# ------------------------------------------------------------------------------
|
41 |
+
# Predictor class.
|
42 |
+
# ------------------------------------------------------------------------------
|
43 |
+
|
44 |
+
|
45 |
+
class Predictor(object):
|
46 |
+
"""Load a pretrained DocReader model and predict inputs on the fly."""
|
47 |
+
|
48 |
+
def __init__(self, model=None, tokenizer=None, normalize=True,
|
49 |
+
embedding_file=None, num_workers=None):
|
50 |
+
"""
|
51 |
+
Args:
|
52 |
+
model: path to saved model file.
|
53 |
+
tokenizer: option string to select tokenizer class.
|
54 |
+
normalize: squash output score to 0-1 probabilities with a softmax.
|
55 |
+
embedding_file: if provided, will expand dictionary to use all
|
56 |
+
available pretrained vectors in this file.
|
57 |
+
num_workers: number of CPU processes to use to preprocess batches.
|
58 |
+
"""
|
59 |
+
logger.info('Initializing model...')
|
60 |
+
self.model = DocReader.load(model or DEFAULTS['model'],
|
61 |
+
normalize=normalize)
|
62 |
+
|
63 |
+
if embedding_file:
|
64 |
+
logger.info('Expanding dictionary...')
|
65 |
+
words = utils.index_embedding_words(embedding_file)
|
66 |
+
added = self.model.expand_dictionary(words)
|
67 |
+
self.model.load_embeddings(added, embedding_file)
|
68 |
+
|
69 |
+
logger.info('Initializing tokenizer...')
|
70 |
+
annotators = tokenizers.get_annotators_for_model(self.model)
|
71 |
+
if not tokenizer:
|
72 |
+
tokenizer_class = DEFAULTS['tokenizer']
|
73 |
+
else:
|
74 |
+
tokenizer_class = tokenizers.get_class(tokenizer)
|
75 |
+
|
76 |
+
if num_workers is None or num_workers > 0:
|
77 |
+
self.workers = ProcessPool(
|
78 |
+
num_workers,
|
79 |
+
initializer=init,
|
80 |
+
initargs=(tokenizer_class, annotators),
|
81 |
+
)
|
82 |
+
else:
|
83 |
+
self.workers = None
|
84 |
+
self.tokenizer = tokenizer_class(annotators=annotators)
|
85 |
+
|
86 |
+
def predict(self, document, question, candidates=None, top_n=1):
|
87 |
+
"""Predict a single document - question pair."""
|
88 |
+
results = self.predict_batch([(document, question, candidates,)], top_n)
|
89 |
+
return results[0]
|
90 |
+
|
91 |
+
def predict_batch(self, batch, top_n=1):
|
92 |
+
"""Predict a batch of document - question pairs."""
|
93 |
+
documents, questions, candidates = [], [], []
|
94 |
+
for b in batch:
|
95 |
+
documents.append(b[0])
|
96 |
+
questions.append(b[1])
|
97 |
+
candidates.append(b[2] if len(b) == 3 else None)
|
98 |
+
candidates = candidates if any(candidates) else None
|
99 |
+
|
100 |
+
# Tokenize the inputs, perhaps multi-processed.
|
101 |
+
if self.workers:
|
102 |
+
q_tokens = self.workers.map_async(tokenize, questions)
|
103 |
+
d_tokens = self.workers.map_async(tokenize, documents)
|
104 |
+
q_tokens = list(q_tokens.get())
|
105 |
+
d_tokens = list(d_tokens.get())
|
106 |
+
else:
|
107 |
+
q_tokens = list(map(self.tokenizer.tokenize, questions))
|
108 |
+
d_tokens = list(map(self.tokenizer.tokenize, documents))
|
109 |
+
|
110 |
+
examples = []
|
111 |
+
for i in range(len(questions)):
|
112 |
+
examples.append({
|
113 |
+
'id': i,
|
114 |
+
'question': q_tokens[i].words(),
|
115 |
+
'qlemma': q_tokens[i].lemmas(),
|
116 |
+
'document': d_tokens[i].words(),
|
117 |
+
'lemma': d_tokens[i].lemmas(),
|
118 |
+
'pos': d_tokens[i].pos(),
|
119 |
+
'ner': d_tokens[i].entities(),
|
120 |
+
})
|
121 |
+
|
122 |
+
# Stick document tokens in candidates for decoding
|
123 |
+
if candidates:
|
124 |
+
candidates = [{'input': d_tokens[i], 'cands': candidates[i]}
|
125 |
+
for i in range(len(candidates))]
|
126 |
+
|
127 |
+
# Build the batch and run it through the model
|
128 |
+
batch_exs = batchify([vectorize(e, self.model) for e in examples])
|
129 |
+
s, e, score = self.model.predict(batch_exs, candidates, top_n)
|
130 |
+
|
131 |
+
# Retrieve the predicted spans
|
132 |
+
results = []
|
133 |
+
for i in range(len(s)):
|
134 |
+
predictions = []
|
135 |
+
for j in range(len(s[i])):
|
136 |
+
span = d_tokens[i].slice(s[i][j], e[i][j] + 1).untokenize()
|
137 |
+
predictions.append((span, score[i][j].item()))
|
138 |
+
results.append(predictions)
|
139 |
+
return results
|
140 |
+
|
141 |
+
def cuda(self):
|
142 |
+
self.model.cuda()
|
143 |
+
|
144 |
+
def cpu(self):
|
145 |
+
self.model.cpu()
|
drqa/reader/rnn_reader.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright 2017-present, Facebook, Inc.
|
3 |
+
# All rights reserved.
|
4 |
+
#
|
5 |
+
# This source code is licensed under the license found in the
|
6 |
+
# LICENSE file in the root directory of this source tree.
|
7 |
+
"""Implementation of the RNN based DrQA reader."""
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from . import layers
|
12 |
+
|
13 |
+
|
14 |
+
# ------------------------------------------------------------------------------
|
15 |
+
# Network
|
16 |
+
# ------------------------------------------------------------------------------
|
17 |
+
|
18 |
+
|
19 |
+
class RnnDocReader(nn.Module):
|
20 |
+
RNN_TYPES = {'lstm': nn.LSTM, 'gru': nn.GRU, 'rnn': nn.RNN}
|
21 |
+
|
22 |
+
def __init__(self, args, normalize=True):
|
23 |
+
super(RnnDocReader, self).__init__()
|
24 |
+
# Store config
|
25 |
+
self.args = args
|
26 |
+
|
27 |
+
# Word embeddings (+1 for padding)
|
28 |
+
self.embedding = nn.Embedding(args.vocab_size,
|
29 |
+
args.embedding_dim,
|
30 |
+
padding_idx=0)
|
31 |
+
|
32 |
+
# Projection for attention weighted question
|
33 |
+
if args.use_qemb:
|
34 |
+
self.qemb_match = layers.SeqAttnMatch(args.embedding_dim)
|
35 |
+
|
36 |
+
# Input size to RNN: word emb + question emb + manual features
|
37 |
+
doc_input_size = args.embedding_dim + args.num_features
|
38 |
+
if args.use_qemb:
|
39 |
+
doc_input_size += args.embedding_dim
|
40 |
+
|
41 |
+
# RNN document encoder
|
42 |
+
self.doc_rnn = layers.StackedBRNN(
|
43 |
+
input_size=doc_input_size,
|
44 |
+
hidden_size=args.hidden_size,
|
45 |
+
num_layers=args.doc_layers,
|
46 |
+
dropout_rate=args.dropout_rnn,
|
47 |
+
dropout_output=args.dropout_rnn_output,
|
48 |
+
concat_layers=args.concat_rnn_layers,
|
49 |
+
rnn_type=self.RNN_TYPES[args.rnn_type],
|
50 |
+
padding=args.rnn_padding,
|
51 |
+
)
|
52 |
+
|
53 |
+
# RNN question encoder
|
54 |
+
self.question_rnn = layers.StackedBRNN(
|
55 |
+
input_size=args.embedding_dim,
|
56 |
+
hidden_size=args.hidden_size,
|
57 |
+
num_layers=args.question_layers,
|
58 |
+
dropout_rate=args.dropout_rnn,
|
59 |
+
dropout_output=args.dropout_rnn_output,
|
60 |
+
concat_layers=args.concat_rnn_layers,
|
61 |
+
rnn_type=self.RNN_TYPES[args.rnn_type],
|
62 |
+
padding=args.rnn_padding,
|
63 |
+
)
|
64 |
+
|
65 |
+
# Output sizes of rnn encoders
|
66 |
+
doc_hidden_size = 2 * args.hidden_size
|
67 |
+
question_hidden_size = 2 * args.hidden_size
|
68 |
+
if args.concat_rnn_layers:
|
69 |
+
doc_hidden_size *= args.doc_layers
|
70 |
+
question_hidden_size *= args.question_layers
|
71 |
+
|
72 |
+
# Question merging
|
73 |
+
if args.question_merge not in ['avg', 'self_attn']:
|
74 |
+
raise NotImplementedError('merge_mode = %s' % args.merge_mode)
|
75 |
+
if args.question_merge == 'self_attn':
|
76 |
+
self.self_attn = layers.LinearSeqAttn(question_hidden_size)
|
77 |
+
|
78 |
+
# Bilinear attention for span start/end
|
79 |
+
self.start_attn = layers.BilinearSeqAttn(
|
80 |
+
doc_hidden_size,
|
81 |
+
question_hidden_size,
|
82 |
+
normalize=normalize,
|
83 |
+
)
|
84 |
+
self.end_attn = layers.BilinearSeqAttn(
|
85 |
+
doc_hidden_size,
|
86 |
+
question_hidden_size,
|
87 |
+
normalize=normalize,
|
88 |
+
)
|
89 |
+
|
90 |
+
def forward(self, x1, x1_f, x1_mask, x2, x2_mask):
|
91 |
+
"""Inputs:
|
92 |
+
x1 = document word indices [batch * len_d]
|
93 |
+
x1_f = document word features indices [batch * len_d * nfeat]
|
94 |
+
x1_mask = document padding mask [batch * len_d]
|
95 |
+
x2 = question word indices [batch * len_q]
|
96 |
+
x2_mask = question padding mask [batch * len_q]
|
97 |
+
"""
|
98 |
+
# Embed both document and question
|
99 |
+
x1_emb = self.embedding(x1)
|
100 |
+
x2_emb = self.embedding(x2)
|
101 |
+
|
102 |
+
# Dropout on embeddings
|
103 |
+
if self.args.dropout_emb > 0:
|
104 |
+
x1_emb = nn.functional.dropout(x1_emb, p=self.args.dropout_emb,
|
105 |
+
training=self.training)
|
106 |
+
x2_emb = nn.functional.dropout(x2_emb, p=self.args.dropout_emb,
|
107 |
+
training=self.training)
|
108 |
+
|
109 |
+
# Form document encoding inputs
|
110 |
+
drnn_input = [x1_emb]
|
111 |
+
|
112 |
+
# Add attention-weighted question representation
|
113 |
+
if self.args.use_qemb:
|
114 |
+
x2_weighted_emb = self.qemb_match(x1_emb, x2_emb, x2_mask)
|
115 |
+
drnn_input.append(x2_weighted_emb)
|
116 |
+
|
117 |
+
# Add manual features
|
118 |
+
if self.args.num_features > 0:
|
119 |
+
drnn_input.append(x1_f)
|
120 |
+
|
121 |
+
# Encode document with RNN
|
122 |
+
doc_hiddens = self.doc_rnn(torch.cat(drnn_input, 2), x1_mask)
|
123 |
+
|
124 |
+
# Encode question with RNN + merge hiddens
|
125 |
+
question_hiddens = self.question_rnn(x2_emb, x2_mask)
|
126 |
+
if self.args.question_merge == 'avg':
|
127 |
+
q_merge_weights = layers.uniform_weights(question_hiddens, x2_mask)
|
128 |
+
elif self.args.question_merge == 'self_attn':
|
129 |
+
q_merge_weights = self.self_attn(question_hiddens, x2_mask)
|
130 |
+
question_hidden = layers.weighted_avg(question_hiddens, q_merge_weights)
|
131 |
+
|
132 |
+
# Predict start and end positions
|
133 |
+
start_scores = self.start_attn(doc_hiddens, question_hidden, x1_mask)
|
134 |
+
end_scores = self.end_attn(doc_hiddens, question_hidden, x1_mask)
|
135 |
+
return start_scores, end_scores
|
drqa/reader/utils.py
ADDED
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright 2017-present, Facebook, Inc.
|
3 |
+
# All rights reserved.
|
4 |
+
#
|
5 |
+
# This source code is licensed under the license found in the
|
6 |
+
# LICENSE file in the root directory of this source tree.
|
7 |
+
"""DrQA reader utilities."""
|
8 |
+
|
9 |
+
import json
|
10 |
+
import time
|
11 |
+
import logging
|
12 |
+
import string
|
13 |
+
import regex as re
|
14 |
+
|
15 |
+
from collections import Counter
|
16 |
+
from .data import Dictionary
|
17 |
+
|
18 |
+
logger = logging.getLogger(__name__)
|
19 |
+
|
20 |
+
|
21 |
+
# ------------------------------------------------------------------------------
|
22 |
+
# Data loading
|
23 |
+
# ------------------------------------------------------------------------------
|
24 |
+
|
25 |
+
|
26 |
+
def load_data(args, filename, skip_no_answer=False):
|
27 |
+
"""Load examples from preprocessed file.
|
28 |
+
One example per line, JSON encoded.
|
29 |
+
"""
|
30 |
+
# Load JSON lines
|
31 |
+
with open(filename) as f:
|
32 |
+
examples = [json.loads(line) for line in f]
|
33 |
+
|
34 |
+
# Make case insensitive?
|
35 |
+
if args.uncased_question or args.uncased_doc:
|
36 |
+
for ex in examples:
|
37 |
+
if args.uncased_question:
|
38 |
+
ex['question'] = [w.lower() for w in ex['question']]
|
39 |
+
if args.uncased_doc:
|
40 |
+
ex['document'] = [w.lower() for w in ex['document']]
|
41 |
+
|
42 |
+
# Skip unparsed (start/end) examples
|
43 |
+
if skip_no_answer:
|
44 |
+
examples = [ex for ex in examples if len(ex['answers']) > 0]
|
45 |
+
|
46 |
+
return examples
|
47 |
+
|
48 |
+
|
49 |
+
def load_text(filename):
|
50 |
+
"""Load the paragraphs only of a SQuAD dataset. Store as qid -> text."""
|
51 |
+
# Load JSON file
|
52 |
+
with open(filename) as f:
|
53 |
+
examples = json.load(f)['data']
|
54 |
+
|
55 |
+
texts = {}
|
56 |
+
for article in examples:
|
57 |
+
for paragraph in article['paragraphs']:
|
58 |
+
for qa in paragraph['qas']:
|
59 |
+
texts[qa['id']] = paragraph['context']
|
60 |
+
return texts
|
61 |
+
|
62 |
+
|
63 |
+
def load_answers(filename):
|
64 |
+
"""Load the answers only of a SQuAD dataset. Store as qid -> [answers]."""
|
65 |
+
# Load JSON file
|
66 |
+
with open(filename) as f:
|
67 |
+
examples = json.load(f)['data']
|
68 |
+
|
69 |
+
ans = {}
|
70 |
+
for article in examples:
|
71 |
+
for paragraph in article['paragraphs']:
|
72 |
+
for qa in paragraph['qas']:
|
73 |
+
ans[qa['id']] = list(map(lambda x: x['text'], qa['answers']))
|
74 |
+
return ans
|
75 |
+
|
76 |
+
|
77 |
+
# ------------------------------------------------------------------------------
|
78 |
+
# Dictionary building
|
79 |
+
# ------------------------------------------------------------------------------
|
80 |
+
|
81 |
+
|
82 |
+
def index_embedding_words(embedding_file):
|
83 |
+
"""Put all the words in embedding_file into a set."""
|
84 |
+
words = set()
|
85 |
+
with open(embedding_file) as f:
|
86 |
+
for line in f:
|
87 |
+
w = Dictionary.normalize(line.rstrip().split(' ')[0])
|
88 |
+
words.add(w)
|
89 |
+
return words
|
90 |
+
|
91 |
+
|
92 |
+
def load_words(args, examples):
|
93 |
+
"""Iterate and index all the words in examples (documents + questions)."""
|
94 |
+
def _insert(iterable):
|
95 |
+
for w in iterable:
|
96 |
+
w = Dictionary.normalize(w)
|
97 |
+
if valid_words and w not in valid_words:
|
98 |
+
continue
|
99 |
+
words.add(w)
|
100 |
+
|
101 |
+
if args.restrict_vocab and args.embedding_file:
|
102 |
+
logger.info('Restricting to words in %s' % args.embedding_file)
|
103 |
+
valid_words = index_embedding_words(args.embedding_file)
|
104 |
+
logger.info('Num words in set = %d' % len(valid_words))
|
105 |
+
else:
|
106 |
+
valid_words = None
|
107 |
+
|
108 |
+
words = set()
|
109 |
+
for ex in examples:
|
110 |
+
_insert(ex['question'])
|
111 |
+
_insert(ex['document'])
|
112 |
+
return words
|
113 |
+
|
114 |
+
|
115 |
+
def build_word_dict(args, examples):
|
116 |
+
"""Return a dictionary from question and document words in
|
117 |
+
provided examples.
|
118 |
+
"""
|
119 |
+
word_dict = Dictionary()
|
120 |
+
for w in load_words(args, examples):
|
121 |
+
word_dict.add(w)
|
122 |
+
return word_dict
|
123 |
+
|
124 |
+
|
125 |
+
def top_question_words(args, examples, word_dict):
|
126 |
+
"""Count and return the most common question words in provided examples."""
|
127 |
+
word_count = Counter()
|
128 |
+
for ex in examples:
|
129 |
+
for w in ex['question']:
|
130 |
+
w = Dictionary.normalize(w)
|
131 |
+
if w in word_dict:
|
132 |
+
word_count.update([w])
|
133 |
+
return word_count.most_common(args.tune_partial)
|
134 |
+
|
135 |
+
|
136 |
+
def build_feature_dict(args, examples):
|
137 |
+
"""Index features (one hot) from fields in examples and options."""
|
138 |
+
def _insert(feature):
|
139 |
+
if feature not in feature_dict:
|
140 |
+
feature_dict[feature] = len(feature_dict)
|
141 |
+
|
142 |
+
feature_dict = {}
|
143 |
+
|
144 |
+
# Exact match features
|
145 |
+
if args.use_in_question:
|
146 |
+
_insert('in_question')
|
147 |
+
_insert('in_question_uncased')
|
148 |
+
if args.use_lemma:
|
149 |
+
_insert('in_question_lemma')
|
150 |
+
|
151 |
+
# Part of speech tag features
|
152 |
+
if args.use_pos:
|
153 |
+
for ex in examples:
|
154 |
+
for w in ex['pos']:
|
155 |
+
_insert('pos=%s' % w)
|
156 |
+
|
157 |
+
# Named entity tag features
|
158 |
+
if args.use_ner:
|
159 |
+
for ex in examples:
|
160 |
+
for w in ex['ner']:
|
161 |
+
_insert('ner=%s' % w)
|
162 |
+
|
163 |
+
# Term frequency feature
|
164 |
+
if args.use_tf:
|
165 |
+
_insert('tf')
|
166 |
+
return feature_dict
|
167 |
+
|
168 |
+
|
169 |
+
# ------------------------------------------------------------------------------
|
170 |
+
# Evaluation. Follows official evalutation script for v1.1 of the SQuAD dataset.
|
171 |
+
# ------------------------------------------------------------------------------
|
172 |
+
|
173 |
+
|
174 |
+
def normalize_answer(s):
|
175 |
+
"""Lower text and remove punctuation, articles and extra whitespace."""
|
176 |
+
def remove_articles(text):
|
177 |
+
return re.sub(r'\b(a|an|the)\b', ' ', text)
|
178 |
+
|
179 |
+
def white_space_fix(text):
|
180 |
+
return ' '.join(text.split())
|
181 |
+
|
182 |
+
def remove_punc(text):
|
183 |
+
exclude = set(string.punctuation)
|
184 |
+
return ''.join(ch for ch in text if ch not in exclude)
|
185 |
+
|
186 |
+
def lower(text):
|
187 |
+
return text.lower()
|
188 |
+
|
189 |
+
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
190 |
+
|
191 |
+
|
192 |
+
def f1_score(prediction, ground_truth):
|
193 |
+
"""Compute the geometric mean of precision and recall for answer tokens."""
|
194 |
+
prediction_tokens = normalize_answer(prediction).split()
|
195 |
+
ground_truth_tokens = normalize_answer(ground_truth).split()
|
196 |
+
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
|
197 |
+
num_same = sum(common.values())
|
198 |
+
if num_same == 0:
|
199 |
+
return 0
|
200 |
+
precision = 1.0 * num_same / len(prediction_tokens)
|
201 |
+
recall = 1.0 * num_same / len(ground_truth_tokens)
|
202 |
+
f1 = (2 * precision * recall) / (precision + recall)
|
203 |
+
return f1
|
204 |
+
|
205 |
+
|
206 |
+
def exact_match_score(prediction, ground_truth):
|
207 |
+
"""Check if the prediction is a (soft) exact match with the ground truth."""
|
208 |
+
return normalize_answer(prediction) == normalize_answer(ground_truth)
|
209 |
+
|
210 |
+
|
211 |
+
def regex_match_score(prediction, pattern):
|
212 |
+
"""Check if the prediction matches the given regular expression."""
|
213 |
+
try:
|
214 |
+
compiled = re.compile(
|
215 |
+
pattern,
|
216 |
+
flags=re.IGNORECASE + re.UNICODE + re.MULTILINE
|
217 |
+
)
|
218 |
+
except BaseException:
|
219 |
+
logger.warn('Regular expression failed to compile: %s' % pattern)
|
220 |
+
return False
|
221 |
+
return compiled.match(prediction) is not None
|
222 |
+
|
223 |
+
|
224 |
+
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
|
225 |
+
"""Given a prediction and multiple valid answers, return the score of
|
226 |
+
the best prediction-answer_n pair given a metric function.
|
227 |
+
"""
|
228 |
+
scores_for_ground_truths = []
|
229 |
+
for ground_truth in ground_truths:
|
230 |
+
score = metric_fn(prediction, ground_truth)
|
231 |
+
scores_for_ground_truths.append(score)
|
232 |
+
return max(scores_for_ground_truths)
|
233 |
+
|
234 |
+
|
235 |
+
# ------------------------------------------------------------------------------
|
236 |
+
# Utility classes
|
237 |
+
# ------------------------------------------------------------------------------
|
238 |
+
|
239 |
+
|
240 |
+
class AverageMeter(object):
|
241 |
+
"""Computes and stores the average and current value."""
|
242 |
+
|
243 |
+
def __init__(self):
|
244 |
+
self.reset()
|
245 |
+
|
246 |
+
def reset(self):
|
247 |
+
self.val = 0
|
248 |
+
self.avg = 0
|
249 |
+
self.sum = 0
|
250 |
+
self.count = 0
|
251 |
+
|
252 |
+
def update(self, val, n=1):
|
253 |
+
self.val = val
|
254 |
+
self.sum += val * n
|
255 |
+
self.count += n
|
256 |
+
self.avg = self.sum / self.count
|
257 |
+
|
258 |
+
|
259 |
+
class Timer(object):
|
260 |
+
"""Computes elapsed time."""
|
261 |
+
|
262 |
+
def __init__(self):
|
263 |
+
self.running = True
|
264 |
+
self.total = 0
|
265 |
+
self.start = time.time()
|
266 |
+
|
267 |
+
def reset(self):
|
268 |
+
self.running = True
|
269 |
+
self.total = 0
|
270 |
+
self.start = time.time()
|
271 |
+
return self
|
272 |
+
|
273 |
+
def resume(self):
|
274 |
+
if not self.running:
|
275 |
+
self.running = True
|
276 |
+
self.start = time.time()
|
277 |
+
return self
|
278 |
+
|
279 |
+
def stop(self):
|
280 |
+
if self.running:
|
281 |
+
self.running = False
|
282 |
+
self.total += time.time() - self.start
|
283 |
+
return self
|
284 |
+
|
285 |
+
def time(self):
|
286 |
+
if self.running:
|
287 |
+
return self.total + time.time() - self.start
|
288 |
+
return self.total
|
drqa/reader/vector.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright 2017-present, Facebook, Inc.
|
3 |
+
# All rights reserved.
|
4 |
+
#
|
5 |
+
# This source code is licensed under the license found in the
|
6 |
+
# LICENSE file in the root directory of this source tree.
|
7 |
+
"""Functions for putting examples into torch format."""
|
8 |
+
|
9 |
+
from collections import Counter
|
10 |
+
import torch
|
11 |
+
|
12 |
+
|
13 |
+
def vectorize(ex, model, single_answer=False):
|
14 |
+
"""Torchify a single example."""
|
15 |
+
args = model.args
|
16 |
+
word_dict = model.word_dict
|
17 |
+
feature_dict = model.feature_dict
|
18 |
+
|
19 |
+
# Index words
|
20 |
+
document = torch.LongTensor([word_dict[w] for w in ex['document']])
|
21 |
+
question = torch.LongTensor([word_dict[w] for w in ex['question']])
|
22 |
+
|
23 |
+
# Create extra features vector
|
24 |
+
if len(feature_dict) > 0:
|
25 |
+
features = torch.zeros(len(ex['document']), len(feature_dict))
|
26 |
+
else:
|
27 |
+
features = None
|
28 |
+
|
29 |
+
# f_{exact_match}
|
30 |
+
if args.use_in_question:
|
31 |
+
q_words_cased = {w for w in ex['question']}
|
32 |
+
q_words_uncased = {w.lower() for w in ex['question']}
|
33 |
+
q_lemma = {w for w in ex['qlemma']} if args.use_lemma else None
|
34 |
+
for i in range(len(ex['document'])):
|
35 |
+
if ex['document'][i] in q_words_cased:
|
36 |
+
features[i][feature_dict['in_question']] = 1.0
|
37 |
+
if ex['document'][i].lower() in q_words_uncased:
|
38 |
+
features[i][feature_dict['in_question_uncased']] = 1.0
|
39 |
+
if q_lemma and ex['lemma'][i] in q_lemma:
|
40 |
+
features[i][feature_dict['in_question_lemma']] = 1.0
|
41 |
+
|
42 |
+
# f_{token} (POS)
|
43 |
+
if args.use_pos:
|
44 |
+
for i, w in enumerate(ex['pos']):
|
45 |
+
f = 'pos=%s' % w
|
46 |
+
if f in feature_dict:
|
47 |
+
features[i][feature_dict[f]] = 1.0
|
48 |
+
|
49 |
+
# f_{token} (NER)
|
50 |
+
if args.use_ner:
|
51 |
+
for i, w in enumerate(ex['ner']):
|
52 |
+
f = 'ner=%s' % w
|
53 |
+
if f in feature_dict:
|
54 |
+
features[i][feature_dict[f]] = 1.0
|
55 |
+
|
56 |
+
# f_{token} (TF)
|
57 |
+
if args.use_tf:
|
58 |
+
counter = Counter([w.lower() for w in ex['document']])
|
59 |
+
l = len(ex['document'])
|
60 |
+
for i, w in enumerate(ex['document']):
|
61 |
+
features[i][feature_dict['tf']] = counter[w.lower()] * 1.0 / l
|
62 |
+
|
63 |
+
# Maybe return without target
|
64 |
+
if 'answers' not in ex:
|
65 |
+
return document, features, question, ex['id']
|
66 |
+
|
67 |
+
# ...or with target(s) (might still be empty if answers is empty)
|
68 |
+
if single_answer:
|
69 |
+
assert(len(ex['answers']) > 0)
|
70 |
+
start = torch.LongTensor(1).fill_(ex['answers'][0][0])
|
71 |
+
end = torch.LongTensor(1).fill_(ex['answers'][0][1])
|
72 |
+
else:
|
73 |
+
start = [a[0] for a in ex['answers']]
|
74 |
+
end = [a[1] for a in ex['answers']]
|
75 |
+
|
76 |
+
return document, features, question, start, end, ex['id']
|
77 |
+
|
78 |
+
|
79 |
+
def batchify(batch):
|
80 |
+
"""Gather a batch of individual examples into one batch."""
|
81 |
+
NUM_INPUTS = 3
|
82 |
+
NUM_TARGETS = 2
|
83 |
+
NUM_EXTRA = 1
|
84 |
+
|
85 |
+
ids = [ex[-1] for ex in batch]
|
86 |
+
docs = [ex[0] for ex in batch]
|
87 |
+
features = [ex[1] for ex in batch]
|
88 |
+
questions = [ex[2] for ex in batch]
|
89 |
+
|
90 |
+
# Batch documents and features
|
91 |
+
max_length = max([d.size(0) for d in docs])
|
92 |
+
x1 = torch.LongTensor(len(docs), max_length).zero_()
|
93 |
+
x1_mask = torch.ByteTensor(len(docs), max_length).fill_(1)
|
94 |
+
if features[0] is None:
|
95 |
+
x1_f = None
|
96 |
+
else:
|
97 |
+
x1_f = torch.zeros(len(docs), max_length, features[0].size(1))
|
98 |
+
for i, d in enumerate(docs):
|
99 |
+
x1[i, :d.size(0)].copy_(d)
|
100 |
+
x1_mask[i, :d.size(0)].fill_(0)
|
101 |
+
if x1_f is not None:
|
102 |
+
x1_f[i, :d.size(0)].copy_(features[i])
|
103 |
+
|
104 |
+
# Batch questions
|
105 |
+
max_length = max([q.size(0) for q in questions])
|
106 |
+
x2 = torch.LongTensor(len(questions), max_length).zero_()
|
107 |
+
x2_mask = torch.ByteTensor(len(questions), max_length).fill_(1)
|
108 |
+
for i, q in enumerate(questions):
|
109 |
+
x2[i, :q.size(0)].copy_(q)
|
110 |
+
x2_mask[i, :q.size(0)].fill_(0)
|
111 |
+
|
112 |
+
# Maybe return without targets
|
113 |
+
if len(batch[0]) == NUM_INPUTS + NUM_EXTRA:
|
114 |
+
return x1, x1_f, x1_mask, x2, x2_mask, ids
|
115 |
+
|
116 |
+
elif len(batch[0]) == NUM_INPUTS + NUM_EXTRA + NUM_TARGETS:
|
117 |
+
# ...Otherwise add targets
|
118 |
+
if torch.is_tensor(batch[0][3]):
|
119 |
+
y_s = torch.cat([ex[3] for ex in batch])
|
120 |
+
y_e = torch.cat([ex[4] for ex in batch])
|
121 |
+
else:
|
122 |
+
y_s = [ex[3] for ex in batch]
|
123 |
+
y_e = [ex[4] for ex in batch]
|
124 |
+
else:
|
125 |
+
raise RuntimeError('Incorrect number of inputs per example.')
|
126 |
+
|
127 |
+
return x1, x1_f, x1_mask, x2, x2_mask, y_s, y_e, ids
|
drqa/retriever/__init__.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright 2017-present, Facebook, Inc.
|
3 |
+
# All rights reserved.
|
4 |
+
#
|
5 |
+
# This source code is licensed under the license found in the
|
6 |
+
# LICENSE file in the root directory of this source tree.
|
7 |
+
|
8 |
+
import os
|
9 |
+
from .. import DATA_DIR
|
10 |
+
|
11 |
+
DEFAULTS = {
|
12 |
+
'db_path': os.path.join(DATA_DIR, 'wikipedia/docs.db'),
|
13 |
+
'tfidf_path': os.path.join(
|
14 |
+
DATA_DIR,
|
15 |
+
'wikipedia/docs-tfidf-ngram=2-hash=16777216-tokenizer=simple.npz'
|
16 |
+
),
|
17 |
+
'elastic_url': 'localhost:9200'
|
18 |
+
}
|
19 |
+
|
20 |
+
|
21 |
+
def set_default(key, value):
|
22 |
+
global DEFAULTS
|
23 |
+
DEFAULTS[key] = value
|
24 |
+
|
25 |
+
|
26 |
+
def get_class(name):
|
27 |
+
if name == 'tfidf':
|
28 |
+
return TfidfDocRanker
|
29 |
+
if name == 'sqlite':
|
30 |
+
return DocDB
|
31 |
+
if name == 'elasticsearch':
|
32 |
+
return ElasticDocRanker
|
33 |
+
raise RuntimeError('Invalid retriever class: %s' % name)
|
34 |
+
|
35 |
+
|
36 |
+
from .doc_db import DocDB
|
37 |
+
from .tfidf_doc_ranker import TfidfDocRanker
|
38 |
+
from .elastic_doc_ranker import ElasticDocRanker
|
drqa/retriever/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (975 Bytes). View file
|
|
drqa/retriever/__pycache__/doc_db.cpython-38.pyc
ADDED
Binary file (2.67 kB). View file
|
|
drqa/retriever/__pycache__/elastic_doc_ranker.cpython-38.pyc
ADDED
Binary file (4.64 kB). View file
|
|
drqa/retriever/__pycache__/tfidf_doc_ranker.cpython-38.pyc
ADDED
Binary file (4.26 kB). View file
|
|
drqa/retriever/__pycache__/utils.cpython-38.pyc
ADDED
Binary file (4.23 kB). View file
|
|
drqa/retriever/doc_db.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright 2017-present, Facebook, Inc.
|
3 |
+
# All rights reserved.
|
4 |
+
#
|
5 |
+
# This source code is licensed under the license found in the
|
6 |
+
# LICENSE file in the root directory of this source tree.
|
7 |
+
"""Documents, in a sqlite database."""
|
8 |
+
|
9 |
+
import sqlite3
|
10 |
+
from . import utils
|
11 |
+
from . import DEFAULTS
|
12 |
+
|
13 |
+
|
14 |
+
class DocDB(object):
|
15 |
+
"""Sqlite backed document storage.
|
16 |
+
|
17 |
+
Implements get_doc_text(doc_id).
|
18 |
+
"""
|
19 |
+
|
20 |
+
def __init__(self, db_path=None):
|
21 |
+
self.path = db_path or DEFAULTS['db_path']
|
22 |
+
self.connection = sqlite3.connect(self.path, check_same_thread=False)
|
23 |
+
|
24 |
+
def __enter__(self):
|
25 |
+
return self
|
26 |
+
|
27 |
+
def __exit__(self, *args):
|
28 |
+
self.close()
|
29 |
+
|
30 |
+
def path(self):
|
31 |
+
"""Return the path to the file that backs this database."""
|
32 |
+
return self.path
|
33 |
+
|
34 |
+
def close(self):
|
35 |
+
"""Close the connection to the database."""
|
36 |
+
self.connection.close()
|
37 |
+
|
38 |
+
def get_doc_ids(self):
|
39 |
+
"""Fetch all ids of docs stored in the db."""
|
40 |
+
cursor = self.connection.cursor()
|
41 |
+
cursor.execute("SELECT id FROM documents")
|
42 |
+
results = [r[0] for r in cursor.fetchall()]
|
43 |
+
cursor.close()
|
44 |
+
return results
|
45 |
+
|
46 |
+
def get_doc_text(self, doc_id):
|
47 |
+
"""Fetch the raw text of the doc for 'doc_id'."""
|
48 |
+
cursor = self.connection.cursor()
|
49 |
+
cursor.execute(
|
50 |
+
"SELECT text FROM documents WHERE id = ?",
|
51 |
+
(utils.normalize(doc_id), )
|
52 |
+
# (doc_id, )
|
53 |
+
)
|
54 |
+
result = cursor.fetchone()
|
55 |
+
cursor.close()
|
56 |
+
return result if result is None else result[0]
|
57 |
+
|
58 |
+
|
59 |
+
def get_doc_title(self, doc_id):
|
60 |
+
"""Fetch the raw text of the doc for 'doc_id'."""
|
61 |
+
cursor = self.connection.cursor()
|
62 |
+
cursor.execute(
|
63 |
+
"SELECT title FROM documents WHERE id = ?",
|
64 |
+
(utils.normalize(doc_id),)
|
65 |
+
# (doc_id, )
|
66 |
+
)
|
67 |
+
result = cursor.fetchone()
|
68 |
+
cursor.close()
|
69 |
+
return result if result is None else result[0]
|
70 |
+
|
71 |
+
def get_doc_intro(self, doc_id):
|
72 |
+
"""Fetch the raw text of the doc for 'doc_id'."""
|
73 |
+
cursor = self.connection.cursor()
|
74 |
+
cursor.execute(
|
75 |
+
"SELECT intro FROM documents WHERE id = ?", # intro: the introduction of Wikipedia page
|
76 |
+
(utils.normalize(doc_id),)
|
77 |
+
# (doc_id, )
|
78 |
+
)
|
79 |
+
result = cursor.fetchone()
|
80 |
+
cursor.close()
|
81 |
+
return result if result is None else result[0]
|
drqa/retriever/elastic_doc_ranker.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright 2017-present, Facebook, Inc.
|
3 |
+
# All rights reserved.
|
4 |
+
#
|
5 |
+
# This source code is licensed under the license found in the
|
6 |
+
# LICENSE file in the root directory of this source tree.
|
7 |
+
"""Rank documents with an ElasticSearch index"""
|
8 |
+
|
9 |
+
import logging
|
10 |
+
import scipy.sparse as sp
|
11 |
+
|
12 |
+
from multiprocessing.pool import ThreadPool
|
13 |
+
from functools import partial
|
14 |
+
from elasticsearch import Elasticsearch
|
15 |
+
|
16 |
+
from . import utils
|
17 |
+
from . import DEFAULTS
|
18 |
+
from .. import tokenizers
|
19 |
+
|
20 |
+
logger = logging.getLogger(__name__)
|
21 |
+
|
22 |
+
|
23 |
+
class ElasticDocRanker(object):
|
24 |
+
""" Connect to an ElasticSearch index.
|
25 |
+
Score pairs based on Elasticsearch
|
26 |
+
"""
|
27 |
+
|
28 |
+
def __init__(self, elastic_url=None, elastic_index=None, elastic_fields=None, elastic_field_doc_name=None, strict=True, elastic_field_content=None):
|
29 |
+
"""
|
30 |
+
Args:
|
31 |
+
elastic_url: URL of the ElasticSearch server containing port
|
32 |
+
elastic_index: Index name of ElasticSearch
|
33 |
+
elastic_fields: Fields of the Elasticsearch index to search in
|
34 |
+
elastic_field_doc_name: Field containing the name of the document (index)
|
35 |
+
strict: fail on empty queries or continue (and return empty result)
|
36 |
+
elastic_field_content: Field containing the content of document in plaint text
|
37 |
+
"""
|
38 |
+
# Load from disk
|
39 |
+
elastic_url = elastic_url or DEFAULTS['elastic_url']
|
40 |
+
logger.info('Connecting to %s' % elastic_url)
|
41 |
+
self.es = Elasticsearch(hosts=elastic_url)
|
42 |
+
self.elastic_index = elastic_index
|
43 |
+
self.elastic_fields = elastic_fields
|
44 |
+
self.elastic_field_doc_name = elastic_field_doc_name
|
45 |
+
self.elastic_field_content = elastic_field_content
|
46 |
+
self.strict = strict
|
47 |
+
|
48 |
+
# Elastic Ranker
|
49 |
+
|
50 |
+
def get_doc_index(self, doc_id):
|
51 |
+
"""Convert doc_id --> doc_index"""
|
52 |
+
field_index = self.elastic_field_doc_name
|
53 |
+
if isinstance(field_index, list):
|
54 |
+
field_index = '.'.join(field_index)
|
55 |
+
result = self.es.search(index=self.elastic_index, body={'query':{'match':
|
56 |
+
{field_index: doc_id}}})
|
57 |
+
return result['hits']['hits'][0]['_id']
|
58 |
+
|
59 |
+
|
60 |
+
def get_doc_id(self, doc_index):
|
61 |
+
"""Convert doc_index --> doc_id"""
|
62 |
+
result = self.es.search(index=self.elastic_index, body={'query': { 'match': {"_id": doc_index}}})
|
63 |
+
source = result['hits']['hits'][0]['_source']
|
64 |
+
return utils.get_field(source, self.elastic_field_doc_name)
|
65 |
+
|
66 |
+
def closest_docs(self, query, k=1):
|
67 |
+
"""Closest docs by using ElasticSearch
|
68 |
+
"""
|
69 |
+
results = self.es.search(index=self.elastic_index, body={'size':k ,'query':
|
70 |
+
{'multi_match': {
|
71 |
+
'query': query,
|
72 |
+
'type': 'most_fields',
|
73 |
+
'fields': self.elastic_fields}}})
|
74 |
+
hits = results['hits']['hits']
|
75 |
+
doc_ids = [utils.get_field(row['_source'], self.elastic_field_doc_name) for row in hits]
|
76 |
+
doc_scores = [row['_score'] for row in hits]
|
77 |
+
return doc_ids, doc_scores
|
78 |
+
|
79 |
+
def batch_closest_docs(self, queries, k=1, num_workers=None):
|
80 |
+
"""Process a batch of closest_docs requests multithreaded.
|
81 |
+
Note: we can use plain threads here as scipy is outside of the GIL.
|
82 |
+
"""
|
83 |
+
with ThreadPool(num_workers) as threads:
|
84 |
+
closest_docs = partial(self.closest_docs, k=k)
|
85 |
+
results = threads.map(closest_docs, queries)
|
86 |
+
return results
|
87 |
+
|
88 |
+
# Elastic DB
|
89 |
+
|
90 |
+
def __enter__(self):
|
91 |
+
return self
|
92 |
+
|
93 |
+
def close(self):
|
94 |
+
"""Close the connection to the database."""
|
95 |
+
self.es = None
|
96 |
+
|
97 |
+
def get_doc_ids(self):
|
98 |
+
"""Fetch all ids of docs stored in the db."""
|
99 |
+
results = self.es.search(index= self.elastic_index, body={
|
100 |
+
"query": {"match_all": {}}})
|
101 |
+
doc_ids = [utils.get_field(result['_source'], self.elastic_field_doc_name) for result in results['hits']['hits']]
|
102 |
+
return doc_ids
|
103 |
+
|
104 |
+
def get_doc_text(self, doc_id):
|
105 |
+
"""Fetch the raw text of the doc for 'doc_id'."""
|
106 |
+
idx = self.get_doc_index(doc_id)
|
107 |
+
result = self.es.get(index=self.elastic_index, doc_type='_doc', id=idx)
|
108 |
+
return result if result is None else result['_source'][self.elastic_field_content]
|
109 |
+
|
drqa/retriever/tfidf_doc_ranker.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright 2017-present, Facebook, Inc.
|
3 |
+
# All rights reserved.
|
4 |
+
#
|
5 |
+
# This source code is licensed under the license found in the
|
6 |
+
# LICENSE file in the root directory of this source tree.
|
7 |
+
"""Rank documents with TF-IDF scores"""
|
8 |
+
|
9 |
+
import logging
|
10 |
+
import numpy as np
|
11 |
+
import scipy.sparse as sp
|
12 |
+
|
13 |
+
from multiprocessing.pool import ThreadPool
|
14 |
+
from functools import partial
|
15 |
+
|
16 |
+
from . import utils
|
17 |
+
from . import DEFAULTS
|
18 |
+
from .. import tokenizers
|
19 |
+
|
20 |
+
logger = logging.getLogger(__name__)
|
21 |
+
|
22 |
+
|
23 |
+
class TfidfDocRanker(object):
|
24 |
+
"""Loads a pre-weighted inverted index of token/document terms.
|
25 |
+
Scores new queries by taking sparse dot products.
|
26 |
+
"""
|
27 |
+
|
28 |
+
def __init__(self, tfidf_path=None, strict=True):
|
29 |
+
"""
|
30 |
+
Args:
|
31 |
+
tfidf_path: path to saved model file
|
32 |
+
strict: fail on empty queries or continue (and return empty result)
|
33 |
+
"""
|
34 |
+
# Load from disk
|
35 |
+
tfidf_path = tfidf_path or DEFAULTS['tfidf_path']
|
36 |
+
logger.info('Loading %s' % tfidf_path)
|
37 |
+
matrix, metadata = utils.load_sparse_csr(tfidf_path)
|
38 |
+
self.doc_mat = matrix
|
39 |
+
self.ngrams = metadata['ngram']
|
40 |
+
self.hash_size = metadata['hash_size']
|
41 |
+
self.tokenizer = tokenizers.get_class(metadata['tokenizer'])()
|
42 |
+
self.doc_freqs = metadata['doc_freqs'].squeeze()
|
43 |
+
self.doc_dict = metadata['doc_dict']
|
44 |
+
self.num_docs = len(self.doc_dict[0])
|
45 |
+
self.strict = strict
|
46 |
+
|
47 |
+
def get_doc_index(self, doc_id):
|
48 |
+
"""Convert doc_id --> doc_index"""
|
49 |
+
return self.doc_dict[0][doc_id]
|
50 |
+
|
51 |
+
def get_doc_id(self, doc_index):
|
52 |
+
"""Convert doc_index --> doc_id"""
|
53 |
+
return self.doc_dict[1][doc_index]
|
54 |
+
|
55 |
+
def closest_docs(self, query, k=1):
|
56 |
+
"""Closest docs by dot product between query and documents
|
57 |
+
in tfidf weighted word vector space.
|
58 |
+
"""
|
59 |
+
spvec = self.text2spvec(query)
|
60 |
+
res = spvec * self.doc_mat
|
61 |
+
|
62 |
+
if len(res.data) <= k:
|
63 |
+
o_sort = np.argsort(-res.data)
|
64 |
+
else:
|
65 |
+
o = np.argpartition(-res.data, k)[0:k]
|
66 |
+
o_sort = o[np.argsort(-res.data[o])]
|
67 |
+
|
68 |
+
doc_scores = res.data[o_sort]
|
69 |
+
doc_ids = [self.get_doc_id(i) for i in res.indices[o_sort]]
|
70 |
+
return doc_ids, doc_scores
|
71 |
+
|
72 |
+
def batch_closest_docs(self, queries, k=1, num_workers=None):
|
73 |
+
"""Process a batch of closest_docs requests multithreaded.
|
74 |
+
Note: we can use plain threads here as scipy is outside of the GIL.
|
75 |
+
"""
|
76 |
+
with ThreadPool(num_workers) as threads:
|
77 |
+
closest_docs = partial(self.closest_docs, k=k)
|
78 |
+
results = threads.map(closest_docs, queries)
|
79 |
+
return results
|
80 |
+
|
81 |
+
def parse(self, query):
|
82 |
+
"""Parse the query into tokens (either ngrams or tokens)."""
|
83 |
+
tokens = self.tokenizer.tokenize(query)
|
84 |
+
return tokens.ngrams(n=self.ngrams, uncased=True,
|
85 |
+
filter_fn=utils.filter_ngram)
|
86 |
+
|
87 |
+
def text2spvec(self, query):
|
88 |
+
"""Create a sparse tfidf-weighted word vector from query.
|
89 |
+
|
90 |
+
tfidf = log(tf + 1) * log((N - Nt + 0.5) / (Nt + 0.5))
|
91 |
+
"""
|
92 |
+
# Get hashed ngrams
|
93 |
+
words = self.parse(utils.normalize(query))
|
94 |
+
wids = [utils.hash(w, self.hash_size) for w in words]
|
95 |
+
|
96 |
+
if len(wids) == 0:
|
97 |
+
if self.strict:
|
98 |
+
raise RuntimeError('No valid word in: %s' % query)
|
99 |
+
else:
|
100 |
+
logger.warning('No valid word in: %s' % query)
|
101 |
+
return sp.csr_matrix((1, self.hash_size))
|
102 |
+
|
103 |
+
# Count TF
|
104 |
+
wids_unique, wids_counts = np.unique(wids, return_counts=True)
|
105 |
+
tfs = np.log1p(wids_counts)
|
106 |
+
|
107 |
+
# Count IDF
|
108 |
+
Ns = self.doc_freqs[wids_unique]
|
109 |
+
idfs = np.log((self.num_docs - Ns + 0.5) / (Ns + 0.5))
|
110 |
+
idfs[idfs < 0] = 0
|
111 |
+
|
112 |
+
# TF-IDF
|
113 |
+
data = np.multiply(tfs, idfs)
|
114 |
+
|
115 |
+
# One row, sparse csr matrix
|
116 |
+
indptr = np.array([0, len(wids_unique)])
|
117 |
+
spvec = sp.csr_matrix(
|
118 |
+
(data, wids_unique, indptr), shape=(1, self.hash_size)
|
119 |
+
)
|
120 |
+
|
121 |
+
return spvec
|
drqa/retriever/utils.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright 2017-present, Facebook, Inc.
|
3 |
+
# All rights reserved.
|
4 |
+
#
|
5 |
+
# This source code is licensed under the license found in the
|
6 |
+
# LICENSE file in the root directory of this source tree.
|
7 |
+
"""Various retriever utilities."""
|
8 |
+
|
9 |
+
import regex
|
10 |
+
import unicodedata
|
11 |
+
import numpy as np
|
12 |
+
import scipy.sparse as sp
|
13 |
+
from sklearn.utils import murmurhash3_32
|
14 |
+
|
15 |
+
|
16 |
+
# ------------------------------------------------------------------------------
|
17 |
+
# Sparse matrix saving/loading helpers.
|
18 |
+
# ------------------------------------------------------------------------------
|
19 |
+
|
20 |
+
|
21 |
+
def save_sparse_csr(filename, matrix, metadata=None):
|
22 |
+
data = {
|
23 |
+
'data': matrix.data,
|
24 |
+
'indices': matrix.indices,
|
25 |
+
'indptr': matrix.indptr,
|
26 |
+
'shape': matrix.shape,
|
27 |
+
'metadata': metadata,
|
28 |
+
}
|
29 |
+
np.savez(filename, **data)
|
30 |
+
|
31 |
+
|
32 |
+
def load_sparse_csr(filename):
|
33 |
+
loader = np.load(filename, allow_pickle=True)
|
34 |
+
matrix = sp.csr_matrix((loader['data'], loader['indices'],
|
35 |
+
loader['indptr']), shape=loader['shape'])
|
36 |
+
return matrix, loader['metadata'].item(0) if 'metadata' in loader else None
|
37 |
+
|
38 |
+
|
39 |
+
# ------------------------------------------------------------------------------
|
40 |
+
# Token hashing.
|
41 |
+
# ------------------------------------------------------------------------------
|
42 |
+
|
43 |
+
|
44 |
+
def hash(token, num_buckets):
|
45 |
+
"""Unsigned 32 bit murmurhash for feature hashing."""
|
46 |
+
return murmurhash3_32(token, positive=True) % num_buckets
|
47 |
+
|
48 |
+
|
49 |
+
# ------------------------------------------------------------------------------
|
50 |
+
# Text cleaning.
|
51 |
+
# ------------------------------------------------------------------------------
|
52 |
+
|
53 |
+
|
54 |
+
STOPWORDS = {
|
55 |
+
'i', 'me', 'my', 'myself', 'we', 'our', 'ours', 'ourselves', 'you', 'your',
|
56 |
+
'yours', 'yourself', 'yourselves', 'he', 'him', 'his', 'himself', 'she',
|
57 |
+
'her', 'hers', 'herself', 'it', 'its', 'itself', 'they', 'them', 'their',
|
58 |
+
'theirs', 'themselves', 'what', 'which', 'who', 'whom', 'this', 'that',
|
59 |
+
'these', 'those', 'am', 'is', 'are', 'was', 'were', 'be', 'been', 'being',
|
60 |
+
'have', 'has', 'had', 'having', 'do', 'does', 'did', 'doing', 'a', 'an',
|
61 |
+
'the', 'and', 'but', 'if', 'or', 'because', 'as', 'until', 'while', 'of',
|
62 |
+
'at', 'by', 'for', 'with', 'about', 'against', 'between', 'into', 'through',
|
63 |
+
'during', 'before', 'after', 'above', 'below', 'to', 'from', 'up', 'down',
|
64 |
+
'in', 'out', 'on', 'off', 'over', 'under', 'again', 'further', 'then',
|
65 |
+
'once', 'here', 'there', 'when', 'where', 'why', 'how', 'all', 'any',
|
66 |
+
'both', 'each', 'few', 'more', 'most', 'other', 'some', 'such', 'no', 'nor',
|
67 |
+
'not', 'only', 'own', 'same', 'so', 'than', 'too', 'very', 's', 't', 'can',
|
68 |
+
'will', 'just', 'don', 'should', 'now', 'd', 'll', 'm', 'o', 're', 've',
|
69 |
+
'y', 'ain', 'aren', 'couldn', 'didn', 'doesn', 'hadn', 'hasn', 'haven',
|
70 |
+
'isn', 'ma', 'mightn', 'mustn', 'needn', 'shan', 'shouldn', 'wasn', 'weren',
|
71 |
+
'won', 'wouldn', "'ll", "'re", "'ve", "n't", "'s", "'d", "'m", "''", "``"
|
72 |
+
}
|
73 |
+
|
74 |
+
|
75 |
+
def normalize(text):
|
76 |
+
"""Resolve different type of unicode encodings."""
|
77 |
+
return unicodedata.normalize('NFD', text)
|
78 |
+
|
79 |
+
|
80 |
+
def filter_word(text):
|
81 |
+
"""Take out english stopwords, punctuation, and compound endings."""
|
82 |
+
text = normalize(text)
|
83 |
+
if regex.match(r'^\p{P}+$', text):
|
84 |
+
return True
|
85 |
+
if text.lower() in STOPWORDS:
|
86 |
+
return True
|
87 |
+
return False
|
88 |
+
|
89 |
+
|
90 |
+
def filter_ngram(gram, mode='any'):
|
91 |
+
"""Decide whether to keep or discard an n-gram.
|
92 |
+
|
93 |
+
Args:
|
94 |
+
gram: list of tokens (length N)
|
95 |
+
mode: Option to throw out ngram if
|
96 |
+
'any': any single token passes filter_word
|
97 |
+
'all': all tokens pass filter_word
|
98 |
+
'ends': book-ended by filterable tokens
|
99 |
+
"""
|
100 |
+
filtered = [filter_word(w) for w in gram]
|
101 |
+
if mode == 'any':
|
102 |
+
return any(filtered)
|
103 |
+
elif mode == 'all':
|
104 |
+
return all(filtered)
|
105 |
+
elif mode == 'ends':
|
106 |
+
return filtered[0] or filtered[-1]
|
107 |
+
else:
|
108 |
+
raise ValueError('Invalid mode: %s' % mode)
|
109 |
+
|
110 |
+
def get_field(d, field_list):
|
111 |
+
"""get the subfield associated to a list of elastic fields
|
112 |
+
E.g. ['file', 'filename'] to d['file']['filename']
|
113 |
+
"""
|
114 |
+
if isinstance(field_list, str):
|
115 |
+
return d[field_list]
|
116 |
+
else:
|
117 |
+
idx = d.copy()
|
118 |
+
for field in field_list:
|
119 |
+
idx = idx[field]
|
120 |
+
return idx
|
drqa/tokenizers/__init__.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright 2017-present, Facebook, Inc.
|
3 |
+
# All rights reserved.
|
4 |
+
#
|
5 |
+
# This source code is licensed under the license found in the
|
6 |
+
# LICENSE file in the root directory of this source tree.
|
7 |
+
|
8 |
+
import os
|
9 |
+
|
10 |
+
DEFAULTS = {
|
11 |
+
'corenlp_classpath': os.getenv('CLASSPATH')
|
12 |
+
}
|
13 |
+
|
14 |
+
|
15 |
+
def set_default(key, value):
|
16 |
+
global DEFAULTS
|
17 |
+
DEFAULTS[key] = value
|
18 |
+
|
19 |
+
|
20 |
+
from .corenlp_tokenizer import CoreNLPTokenizer
|
21 |
+
from .regexp_tokenizer import RegexpTokenizer
|
22 |
+
from .simple_tokenizer import SimpleTokenizer
|
23 |
+
|
24 |
+
# Spacy is optional
|
25 |
+
try:
|
26 |
+
from .spacy_tokenizer import SpacyTokenizer
|
27 |
+
except ImportError:
|
28 |
+
pass
|
29 |
+
|
30 |
+
|
31 |
+
def get_class(name):
|
32 |
+
if name == 'spacy':
|
33 |
+
return SpacyTokenizer
|
34 |
+
if name == 'corenlp':
|
35 |
+
return CoreNLPTokenizer
|
36 |
+
if name == 'regexp':
|
37 |
+
return RegexpTokenizer
|
38 |
+
if name == 'simple':
|
39 |
+
return SimpleTokenizer
|
40 |
+
|
41 |
+
raise RuntimeError('Invalid tokenizer: %s' % name)
|
42 |
+
|
43 |
+
|
44 |
+
def get_annotators_for_args(args):
|
45 |
+
annotators = set()
|
46 |
+
if args.use_pos:
|
47 |
+
annotators.add('pos')
|
48 |
+
if args.use_lemma:
|
49 |
+
annotators.add('lemma')
|
50 |
+
if args.use_ner:
|
51 |
+
annotators.add('ner')
|
52 |
+
return annotators
|
53 |
+
|
54 |
+
|
55 |
+
def get_annotators_for_model(model):
|
56 |
+
return get_annotators_for_args(model.args)
|
drqa/tokenizers/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (1.33 kB). View file
|
|
drqa/tokenizers/__pycache__/corenlp_tokenizer.cpython-38.pyc
ADDED
Binary file (3.49 kB). View file
|
|
drqa/tokenizers/__pycache__/regexp_tokenizer.cpython-38.pyc
ADDED
Binary file (3.31 kB). View file
|
|
drqa/tokenizers/__pycache__/simple_tokenizer.cpython-38.pyc
ADDED
Binary file (1.77 kB). View file
|
|
drqa/tokenizers/__pycache__/spacy_tokenizer.cpython-38.pyc
ADDED
Binary file (2.05 kB). View file
|
|
drqa/tokenizers/__pycache__/tokenizer.cpython-38.pyc
ADDED
Binary file (5.83 kB). View file
|
|
drqa/tokenizers/corenlp_tokenizer.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright 2017-present, Facebook, Inc.
|
3 |
+
# All rights reserved.
|
4 |
+
#
|
5 |
+
# This source code is licensed under the license found in the
|
6 |
+
# LICENSE file in the root directory of this source tree.
|
7 |
+
"""Simple wrapper around the Stanford CoreNLP pipeline.
|
8 |
+
|
9 |
+
Serves commands to a java subprocess running the jar. Requires java 8.
|
10 |
+
"""
|
11 |
+
|
12 |
+
import copy
|
13 |
+
import json
|
14 |
+
import pexpect
|
15 |
+
|
16 |
+
from .tokenizer import Tokens, Tokenizer
|
17 |
+
from . import DEFAULTS
|
18 |
+
|
19 |
+
|
20 |
+
class CoreNLPTokenizer(Tokenizer):
|
21 |
+
|
22 |
+
def __init__(self, **kwargs):
|
23 |
+
"""
|
24 |
+
Args:
|
25 |
+
annotators: set that can include pos, lemma, and ner.
|
26 |
+
classpath: Path to the corenlp directory of jars
|
27 |
+
mem: Java heap memory
|
28 |
+
"""
|
29 |
+
self.classpath = (kwargs.get('classpath') or
|
30 |
+
DEFAULTS['corenlp_classpath'])
|
31 |
+
self.annotators = copy.deepcopy(kwargs.get('annotators', set()))
|
32 |
+
self.mem = kwargs.get('mem', '2g')
|
33 |
+
self._launch()
|
34 |
+
|
35 |
+
def _launch(self):
|
36 |
+
"""Start the CoreNLP jar with pexpect."""
|
37 |
+
annotators = ['tokenize', 'ssplit']
|
38 |
+
if 'ner' in self.annotators:
|
39 |
+
annotators.extend(['pos', 'lemma', 'ner'])
|
40 |
+
elif 'lemma' in self.annotators:
|
41 |
+
annotators.extend(['pos', 'lemma'])
|
42 |
+
elif 'pos' in self.annotators:
|
43 |
+
annotators.extend(['pos'])
|
44 |
+
annotators = ','.join(annotators)
|
45 |
+
options = ','.join(['untokenizable=noneDelete',
|
46 |
+
'invertible=true'])
|
47 |
+
cmd = ['java', '-mx' + self.mem, '-cp', '"%s"' % self.classpath,
|
48 |
+
'edu.stanford.nlp.pipeline.StanfordCoreNLP', '-annotators',
|
49 |
+
annotators, '-tokenize.options', options,
|
50 |
+
'-outputFormat', 'json', '-prettyPrint', 'false']
|
51 |
+
|
52 |
+
# We use pexpect to keep the subprocess alive and feed it commands.
|
53 |
+
# Because we don't want to get hit by the max terminal buffer size,
|
54 |
+
# we turn off canonical input processing to have unlimited bytes.
|
55 |
+
self.corenlp = pexpect.spawn('/bin/bash', maxread=100000, timeout=60)
|
56 |
+
self.corenlp.setecho(False)
|
57 |
+
self.corenlp.sendline('stty -icanon')
|
58 |
+
self.corenlp.sendline(' '.join(cmd))
|
59 |
+
self.corenlp.delaybeforesend = 0
|
60 |
+
self.corenlp.delayafterread = 0
|
61 |
+
self.corenlp.expect_exact('NLP>', searchwindowsize=100)
|
62 |
+
|
63 |
+
@staticmethod
|
64 |
+
def _convert(token):
|
65 |
+
if token == '-LRB-':
|
66 |
+
return '('
|
67 |
+
if token == '-RRB-':
|
68 |
+
return ')'
|
69 |
+
if token == '-LSB-':
|
70 |
+
return '['
|
71 |
+
if token == '-RSB-':
|
72 |
+
return ']'
|
73 |
+
if token == '-LCB-':
|
74 |
+
return '{'
|
75 |
+
if token == '-RCB-':
|
76 |
+
return '}'
|
77 |
+
return token
|
78 |
+
|
79 |
+
def tokenize(self, text):
|
80 |
+
# Since we're feeding text to the commandline, we're waiting on seeing
|
81 |
+
# the NLP> prompt. Hacky!
|
82 |
+
if 'NLP>' in text:
|
83 |
+
raise RuntimeError('Bad token (NLP>) in text!')
|
84 |
+
|
85 |
+
# Sending q will cause the process to quit -- manually override
|
86 |
+
if text.lower().strip() == 'q':
|
87 |
+
token = text.strip()
|
88 |
+
index = text.index(token)
|
89 |
+
data = [(token, text[index:], (index, index + 1), 'NN', 'q', 'O')]
|
90 |
+
return Tokens(data, self.annotators)
|
91 |
+
|
92 |
+
# Minor cleanup before tokenizing.
|
93 |
+
clean_text = text.replace('\n', ' ')
|
94 |
+
|
95 |
+
self.corenlp.sendline(clean_text.encode('utf-8'))
|
96 |
+
self.corenlp.expect_exact('NLP>', searchwindowsize=100)
|
97 |
+
|
98 |
+
# Skip to start of output (may have been stderr logging messages)
|
99 |
+
output = self.corenlp.before
|
100 |
+
start = output.find(b'{"sentences":')
|
101 |
+
output = json.loads(output[start:].decode('utf-8'))
|
102 |
+
|
103 |
+
data = []
|
104 |
+
tokens = [t for s in output['sentences'] for t in s['tokens']]
|
105 |
+
for i in range(len(tokens)):
|
106 |
+
# Get whitespace
|
107 |
+
start_ws = tokens[i]['characterOffsetBegin']
|
108 |
+
if i + 1 < len(tokens):
|
109 |
+
end_ws = tokens[i + 1]['characterOffsetBegin']
|
110 |
+
else:
|
111 |
+
end_ws = tokens[i]['characterOffsetEnd']
|
112 |
+
|
113 |
+
data.append((
|
114 |
+
self._convert(tokens[i]['word']),
|
115 |
+
text[start_ws: end_ws],
|
116 |
+
(tokens[i]['characterOffsetBegin'],
|
117 |
+
tokens[i]['characterOffsetEnd']),
|
118 |
+
tokens[i].get('pos', None),
|
119 |
+
tokens[i].get('lemma', None),
|
120 |
+
tokens[i].get('ner', None)
|
121 |
+
))
|
122 |
+
return Tokens(data, self.annotators)
|
drqa/tokenizers/regexp_tokenizer.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright 2017-present, Facebook, Inc.
|
3 |
+
# All rights reserved.
|
4 |
+
#
|
5 |
+
# This source code is licensed under the license found in the
|
6 |
+
# LICENSE file in the root directory of this source tree.
|
7 |
+
"""Regex based tokenizer that emulates the Stanford/NLTK PTB tokenizers.
|
8 |
+
|
9 |
+
However it is purely in Python, supports robust untokenization, unicode,
|
10 |
+
and requires minimal dependencies.
|
11 |
+
"""
|
12 |
+
|
13 |
+
import regex
|
14 |
+
import logging
|
15 |
+
from .tokenizer import Tokens, Tokenizer
|
16 |
+
|
17 |
+
logger = logging.getLogger(__name__)
|
18 |
+
|
19 |
+
|
20 |
+
class RegexpTokenizer(Tokenizer):
|
21 |
+
DIGIT = r'\p{Nd}+([:\.\,]\p{Nd}+)*'
|
22 |
+
TITLE = (r'(dr|esq|hon|jr|mr|mrs|ms|prof|rev|sr|st|rt|messrs|mmes|msgr)'
|
23 |
+
r'\.(?=\p{Z})')
|
24 |
+
ABBRV = r'([\p{L}]\.){2,}(?=\p{Z}|$)'
|
25 |
+
ALPHA_NUM = r'[\p{L}\p{N}\p{M}]++'
|
26 |
+
HYPHEN = r'{A}([-\u058A\u2010\u2011]{A})+'.format(A=ALPHA_NUM)
|
27 |
+
NEGATION = r"((?!n't)[\p{L}\p{N}\p{M}])++(?=n't)|n't"
|
28 |
+
CONTRACTION1 = r"can(?=not\b)"
|
29 |
+
CONTRACTION2 = r"'([tsdm]|re|ll|ve)\b"
|
30 |
+
START_DQUOTE = r'(?<=[\p{Z}\(\[{<]|^)(``|["\u0093\u201C\u00AB])(?!\p{Z})'
|
31 |
+
START_SQUOTE = r'(?<=[\p{Z}\(\[{<]|^)[\'\u0091\u2018\u201B\u2039](?!\p{Z})'
|
32 |
+
END_DQUOTE = r'(?<!\p{Z})(\'\'|["\u0094\u201D\u00BB])'
|
33 |
+
END_SQUOTE = r'(?<!\p{Z})[\'\u0092\u2019\u203A]'
|
34 |
+
DASH = r'--|[\u0096\u0097\u2013\u2014\u2015]'
|
35 |
+
ELLIPSES = r'\.\.\.|\u2026'
|
36 |
+
PUNCT = r'\p{P}'
|
37 |
+
NON_WS = r'[^\p{Z}\p{C}]'
|
38 |
+
|
39 |
+
def __init__(self, **kwargs):
|
40 |
+
"""
|
41 |
+
Args:
|
42 |
+
annotators: None or empty set (only tokenizes).
|
43 |
+
substitutions: if true, normalizes some token types (e.g. quotes).
|
44 |
+
"""
|
45 |
+
self._regexp = regex.compile(
|
46 |
+
'(?P<digit>%s)|(?P<title>%s)|(?P<abbr>%s)|(?P<neg>%s)|(?P<hyph>%s)|'
|
47 |
+
'(?P<contr1>%s)|(?P<alphanum>%s)|(?P<contr2>%s)|(?P<sdquote>%s)|'
|
48 |
+
'(?P<edquote>%s)|(?P<ssquote>%s)|(?P<esquote>%s)|(?P<dash>%s)|'
|
49 |
+
'(?<ellipses>%s)|(?P<punct>%s)|(?P<nonws>%s)' %
|
50 |
+
(self.DIGIT, self.TITLE, self.ABBRV, self.NEGATION, self.HYPHEN,
|
51 |
+
self.CONTRACTION1, self.ALPHA_NUM, self.CONTRACTION2,
|
52 |
+
self.START_DQUOTE, self.END_DQUOTE, self.START_SQUOTE,
|
53 |
+
self.END_SQUOTE, self.DASH, self.ELLIPSES, self.PUNCT,
|
54 |
+
self.NON_WS),
|
55 |
+
flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE
|
56 |
+
)
|
57 |
+
if len(kwargs.get('annotators', {})) > 0:
|
58 |
+
logger.warning('%s only tokenizes! Skipping annotators: %s' %
|
59 |
+
(type(self).__name__, kwargs.get('annotators')))
|
60 |
+
self.annotators = set()
|
61 |
+
self.substitutions = kwargs.get('substitutions', True)
|
62 |
+
|
63 |
+
def tokenize(self, text):
|
64 |
+
data = []
|
65 |
+
matches = [m for m in self._regexp.finditer(text)]
|
66 |
+
for i in range(len(matches)):
|
67 |
+
# Get text
|
68 |
+
token = matches[i].group()
|
69 |
+
|
70 |
+
# Make normalizations for special token types
|
71 |
+
if self.substitutions:
|
72 |
+
groups = matches[i].groupdict()
|
73 |
+
if groups['sdquote']:
|
74 |
+
token = "``"
|
75 |
+
elif groups['edquote']:
|
76 |
+
token = "''"
|
77 |
+
elif groups['ssquote']:
|
78 |
+
token = "`"
|
79 |
+
elif groups['esquote']:
|
80 |
+
token = "'"
|
81 |
+
elif groups['dash']:
|
82 |
+
token = '--'
|
83 |
+
elif groups['ellipses']:
|
84 |
+
token = '...'
|
85 |
+
|
86 |
+
# Get whitespace
|
87 |
+
span = matches[i].span()
|
88 |
+
start_ws = span[0]
|
89 |
+
if i + 1 < len(matches):
|
90 |
+
end_ws = matches[i + 1].span()[0]
|
91 |
+
else:
|
92 |
+
end_ws = span[1]
|
93 |
+
|
94 |
+
# Format data
|
95 |
+
data.append((
|
96 |
+
token,
|
97 |
+
text[start_ws: end_ws],
|
98 |
+
span,
|
99 |
+
))
|
100 |
+
return Tokens(data, self.annotators)
|
drqa/tokenizers/simple_tokenizer.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright 2017-present, Facebook, Inc.
|
3 |
+
# All rights reserved.
|
4 |
+
#
|
5 |
+
# This source code is licensed under the license found in the
|
6 |
+
# LICENSE file in the root directory of this source tree.
|
7 |
+
"""Basic tokenizer that splits text into alpha-numeric tokens and
|
8 |
+
non-whitespace tokens.
|
9 |
+
"""
|
10 |
+
|
11 |
+
import regex
|
12 |
+
import logging
|
13 |
+
from .tokenizer import Tokens, Tokenizer
|
14 |
+
|
15 |
+
logger = logging.getLogger(__name__)
|
16 |
+
|
17 |
+
|
18 |
+
class SimpleTokenizer(Tokenizer):
|
19 |
+
ALPHA_NUM = r'[\p{L}\p{N}\p{M}]+'
|
20 |
+
NON_WS = r'[^\p{Z}\p{C}]'
|
21 |
+
|
22 |
+
def __init__(self, **kwargs):
|
23 |
+
"""
|
24 |
+
Args:
|
25 |
+
annotators: None or empty set (only tokenizes).
|
26 |
+
"""
|
27 |
+
self._regexp = regex.compile(
|
28 |
+
'(%s)|(%s)' % (self.ALPHA_NUM, self.NON_WS),
|
29 |
+
flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE
|
30 |
+
)
|
31 |
+
if len(kwargs.get('annotators', {})) > 0:
|
32 |
+
logger.warning('%s only tokenizes! Skipping annotators: %s' %
|
33 |
+
(type(self).__name__, kwargs.get('annotators')))
|
34 |
+
self.annotators = set()
|
35 |
+
|
36 |
+
def tokenize(self, text):
|
37 |
+
data = []
|
38 |
+
matches = [m for m in self._regexp.finditer(text)]
|
39 |
+
for i in range(len(matches)):
|
40 |
+
# Get text
|
41 |
+
token = matches[i].group()
|
42 |
+
|
43 |
+
# Get whitespace
|
44 |
+
span = matches[i].span()
|
45 |
+
start_ws = span[0]
|
46 |
+
if i + 1 < len(matches):
|
47 |
+
end_ws = matches[i + 1].span()[0]
|
48 |
+
else:
|
49 |
+
end_ws = span[1]
|
50 |
+
|
51 |
+
# Format data
|
52 |
+
data.append((
|
53 |
+
token,
|
54 |
+
text[start_ws: end_ws],
|
55 |
+
span,
|
56 |
+
))
|
57 |
+
return Tokens(data, self.annotators)
|
drqa/tokenizers/spacy_tokenizer.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright 2017-present, Facebook, Inc.
|
3 |
+
# All rights reserved.
|
4 |
+
#
|
5 |
+
# This source code is licensed under the license found in the
|
6 |
+
# LICENSE file in the root directory of this source tree.
|
7 |
+
"""Tokenizer that is backed by spaCy (spacy.io).
|
8 |
+
|
9 |
+
Requires spaCy package and the spaCy english model.
|
10 |
+
"""
|
11 |
+
|
12 |
+
import spacy
|
13 |
+
import copy
|
14 |
+
from .tokenizer import Tokens, Tokenizer
|
15 |
+
|
16 |
+
|
17 |
+
class SpacyTokenizer(Tokenizer):
|
18 |
+
|
19 |
+
def __init__(self, **kwargs):
|
20 |
+
"""
|
21 |
+
Args:
|
22 |
+
annotators: set that can include pos, lemma, and ner.
|
23 |
+
model: spaCy model to use (either path, or keyword like 'en').
|
24 |
+
"""
|
25 |
+
model = kwargs.get('model', 'en')
|
26 |
+
self.annotators = copy.deepcopy(kwargs.get('annotators', set()))
|
27 |
+
nlp_kwargs = {'parser': False}
|
28 |
+
if not any([p in self.annotators for p in ['lemma', 'pos', 'ner']]):
|
29 |
+
nlp_kwargs['tagger'] = False
|
30 |
+
if 'ner' not in self.annotators:
|
31 |
+
nlp_kwargs['entity'] = False
|
32 |
+
self.nlp = spacy.load(model, **nlp_kwargs)
|
33 |
+
|
34 |
+
def tokenize(self, text):
|
35 |
+
# We don't treat new lines as tokens.
|
36 |
+
clean_text = text.replace('\n', ' ')
|
37 |
+
tokens = self.nlp.tokenizer(clean_text)
|
38 |
+
if any([p in self.annotators for p in ['lemma', 'pos', 'ner']]):
|
39 |
+
self.nlp.tagger(tokens)
|
40 |
+
if 'ner' in self.annotators:
|
41 |
+
self.nlp.entity(tokens)
|
42 |
+
|
43 |
+
data = []
|
44 |
+
for i in range(len(tokens)):
|
45 |
+
# Get whitespace
|
46 |
+
start_ws = tokens[i].idx
|
47 |
+
if i + 1 < len(tokens):
|
48 |
+
end_ws = tokens[i + 1].idx
|
49 |
+
else:
|
50 |
+
end_ws = tokens[i].idx + len(tokens[i].text)
|
51 |
+
|
52 |
+
data.append((
|
53 |
+
tokens[i].text,
|
54 |
+
text[start_ws: end_ws],
|
55 |
+
(tokens[i].idx, tokens[i].idx + len(tokens[i].text)),
|
56 |
+
tokens[i].tag_,
|
57 |
+
tokens[i].lemma_,
|
58 |
+
tokens[i].ent_type_,
|
59 |
+
))
|
60 |
+
|
61 |
+
# Set special option for non-entity tag: '' vs 'O' in spaCy
|
62 |
+
return Tokens(data, self.annotators, opts={'non_ent': ''})
|
drqa/tokenizers/tokenizer.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright 2017-present, Facebook, Inc.
|
3 |
+
# All rights reserved.
|
4 |
+
#
|
5 |
+
# This source code is licensed under the license found in the
|
6 |
+
# LICENSE file in the root directory of this source tree.
|
7 |
+
"""Base tokenizer/tokens classes and utilities."""
|
8 |
+
|
9 |
+
import copy
|
10 |
+
|
11 |
+
|
12 |
+
class Tokens(object):
|
13 |
+
"""A class to represent a list of tokenized text."""
|
14 |
+
TEXT = 0
|
15 |
+
TEXT_WS = 1
|
16 |
+
SPAN = 2
|
17 |
+
POS = 3
|
18 |
+
LEMMA = 4
|
19 |
+
NER = 5
|
20 |
+
|
21 |
+
def __init__(self, data, annotators, opts=None):
|
22 |
+
self.data = data
|
23 |
+
self.annotators = annotators
|
24 |
+
self.opts = opts or {}
|
25 |
+
|
26 |
+
def __len__(self):
|
27 |
+
"""The number of tokens."""
|
28 |
+
return len(self.data)
|
29 |
+
|
30 |
+
def slice(self, i=None, j=None):
|
31 |
+
"""Return a view of the list of tokens from [i, j)."""
|
32 |
+
new_tokens = copy.copy(self)
|
33 |
+
new_tokens.data = self.data[i: j]
|
34 |
+
return new_tokens
|
35 |
+
|
36 |
+
def untokenize(self):
|
37 |
+
"""Returns the original text (with whitespace reinserted)."""
|
38 |
+
return ''.join([t[self.TEXT_WS] for t in self.data]).strip()
|
39 |
+
|
40 |
+
def words(self, uncased=False):
|
41 |
+
"""Returns a list of the text of each token
|
42 |
+
|
43 |
+
Args:
|
44 |
+
uncased: lower cases text
|
45 |
+
"""
|
46 |
+
if uncased:
|
47 |
+
return [t[self.TEXT].lower() for t in self.data]
|
48 |
+
else:
|
49 |
+
return [t[self.TEXT] for t in self.data]
|
50 |
+
|
51 |
+
def offsets(self):
|
52 |
+
"""Returns a list of [start, end) character offsets of each token."""
|
53 |
+
return [t[self.SPAN] for t in self.data]
|
54 |
+
|
55 |
+
def pos(self):
|
56 |
+
"""Returns a list of part-of-speech tags of each token.
|
57 |
+
Returns None if this annotation was not included.
|
58 |
+
"""
|
59 |
+
if 'pos' not in self.annotators:
|
60 |
+
return None
|
61 |
+
return [t[self.POS] for t in self.data]
|
62 |
+
|
63 |
+
def lemmas(self):
|
64 |
+
"""Returns a list of the lemmatized text of each token.
|
65 |
+
Returns None if this annotation was not included.
|
66 |
+
"""
|
67 |
+
if 'lemma' not in self.annotators:
|
68 |
+
return None
|
69 |
+
return [t[self.LEMMA] for t in self.data]
|
70 |
+
|
71 |
+
def entities(self):
|
72 |
+
"""Returns a list of named-entity-recognition tags of each token.
|
73 |
+
Returns None if this annotation was not included.
|
74 |
+
"""
|
75 |
+
if 'ner' not in self.annotators:
|
76 |
+
return None
|
77 |
+
return [t[self.NER] for t in self.data]
|
78 |
+
|
79 |
+
def ngrams(self, n=1, uncased=False, filter_fn=None, as_strings=True):
|
80 |
+
"""Returns a list of all ngrams from length 1 to n.
|
81 |
+
|
82 |
+
Args:
|
83 |
+
n: upper limit of ngram length
|
84 |
+
uncased: lower cases text
|
85 |
+
filter_fn: user function that takes in an ngram list and returns
|
86 |
+
True or False to keep or not keep the ngram
|
87 |
+
as_string: return the ngram as a string vs list
|
88 |
+
"""
|
89 |
+
def _skip(gram):
|
90 |
+
if not filter_fn:
|
91 |
+
return False
|
92 |
+
return filter_fn(gram)
|
93 |
+
|
94 |
+
words = self.words(uncased)
|
95 |
+
ngrams = [(s, e + 1)
|
96 |
+
for s in range(len(words))
|
97 |
+
for e in range(s, min(s + n, len(words)))
|
98 |
+
if not _skip(words[s:e + 1])]
|
99 |
+
|
100 |
+
# Concatenate into strings
|
101 |
+
if as_strings:
|
102 |
+
ngrams = ['{}'.format(' '.join(words[s:e])) for (s, e) in ngrams]
|
103 |
+
|
104 |
+
return ngrams
|
105 |
+
|
106 |
+
def entity_groups(self):
|
107 |
+
"""Group consecutive entity tokens with the same NER tag."""
|
108 |
+
entities = self.entities()
|
109 |
+
if not entities:
|
110 |
+
return None
|
111 |
+
non_ent = self.opts.get('non_ent', 'O')
|
112 |
+
groups = []
|
113 |
+
idx = 0
|
114 |
+
while idx < len(entities):
|
115 |
+
ner_tag = entities[idx]
|
116 |
+
# Check for entity tag
|
117 |
+
if ner_tag != non_ent:
|
118 |
+
# Chomp the sequence
|
119 |
+
start = idx
|
120 |
+
while (idx < len(entities) and entities[idx] == ner_tag):
|
121 |
+
idx += 1
|
122 |
+
groups.append((self.slice(start, idx).untokenize(), ner_tag))
|
123 |
+
else:
|
124 |
+
idx += 1
|
125 |
+
return groups
|
126 |
+
|
127 |
+
|
128 |
+
class Tokenizer(object):
|
129 |
+
"""Base tokenizer class.
|
130 |
+
Tokenizers implement tokenize, which should return a Tokens class.
|
131 |
+
"""
|
132 |
+
def tokenize(self, text):
|
133 |
+
raise NotImplementedError
|
134 |
+
|
135 |
+
def shutdown(self):
|
136 |
+
pass
|
137 |
+
|
138 |
+
def __del__(self):
|
139 |
+
self.shutdown()
|
html2lines.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from distutils.command.config import config
|
2 |
+
import requests
|
3 |
+
from time import sleep
|
4 |
+
import trafilatura
|
5 |
+
from trafilatura.meta import reset_caches
|
6 |
+
from trafilatura.settings import DEFAULT_CONFIG
|
7 |
+
import spacy
|
8 |
+
import os
|
9 |
+
os.system("python -m spacy download en_core_web_sm")
|
10 |
+
nlp = spacy.load('en_core_web_sm')
|
11 |
+
import sys
|
12 |
+
|
13 |
+
DEFAULT_CONFIG.MAX_FILE_SIZE = 50000
|
14 |
+
|
15 |
+
def get_page(url):
|
16 |
+
page = None
|
17 |
+
for i in range(3):
|
18 |
+
try:
|
19 |
+
page = trafilatura.fetch_url(url, config=DEFAULT_CONFIG)
|
20 |
+
assert page is not None
|
21 |
+
print("Fetched "+url, file=sys.stderr)
|
22 |
+
break
|
23 |
+
except:
|
24 |
+
sleep(3)
|
25 |
+
return page
|
26 |
+
|
27 |
+
def url2lines(url):
|
28 |
+
page = get_page(url)
|
29 |
+
|
30 |
+
if page is None:
|
31 |
+
return []
|
32 |
+
|
33 |
+
lines = html2lines(page)
|
34 |
+
return lines
|
35 |
+
|
36 |
+
def line_correction(lines, max_size=100):
|
37 |
+
out_lines = []
|
38 |
+
for line in lines:
|
39 |
+
if len(line) < 4:
|
40 |
+
continue
|
41 |
+
|
42 |
+
if len(line) > max_size:
|
43 |
+
doc = nlp(line[:5000]) # We split lines into sentences, but for performance we take only the first 5k characters per line
|
44 |
+
stack = ""
|
45 |
+
for sent in doc.sents:
|
46 |
+
if len(stack) > 0:
|
47 |
+
stack += " "
|
48 |
+
stack += str(sent).strip()
|
49 |
+
if len(stack) > max_size:
|
50 |
+
out_lines.append(stack)
|
51 |
+
stack = ""
|
52 |
+
|
53 |
+
if len(stack) > 0:
|
54 |
+
out_lines.append(stack)
|
55 |
+
else:
|
56 |
+
out_lines.append(line)
|
57 |
+
|
58 |
+
return out_lines
|
59 |
+
|
60 |
+
def html2lines(page):
|
61 |
+
out_lines = []
|
62 |
+
|
63 |
+
if len(page.strip()) == 0 or page is None:
|
64 |
+
return out_lines
|
65 |
+
|
66 |
+
text = trafilatura.extract(page, config=DEFAULT_CONFIG)
|
67 |
+
reset_caches()
|
68 |
+
|
69 |
+
if text is None:
|
70 |
+
return out_lines
|
71 |
+
|
72 |
+
return text.split("\n") # We just spit out the entire page, so need to reformat later.
|
requirements.txt
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
nltk
|
3 |
+
rank_bm25
|
4 |
+
accelerate
|
5 |
+
trafilatura
|
6 |
+
spacy
|
7 |
+
pytorch_lightning
|
8 |
+
transformers==4.29.2
|
9 |
+
datasets
|
10 |
+
leven
|
11 |
+
scikit-learn
|
12 |
+
pexpect
|
13 |
+
elasticsearch
|
14 |
+
torch
|
15 |
+
huggingface_hub
|
16 |
+
google-api-python-client
|
17 |
+
wikipedia-api
|
18 |
+
beautifulsoup4
|
19 |
+
azure-storage-file-share
|
20 |
+
azure-storage-blob
|
21 |
+
bm25s
|
22 |
+
PyStemmer
|