File size: 5,047 Bytes
83870cc 8bbe3aa 83870cc 2827202 8bbe3aa 83870cc 8bbe3aa 83870cc 8bbe3aa aa426fb 8bbe3aa 83870cc 8bbe3aa 83870cc 8bbe3aa 83870cc 8bbe3aa 83870cc 8bbe3aa aa426fb 8bbe3aa aa426fb 8bbe3aa 83870cc 8bbe3aa aa426fb 2827202 8bbe3aa aa426fb 83870cc aa426fb 83870cc 8bbe3aa 83870cc aa426fb 8bbe3aa 83870cc 8bbe3aa 83870cc 8bbe3aa 83870cc 2827202 8bbe3aa 2827202 8fe5a80 2827202 aa426fb 2827202 aa426fb 2827202 8fe5a80 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
from transformers import (
DPRContextEncoder,
DPRContextEncoderTokenizer,
DPRQuestionEncoder,
DPRQuestionEncoderTokenizer,
)
from datasets import load_dataset
import torch
import os.path
import evaluate
# Hacky fix for FAISS error on macOS
# See https://stackoverflow.com/a/63374568/4545692
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
class Retriever:
"""A class used to retrieve relevant documents based on some query.
based on https://huggingface.co/docs/datasets/faiss_es#faiss.
"""
def __init__(self, dataset_name: str = "GroNLP/ik-nlp-22_slp") -> None:
"""Initialize the retriever
Args:
dataset (str, optional): The dataset to train on. Assumes the
information is stored in a column named 'text'. Defaults to
"GroNLP/ik-nlp-22_slp".
"""
torch.set_grad_enabled(False)
# Context encoding and tokenization
self.ctx_encoder = DPRContextEncoder.from_pretrained(
"facebook/dpr-ctx_encoder-single-nq-base"
)
self.ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained(
"facebook/dpr-ctx_encoder-single-nq-base"
)
# Question encoding and tokenization
self.q_encoder = DPRQuestionEncoder.from_pretrained(
"facebook/dpr-question_encoder-single-nq-base"
)
self.q_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
"facebook/dpr-question_encoder-single-nq-base"
)
# Dataset building
self.dataset_name = dataset_name
self.dataset = self._init_dataset(dataset_name)
def _init_dataset(self,
dataset_name: str,
embedding_path: str = "./models/paragraphs_embedding.faiss"):
"""Loads the dataset and adds FAISS embeddings.
Args:
dataset (str): A HuggingFace dataset name.
fname (str): The name to use to save the embeddings to disk for
faster loading after the first run.
Returns:
Dataset: A dataset with a new column 'embeddings' containing FAISS
embeddings.
"""
# Load dataset
ds = load_dataset(dataset_name, name="paragraphs")["train"]
print(ds)
if os.path.exists(embedding_path):
# If we already have FAISS embeddings, load them from disk
ds.load_faiss_index('embeddings', embedding_path)
return ds
else:
# If there are no FAISS embeddings, generate them
def embed(row):
# Inline helper function to perform embedding
p = row["text"]
tok = self.ctx_tokenizer(
p, return_tensors="pt", truncation=True)
enc = self.ctx_encoder(**tok)[0][0].numpy()
return {"embeddings": enc}
# Add FAISS embeddings
ds_with_embeddings = ds.map(embed)
ds_with_embeddings.add_faiss_index(column="embeddings")
# save dataset w/ embeddings
os.makedirs("./models/", exist_ok=True)
ds_with_embeddings.save_faiss_index("embeddings", embedding_path)
return ds_with_embeddings
def retrieve(self, query: str, k: int = 5):
"""Retrieve the top k matches for a search query.
Args:
query (str): A search query
k (int, optional): The number of documents to retrieve. Defaults to
5.
Returns:
tuple: A tuple of lists of scores and results.
"""
def embed(q):
# Inline helper function to perform embedding
tok = self.q_tokenizer(q, return_tensors="pt", truncation=True)
return self.q_encoder(**tok)[0][0].numpy()
question_embedding = embed(query)
scores, results = self.dataset.get_nearest_examples(
"embeddings", question_embedding, k=k
)
return scores, results
def evaluate(self):
"""Evaluates the entire model by computing F1-score and exact match on the
entire dataset.
Returns:
float: overall exact match
float: overall F1-score
"""
questions_ds = load_dataset(
self.dataset_name, name="questions")['test']
questions = questions_ds['question']
answers = questions_ds['answer']
predictions = []
scores = 0
# Currently just takes the first answer and does not look at scores yet
for question in questions:
score, result = self.retrieve(question, 1)
scores += score[0]
predictions.append(result['text'][0])
exact_matches = [evaluate.compute_exact_match(
predictions[i], answers[i]) for i in range(len(answers))]
f1_scores = [evaluate.compute_f1(
predictions[i], answers[i]) for i in range(len(answers))]
return sum(exact_matches) / len(exact_matches), sum(f1_scores) / len(f1_scores)
|