File size: 5,047 Bytes
83870cc
 
 
 
 
 
8bbe3aa
 
83870cc
2827202
 
8bbe3aa
83870cc
 
 
8bbe3aa
83870cc
 
 
 
8bbe3aa
 
 
 
aa426fb
8bbe3aa
 
 
 
 
 
 
 
 
 
 
83870cc
 
8bbe3aa
83870cc
 
8bbe3aa
 
 
83870cc
 
8bbe3aa
83870cc
 
8bbe3aa
 
aa426fb
 
8bbe3aa
aa426fb
 
 
8bbe3aa
 
 
 
83870cc
 
8bbe3aa
 
 
 
 
 
aa426fb
2827202
8bbe3aa
aa426fb
83870cc
aa426fb
83870cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8bbe3aa
83870cc
 
aa426fb
8bbe3aa
83870cc
8bbe3aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83870cc
8bbe3aa
 
 
 
83870cc
 
2827202
8bbe3aa
2827202
 
 
 
 
 
8fe5a80
2827202
 
aa426fb
 
2827202
 
 
 
 
 
 
 
 
 
 
 
aa426fb
 
 
 
2827202
8fe5a80
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
from transformers import (
    DPRContextEncoder,
    DPRContextEncoderTokenizer,
    DPRQuestionEncoder,
    DPRQuestionEncoderTokenizer,
)
from datasets import load_dataset
import torch
import os.path

import evaluate

# Hacky fix for FAISS error on macOS
# See https://stackoverflow.com/a/63374568/4545692
import os

os.environ["KMP_DUPLICATE_LIB_OK"] = "True"


class Retriever:
    """A class used to retrieve relevant documents based on some query.
    based on https://huggingface.co/docs/datasets/faiss_es#faiss.
    """

    def __init__(self, dataset_name: str = "GroNLP/ik-nlp-22_slp") -> None:
        """Initialize the retriever

        Args:
            dataset (str, optional): The dataset to train on. Assumes the
            information is stored in a column named 'text'. Defaults to
            "GroNLP/ik-nlp-22_slp".
        """
        torch.set_grad_enabled(False)

        # Context encoding and tokenization
        self.ctx_encoder = DPRContextEncoder.from_pretrained(
            "facebook/dpr-ctx_encoder-single-nq-base"
        )
        self.ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained(
            "facebook/dpr-ctx_encoder-single-nq-base"
        )

        # Question encoding and tokenization
        self.q_encoder = DPRQuestionEncoder.from_pretrained(
            "facebook/dpr-question_encoder-single-nq-base"
        )
        self.q_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
            "facebook/dpr-question_encoder-single-nq-base"
        )

        # Dataset building
        self.dataset_name = dataset_name
        self.dataset = self._init_dataset(dataset_name)

    def _init_dataset(self,
                      dataset_name: str,
                      embedding_path: str = "./models/paragraphs_embedding.faiss"):
        """Loads the dataset and adds FAISS embeddings.

        Args:
            dataset (str): A HuggingFace dataset name.
            fname (str): The name to use to save the embeddings to disk for 
            faster loading after the first run.

        Returns:
            Dataset: A dataset with a new column 'embeddings' containing FAISS
            embeddings.
        """
        # Load dataset
        ds = load_dataset(dataset_name, name="paragraphs")["train"]
        print(ds)

        if os.path.exists(embedding_path):
            # If we already have FAISS embeddings, load them from disk
            ds.load_faiss_index('embeddings', embedding_path)
            return ds
        else:
            # If there are no FAISS embeddings, generate them
            def embed(row):
                # Inline helper function to perform embedding
                p = row["text"]
                tok = self.ctx_tokenizer(
                    p, return_tensors="pt", truncation=True)
                enc = self.ctx_encoder(**tok)[0][0].numpy()
                return {"embeddings": enc}

            # Add FAISS embeddings
            ds_with_embeddings = ds.map(embed)

            ds_with_embeddings.add_faiss_index(column="embeddings")

            # save dataset w/ embeddings
            os.makedirs("./models/", exist_ok=True)
            ds_with_embeddings.save_faiss_index("embeddings", embedding_path)

            return ds_with_embeddings

    def retrieve(self, query: str, k: int = 5):
        """Retrieve the top k matches for a search query.

        Args:
            query (str): A search query
            k (int, optional): The number of documents to retrieve. Defaults to
            5.

        Returns:
            tuple: A tuple of lists of scores and results.
        """

        def embed(q):
            # Inline helper function to perform embedding
            tok = self.q_tokenizer(q, return_tensors="pt", truncation=True)
            return self.q_encoder(**tok)[0][0].numpy()

        question_embedding = embed(query)
        scores, results = self.dataset.get_nearest_examples(
            "embeddings", question_embedding, k=k
        )

        return scores, results

    def evaluate(self):
        """Evaluates the entire model by computing F1-score and exact match on the
        entire dataset.

        Returns:
            float: overall exact match
            float: overall F1-score
        """
        questions_ds = load_dataset(
            self.dataset_name, name="questions")['test']
        questions = questions_ds['question']
        answers = questions_ds['answer']

        predictions = []
        scores = 0

        # Currently just takes the first answer and does not look at scores yet
        for question in questions:
            score, result = self.retrieve(question, 1)
            scores += score[0]
            predictions.append(result['text'][0])

        exact_matches = [evaluate.compute_exact_match(
            predictions[i], answers[i]) for i in range(len(answers))]
        f1_scores = [evaluate.compute_f1(
            predictions[i], answers[i]) for i in range(len(answers))]

        return sum(exact_matches) / len(exact_matches), sum(f1_scores) / len(f1_scores)