Ramon Meffert
commited on
Commit
·
be1f224
1
Parent(s):
b06298d
Add longformer
Browse files- .gitattributes +2 -0
- README.md +7 -2
- query.py +61 -18
- src/models/{paragraphs_embedding.faiss → dpr.faiss} +1 -1
- src/models/longformer.faiss +3 -0
- src/readers/base_reader.py +9 -0
- src/readers/dpr_reader.py +3 -1
- src/readers/longformer_reader.py +41 -0
- src/retrievers/faiss_retriever.py +89 -33
.gitattributes
CHANGED
|
@@ -28,3 +28,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 28 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 29 |
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
| 30 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 28 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 29 |
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
| 30 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
src/models/dpr.faiss filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
src/models/longformer.faiss filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -75,7 +75,10 @@ By default, the best answer along with its location in the book will be
|
|
| 75 |
returned. If you want to generate more answers (say, a top-5), you can supply
|
| 76 |
the `--top=5` option. The default retriever uses [FAISS](https://faiss.ai/), but
|
| 77 |
you can also use [ElasticSearch](https://www.elastic.co/elastic-stack/) using
|
| 78 |
-
the `--retriever=es` option.
|
|
|
|
|
|
|
|
|
|
| 79 |
|
| 80 |
### CLI overview
|
| 81 |
|
|
@@ -83,7 +86,7 @@ To get an overview of all available options, run `python query.py --help`. The
|
|
| 83 |
options are also printed below.
|
| 84 |
|
| 85 |
```sh
|
| 86 |
-
usage: query.py [-h] [--top int] [--retriever {faiss,es}] str
|
| 87 |
|
| 88 |
positional arguments:
|
| 89 |
str The question to feed to the QA system
|
|
@@ -93,6 +96,8 @@ options:
|
|
| 93 |
--top int, -t int The number of answers to retrieve
|
| 94 |
--retriever {faiss,es}, -r {faiss,es}
|
| 95 |
The retrieval method to use
|
|
|
|
|
|
|
| 96 |
```
|
| 97 |
|
| 98 |
|
|
|
|
| 75 |
returned. If you want to generate more answers (say, a top-5), you can supply
|
| 76 |
the `--top=5` option. The default retriever uses [FAISS](https://faiss.ai/), but
|
| 77 |
you can also use [ElasticSearch](https://www.elastic.co/elastic-stack/) using
|
| 78 |
+
the `--retriever=es` option. You can also pick a language model using the
|
| 79 |
+
`--lm` option, which accepts either `dpr` (Dense Passage Retrieval) or
|
| 80 |
+
`longformer`. The language model is used to generate embeddings for FAISS, and
|
| 81 |
+
is used to generate the answer.
|
| 82 |
|
| 83 |
### CLI overview
|
| 84 |
|
|
|
|
| 86 |
options are also printed below.
|
| 87 |
|
| 88 |
```sh
|
| 89 |
+
usage: query.py [-h] [--top int] [--retriever {faiss,es}] [--lm {dpr,longformer}] str
|
| 90 |
|
| 91 |
positional arguments:
|
| 92 |
str The question to feed to the QA system
|
|
|
|
| 96 |
--top int, -t int The number of answers to retrieve
|
| 97 |
--retriever {faiss,es}, -r {faiss,es}
|
| 98 |
The retrieval method to use
|
| 99 |
+
--lm {dpr,longformer}, -l {dpr,longformer}
|
| 100 |
+
The language model to use for the FAISS retriever
|
| 101 |
```
|
| 102 |
|
| 103 |
|
query.py
CHANGED
|
@@ -2,21 +2,48 @@ import argparse
|
|
| 2 |
import torch
|
| 3 |
import transformers
|
| 4 |
|
| 5 |
-
from typing import List, Literal,
|
| 6 |
from datasets import load_dataset, DatasetDict
|
| 7 |
from dotenv import load_dotenv
|
| 8 |
|
|
|
|
|
|
|
| 9 |
from src.readers.dpr_reader import DprReader
|
| 10 |
from src.retrievers.base_retriever import Retriever
|
| 11 |
from src.retrievers.es_retriever import ESRetriever
|
| 12 |
-
from src.retrievers.faiss_retriever import
|
|
|
|
|
|
|
|
|
|
| 13 |
from src.utils.preprocessing import context_to_reader_input
|
| 14 |
from src.utils.log import get_logger
|
| 15 |
|
| 16 |
|
| 17 |
-
def get_retriever(
|
| 18 |
-
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
|
| 22 |
def print_name(contexts: dict, section: str, id: int):
|
|
@@ -51,7 +78,11 @@ def print_answers(answers: List[tuple], scores: List[float], contexts: dict):
|
|
| 51 |
print()
|
| 52 |
|
| 53 |
|
| 54 |
-
def probe(query: str,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
scores, contexts = retriever.retrieve(query)
|
| 56 |
reader_input = context_to_reader_input(contexts)
|
| 57 |
answers = reader.read(query, reader_input, num_answers)
|
|
@@ -63,7 +94,7 @@ def default_probe(query: str):
|
|
| 63 |
# default probe is a probe that prints 5 answers with faiss
|
| 64 |
paragraphs = cast(DatasetDict, load_dataset(
|
| 65 |
"GroNLP/ik-nlp-22_slp", "paragraphs"))
|
| 66 |
-
retriever = get_retriever("faiss",
|
| 67 |
reader = DprReader()
|
| 68 |
|
| 69 |
return probe(query, retriever, reader)
|
|
@@ -75,13 +106,20 @@ def main(args: argparse.Namespace):
|
|
| 75 |
"GroNLP/ik-nlp-22_slp", "paragraphs"))
|
| 76 |
|
| 77 |
# Retrieve
|
| 78 |
-
retriever = get_retriever(args.retriever,
|
| 79 |
-
reader =
|
| 80 |
answers, scores, contexts = probe(
|
| 81 |
-
args.query, retriever, reader, args.
|
| 82 |
|
| 83 |
# Print output
|
| 84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
|
| 86 |
|
| 87 |
if __name__ == "__main__":
|
|
@@ -94,13 +132,18 @@ if __name__ == "__main__":
|
|
| 94 |
parser = argparse.ArgumentParser(
|
| 95 |
formatter_class=argparse.MetavarTypeHelpFormatter
|
| 96 |
)
|
| 97 |
-
parser.add_argument(
|
| 98 |
-
|
| 99 |
-
parser.add_argument(
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
|
| 105 |
args = parser.parse_args()
|
| 106 |
main(args)
|
|
|
|
| 2 |
import torch
|
| 3 |
import transformers
|
| 4 |
|
| 5 |
+
from typing import Dict, List, Literal, Tuple, cast
|
| 6 |
from datasets import load_dataset, DatasetDict
|
| 7 |
from dotenv import load_dotenv
|
| 8 |
|
| 9 |
+
from src.readers.base_reader import Reader
|
| 10 |
+
from src.readers.longformer_reader import LongformerReader
|
| 11 |
from src.readers.dpr_reader import DprReader
|
| 12 |
from src.retrievers.base_retriever import Retriever
|
| 13 |
from src.retrievers.es_retriever import ESRetriever
|
| 14 |
+
from src.retrievers.faiss_retriever import (
|
| 15 |
+
FaissRetriever,
|
| 16 |
+
FaissRetrieverOptions
|
| 17 |
+
)
|
| 18 |
from src.utils.preprocessing import context_to_reader_input
|
| 19 |
from src.utils.log import get_logger
|
| 20 |
|
| 21 |
|
| 22 |
+
def get_retriever(paragraphs: DatasetDict,
|
| 23 |
+
r: Literal["es", "faiss"],
|
| 24 |
+
lm: Literal["dpr", "longformer"]) -> Retriever:
|
| 25 |
+
match (r, lm):
|
| 26 |
+
case "es", _:
|
| 27 |
+
return ESRetriever()
|
| 28 |
+
case "faiss", "dpr":
|
| 29 |
+
options = FaissRetrieverOptions.dpr("./src/models/dpr.faiss")
|
| 30 |
+
return FaissRetriever(paragraphs, options)
|
| 31 |
+
case "faiss", "longformer":
|
| 32 |
+
options = FaissRetrieverOptions.longformer(
|
| 33 |
+
"./src/models/longformer.faiss")
|
| 34 |
+
return FaissRetriever(paragraphs, options)
|
| 35 |
+
case _:
|
| 36 |
+
raise ValueError("Retriever options not recognized")
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def get_reader(lm: Literal["dpr", "longformer"]) -> Reader:
|
| 40 |
+
match lm:
|
| 41 |
+
case "dpr":
|
| 42 |
+
return DprReader()
|
| 43 |
+
case "longformer":
|
| 44 |
+
return LongformerReader()
|
| 45 |
+
case _:
|
| 46 |
+
raise ValueError("Language model not recognized")
|
| 47 |
|
| 48 |
|
| 49 |
def print_name(contexts: dict, section: str, id: int):
|
|
|
|
| 78 |
print()
|
| 79 |
|
| 80 |
|
| 81 |
+
def probe(query: str,
|
| 82 |
+
retriever: Retriever,
|
| 83 |
+
reader: Reader,
|
| 84 |
+
num_answers: int = 5) \
|
| 85 |
+
-> Tuple[List[tuple], List[float], Dict[str, List[str]]]:
|
| 86 |
scores, contexts = retriever.retrieve(query)
|
| 87 |
reader_input = context_to_reader_input(contexts)
|
| 88 |
answers = reader.read(query, reader_input, num_answers)
|
|
|
|
| 94 |
# default probe is a probe that prints 5 answers with faiss
|
| 95 |
paragraphs = cast(DatasetDict, load_dataset(
|
| 96 |
"GroNLP/ik-nlp-22_slp", "paragraphs"))
|
| 97 |
+
retriever = get_retriever(paragraphs, "faiss", "dpr")
|
| 98 |
reader = DprReader()
|
| 99 |
|
| 100 |
return probe(query, retriever, reader)
|
|
|
|
| 106 |
"GroNLP/ik-nlp-22_slp", "paragraphs"))
|
| 107 |
|
| 108 |
# Retrieve
|
| 109 |
+
retriever = get_retriever(paragraphs, args.retriever, args.lm)
|
| 110 |
+
reader = get_reader(args.lm)
|
| 111 |
answers, scores, contexts = probe(
|
| 112 |
+
args.query, retriever, reader, args.top)
|
| 113 |
|
| 114 |
# Print output
|
| 115 |
+
print("Question: " + args.query)
|
| 116 |
+
print("Answer(s):")
|
| 117 |
+
if args.lm == "dpr":
|
| 118 |
+
print_answers(answers, scores, contexts)
|
| 119 |
+
else:
|
| 120 |
+
answers = filter(lambda a: len(a[0].strip()) > 0, answers)
|
| 121 |
+
for pos, answer in enumerate(answers, start=1):
|
| 122 |
+
print(f" - {answer[0].strip()}")
|
| 123 |
|
| 124 |
|
| 125 |
if __name__ == "__main__":
|
|
|
|
| 132 |
parser = argparse.ArgumentParser(
|
| 133 |
formatter_class=argparse.MetavarTypeHelpFormatter
|
| 134 |
)
|
| 135 |
+
parser.add_argument(
|
| 136 |
+
"query", type=str, help="The question to feed to the QA system")
|
| 137 |
+
parser.add_argument(
|
| 138 |
+
"--top", "-t", type=int, default=1,
|
| 139 |
+
help="The number of answers to retrieve")
|
| 140 |
+
parser.add_argument(
|
| 141 |
+
"--retriever", "-r", type=str.lower, choices=["faiss", "es"],
|
| 142 |
+
default="faiss", help="The retrieval method to use")
|
| 143 |
+
parser.add_argument(
|
| 144 |
+
"--lm", "-l", type=str.lower,
|
| 145 |
+
choices=["dpr", "longformer"], default="dpr",
|
| 146 |
+
help="The language model to use for the FAISS retriever")
|
| 147 |
|
| 148 |
args = parser.parse_args()
|
| 149 |
main(args)
|
src/models/{paragraphs_embedding.faiss → dpr.faiss}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 5213229
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6bc0e5c38ddeb0a6a4daaf3ae98cd3e564f22ff9a263bc8867d0b363e828ccce
|
| 3 |
size 5213229
|
src/models/longformer.faiss
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:56b2616392540f4d2d8fa34d313a59c41572dca3ef5a683c7a8dbd2691418ea6
|
| 3 |
+
size 5213229
|
src/readers/base_reader.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List, Tuple
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class Reader():
|
| 5 |
+
def read(self,
|
| 6 |
+
query: str,
|
| 7 |
+
context: Dict[str, List[str]],
|
| 8 |
+
num_answers: int) -> List[Tuple]:
|
| 9 |
+
raise NotImplementedError()
|
src/readers/dpr_reader.py
CHANGED
|
@@ -1,8 +1,10 @@
|
|
| 1 |
from transformers import DPRReader, DPRReaderTokenizer
|
| 2 |
from typing import List, Dict, Tuple
|
| 3 |
|
|
|
|
| 4 |
|
| 5 |
-
|
|
|
|
| 6 |
def __init__(self) -> None:
|
| 7 |
self._tokenizer = DPRReaderTokenizer.from_pretrained(
|
| 8 |
"facebook/dpr-reader-single-nq-base")
|
|
|
|
| 1 |
from transformers import DPRReader, DPRReaderTokenizer
|
| 2 |
from typing import List, Dict, Tuple
|
| 3 |
|
| 4 |
+
from src.readers.base_reader import Reader
|
| 5 |
|
| 6 |
+
|
| 7 |
+
class DprReader(Reader):
|
| 8 |
def __init__(self) -> None:
|
| 9 |
self._tokenizer = DPRReaderTokenizer.from_pretrained(
|
| 10 |
"facebook/dpr-reader-single-nq-base")
|
src/readers/longformer_reader.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from transformers import (
|
| 3 |
+
LongformerTokenizerFast,
|
| 4 |
+
LongformerForQuestionAnswering
|
| 5 |
+
)
|
| 6 |
+
from typing import List, Dict, Tuple
|
| 7 |
+
|
| 8 |
+
from src.readers.base_reader import Reader
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class LongformerReader(Reader):
|
| 12 |
+
def __init__(self) -> None:
|
| 13 |
+
checkpoint = "valhalla/longformer-base-4096-finetuned-squadv1"
|
| 14 |
+
self.tokenizer = LongformerTokenizerFast.from_pretrained(checkpoint)
|
| 15 |
+
self.model = LongformerForQuestionAnswering.from_pretrained(checkpoint)
|
| 16 |
+
|
| 17 |
+
def read(self,
|
| 18 |
+
query: str,
|
| 19 |
+
context: Dict[str, List[str]],
|
| 20 |
+
num_answers=5) -> List[Tuple]:
|
| 21 |
+
answers = []
|
| 22 |
+
|
| 23 |
+
for text in context['texts']:
|
| 24 |
+
encoding = self.tokenizer(
|
| 25 |
+
query, text, return_tensors="pt")
|
| 26 |
+
input_ids = encoding["input_ids"]
|
| 27 |
+
attention_mask = encoding["attention_mask"]
|
| 28 |
+
outputs = self.model(input_ids, attention_mask=attention_mask)
|
| 29 |
+
|
| 30 |
+
start_logits = outputs.start_logits
|
| 31 |
+
end_logits = outputs.end_logits
|
| 32 |
+
all_tokens = self.tokenizer.convert_ids_to_tokens(
|
| 33 |
+
input_ids[0].tolist())
|
| 34 |
+
answer_tokens = all_tokens[
|
| 35 |
+
torch.argmax(start_logits):torch.argmax(end_logits) + 1]
|
| 36 |
+
answer = self.tokenizer.decode(
|
| 37 |
+
self.tokenizer.convert_tokens_to_ids(answer_tokens)
|
| 38 |
+
)
|
| 39 |
+
answers.append([answer, [], []])
|
| 40 |
+
|
| 41 |
+
return answers
|
src/retrievers/faiss_retriever.py
CHANGED
|
@@ -1,14 +1,19 @@
|
|
| 1 |
import os
|
| 2 |
import os.path
|
| 3 |
-
|
| 4 |
import torch
|
| 5 |
-
|
|
|
|
|
|
|
| 6 |
from transformers import (
|
| 7 |
DPRContextEncoder,
|
| 8 |
-
|
| 9 |
DPRQuestionEncoder,
|
| 10 |
-
|
|
|
|
|
|
|
| 11 |
)
|
|
|
|
|
|
|
| 12 |
|
| 13 |
from src.retrievers.base_retriever import RetrieveType, Retriever
|
| 14 |
from src.utils.log import get_logger
|
|
@@ -23,35 +28,99 @@ os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
|
|
| 23 |
logger = get_logger()
|
| 24 |
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
class FaissRetriever(Retriever):
|
| 27 |
"""A class used to retrieve relevant documents based on some query.
|
| 28 |
based on https://huggingface.co/docs/datasets/faiss_es#faiss.
|
| 29 |
"""
|
| 30 |
|
| 31 |
-
def __init__(self, paragraphs: DatasetDict,
|
|
|
|
| 32 |
torch.set_grad_enabled(False)
|
| 33 |
|
|
|
|
|
|
|
| 34 |
# Context encoding and tokenization
|
| 35 |
-
self.ctx_encoder =
|
| 36 |
-
|
| 37 |
-
)
|
| 38 |
-
self.ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained(
|
| 39 |
-
"facebook/dpr-ctx_encoder-single-nq-base"
|
| 40 |
-
)
|
| 41 |
|
| 42 |
# Question encoding and tokenization
|
| 43 |
-
self.q_encoder =
|
| 44 |
-
|
| 45 |
-
)
|
| 46 |
-
self.q_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
|
| 47 |
-
"facebook/dpr-question_encoder-single-nq-base"
|
| 48 |
-
)
|
| 49 |
|
| 50 |
self.paragraphs = paragraphs
|
| 51 |
-
self.embedding_path = embedding_path
|
| 52 |
|
| 53 |
self.index = self._init_index()
|
| 54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
def _init_index(
|
| 56 |
self,
|
| 57 |
force_new_embedding: bool = False):
|
|
@@ -64,16 +133,8 @@ class FaissRetriever(Retriever):
|
|
| 64 |
'embeddings', self.embedding_path) # type: ignore
|
| 65 |
return ds
|
| 66 |
else:
|
| 67 |
-
def embed(row):
|
| 68 |
-
# Inline helper function to perform embedding
|
| 69 |
-
p = row["text"]
|
| 70 |
-
tok = self.ctx_tokenizer(
|
| 71 |
-
p, return_tensors="pt", truncation=True)
|
| 72 |
-
enc = self.ctx_encoder(**tok)[0][0].numpy()
|
| 73 |
-
return {"embeddings": enc}
|
| 74 |
-
|
| 75 |
# Add FAISS embeddings
|
| 76 |
-
index = ds.map(
|
| 77 |
|
| 78 |
index.add_faiss_index(column="embeddings")
|
| 79 |
|
|
@@ -86,12 +147,7 @@ class FaissRetriever(Retriever):
|
|
| 86 |
|
| 87 |
@timeit("faissretriever.retrieve")
|
| 88 |
def retrieve(self, query: str, k: int = 5) -> RetrieveType:
|
| 89 |
-
|
| 90 |
-
# Inline helper function to perform embedding
|
| 91 |
-
tok = self.q_tokenizer(q, return_tensors="pt", truncation=True)
|
| 92 |
-
return self.q_encoder(**tok)[0][0].numpy()
|
| 93 |
-
|
| 94 |
-
question_embedding = embed(query)
|
| 95 |
scores, results = self.index.get_nearest_examples(
|
| 96 |
"embeddings", question_embedding, k=k
|
| 97 |
)
|
|
|
|
| 1 |
import os
|
| 2 |
import os.path
|
|
|
|
| 3 |
import torch
|
| 4 |
+
|
| 5 |
+
from datasets import DatasetDict
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
from transformers import (
|
| 8 |
DPRContextEncoder,
|
| 9 |
+
DPRContextEncoderTokenizerFast,
|
| 10 |
DPRQuestionEncoder,
|
| 11 |
+
DPRQuestionEncoderTokenizerFast,
|
| 12 |
+
LongformerModel,
|
| 13 |
+
LongformerTokenizerFast
|
| 14 |
)
|
| 15 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 16 |
+
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
|
| 17 |
|
| 18 |
from src.retrievers.base_retriever import RetrieveType, Retriever
|
| 19 |
from src.utils.log import get_logger
|
|
|
|
| 28 |
logger = get_logger()
|
| 29 |
|
| 30 |
|
| 31 |
+
@dataclass
|
| 32 |
+
class FaissRetrieverOptions:
|
| 33 |
+
ctx_encoder: PreTrainedModel
|
| 34 |
+
ctx_tokenizer: PreTrainedTokenizerFast
|
| 35 |
+
q_encoder: PreTrainedModel
|
| 36 |
+
q_tokenizer: PreTrainedTokenizerFast
|
| 37 |
+
embedding_path: str
|
| 38 |
+
lm: str
|
| 39 |
+
|
| 40 |
+
@staticmethod
|
| 41 |
+
def dpr(embedding_path: str):
|
| 42 |
+
return FaissRetrieverOptions(
|
| 43 |
+
ctx_encoder=DPRContextEncoder.from_pretrained(
|
| 44 |
+
"facebook/dpr-ctx_encoder-single-nq-base"
|
| 45 |
+
),
|
| 46 |
+
ctx_tokenizer=DPRContextEncoderTokenizerFast.from_pretrained(
|
| 47 |
+
"facebook/dpr-ctx_encoder-single-nq-base"
|
| 48 |
+
),
|
| 49 |
+
q_encoder=DPRQuestionEncoder.from_pretrained(
|
| 50 |
+
"facebook/dpr-question_encoder-single-nq-base"
|
| 51 |
+
),
|
| 52 |
+
q_tokenizer=DPRQuestionEncoderTokenizerFast.from_pretrained(
|
| 53 |
+
"facebook/dpr-question_encoder-single-nq-base"
|
| 54 |
+
),
|
| 55 |
+
embedding_path=embedding_path,
|
| 56 |
+
lm="dpr"
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
@staticmethod
|
| 60 |
+
def longformer(embedding_path: str):
|
| 61 |
+
encoder = LongformerModel.from_pretrained(
|
| 62 |
+
"allenai/longformer-base-4096"
|
| 63 |
+
)
|
| 64 |
+
tokenizer = LongformerTokenizerFast.from_pretrained(
|
| 65 |
+
"allenai/longformer-base-4096"
|
| 66 |
+
)
|
| 67 |
+
return FaissRetrieverOptions(
|
| 68 |
+
ctx_encoder=encoder,
|
| 69 |
+
ctx_tokenizer=tokenizer,
|
| 70 |
+
q_encoder=encoder,
|
| 71 |
+
q_tokenizer=tokenizer,
|
| 72 |
+
embedding_path=embedding_path,
|
| 73 |
+
lm="longformer"
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
class FaissRetriever(Retriever):
|
| 78 |
"""A class used to retrieve relevant documents based on some query.
|
| 79 |
based on https://huggingface.co/docs/datasets/faiss_es#faiss.
|
| 80 |
"""
|
| 81 |
|
| 82 |
+
def __init__(self, paragraphs: DatasetDict,
|
| 83 |
+
options: FaissRetrieverOptions) -> None:
|
| 84 |
torch.set_grad_enabled(False)
|
| 85 |
|
| 86 |
+
self.lm = options.lm
|
| 87 |
+
|
| 88 |
# Context encoding and tokenization
|
| 89 |
+
self.ctx_encoder = options.ctx_encoder
|
| 90 |
+
self.ctx_tokenizer = options.ctx_tokenizer
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
|
| 92 |
# Question encoding and tokenization
|
| 93 |
+
self.q_encoder = options.q_encoder
|
| 94 |
+
self.q_tokenizer = options.q_tokenizer
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
|
| 96 |
self.paragraphs = paragraphs
|
| 97 |
+
self.embedding_path = options.embedding_path
|
| 98 |
|
| 99 |
self.index = self._init_index()
|
| 100 |
|
| 101 |
+
def _embed_question(self, q):
|
| 102 |
+
match self.lm:
|
| 103 |
+
case "dpr":
|
| 104 |
+
tok = self.q_tokenizer(q, return_tensors="pt", truncation=True)
|
| 105 |
+
return self.q_encoder(**tok)[0][0].numpy()
|
| 106 |
+
case "longformer":
|
| 107 |
+
tok = self.q_tokenizer(q, return_tensors="pt")
|
| 108 |
+
return self.q_encoder(**tok).last_hidden_state[0][0].numpy()
|
| 109 |
+
|
| 110 |
+
def _embed_context(self, row):
|
| 111 |
+
p = row["text"]
|
| 112 |
+
|
| 113 |
+
match self.lm:
|
| 114 |
+
case "dpr":
|
| 115 |
+
tok = self.ctx_tokenizer(
|
| 116 |
+
p, return_tensors="pt", truncation=True)
|
| 117 |
+
enc = self.ctx_encoder(**tok)[0][0].numpy()
|
| 118 |
+
return {"embeddings": enc}
|
| 119 |
+
case "longformer":
|
| 120 |
+
tok = self.ctx_tokenizer(p, return_tensors="pt")
|
| 121 |
+
enc = self.ctx_encoder(**tok).last_hidden_state[0][0].numpy()
|
| 122 |
+
return {"embeddings": enc}
|
| 123 |
+
|
| 124 |
def _init_index(
|
| 125 |
self,
|
| 126 |
force_new_embedding: bool = False):
|
|
|
|
| 133 |
'embeddings', self.embedding_path) # type: ignore
|
| 134 |
return ds
|
| 135 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
# Add FAISS embeddings
|
| 137 |
+
index = ds.map(self._embed_context) # type: ignore
|
| 138 |
|
| 139 |
index.add_faiss_index(column="embeddings")
|
| 140 |
|
|
|
|
| 147 |
|
| 148 |
@timeit("faissretriever.retrieve")
|
| 149 |
def retrieve(self, query: str, k: int = 5) -> RetrieveType:
|
| 150 |
+
question_embedding = self._embed_question(query)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
scores, results = self.index.get_nearest_examples(
|
| 152 |
"embeddings", question_embedding, k=k
|
| 153 |
)
|