File size: 8,041 Bytes
2f044c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import logging
import time
from pathlib import Path
from typing import List, Optional, Set

import lightning as pl
import torch
from lightning.pytorch.trainer.states import RunningStage
from omegaconf import DictConfig
from torch.utils.data import DataLoader
from tqdm import tqdm

from relik.common.log import get_logger
from relik.retriever.callbacks.base import NLPTemplateCallback, PredictionCallback
from relik.retriever.common.model_inputs import ModelInputs
from relik.retriever.data.base.datasets import BaseDataset
from relik.retriever.data.datasets import GoldenRetrieverDataset
from relik.retriever.indexers.base import BaseDocumentIndex
from relik.retriever.pytorch_modules.model import GoldenRetriever

logger = get_logger(__name__, level=logging.INFO)


class GoldenRetrieverPredictionCallback(PredictionCallback):
    def __init__(
        self,
        k: int | None = None,
        batch_size: int = 32,
        num_workers: int = 8,
        document_index: BaseDocumentIndex | None = None,
        precision: str | int = 32,
        force_reindex: bool = True,
        retriever_dir: Optional[Path] = None,
        stages: Set[str | RunningStage] | None = None,
        other_callbacks: List[DictConfig] | List[NLPTemplateCallback] | None = None,
        dataset: DictConfig | BaseDataset | None = None,
        dataloader: DataLoader | None = None,
        *args,
        **kwargs,
    ):
        super().__init__(batch_size, stages, other_callbacks, dataset, dataloader)
        self.k = k
        self.num_workers = num_workers
        self.document_index = document_index
        self.precision = precision
        self.force_reindex = force_reindex
        self.retriever_dir = retriever_dir

    @torch.no_grad()
    def __call__(
        self,
        trainer: pl.Trainer,
        pl_module: pl.LightningModule,
        datasets: DictConfig
        | BaseDataset
        | List[DictConfig]
        | List[BaseDataset]
        | None = None,
        dataloaders: DataLoader | List[DataLoader] | None = None,
        *args,
        **kwargs,
    ) -> dict:
        stage = trainer.state.stage
        logger.info(f"Computing predictions for stage {stage.value}")
        if stage not in self.stages:
            raise ValueError(
                f"Stage `{stage}` not supported, only {self.stages} are supported"
            )

        self.datasets, self.dataloaders = self._get_datasets_and_dataloaders(
            datasets,
            dataloaders,
            trainer,
            dataloader_kwargs=dict(
                batch_size=self.batch_size,
                num_workers=self.num_workers,
                pin_memory=True,
                shuffle=False,
            ),
        )

        # set the model to eval mode
        pl_module.eval()
        # get the retriever
        retriever: GoldenRetriever = pl_module.model

        # here we will store the samples with predictions for each dataloader
        dataloader_predictions = {}
        # compute the passage embeddings index for each dataloader
        for dataloader_idx, dataloader in enumerate(self.dataloaders):
            current_dataset: GoldenRetrieverDataset = self.datasets[dataloader_idx]
            logger.info(
                f"Computing passage embeddings for dataset {current_dataset.name}"
            )

            tokenizer = current_dataset.tokenizer

            def collate_fn(x):
                return ModelInputs(
                    tokenizer(
                        x,
                        truncation=True,
                        padding=True,
                        max_length=current_dataset.max_passage_length,
                        return_tensors="pt",
                    )
                )

            # check if we need to reindex the passages and
            # also if we need to load the retriever from disk
            if (self.retriever_dir is not None and trainer.current_epoch == 0) or (
                self.retriever_dir is not None and stage == RunningStage.TESTING
            ):
                force_reindex = False
            else:
                force_reindex = self.force_reindex

            if (
                not force_reindex
                and self.retriever_dir is not None
                and stage == RunningStage.TESTING
            ):
                retriever = retriever.from_pretrained(self.retriever_dir)

            # you never know :)
            retriever.eval()

            retriever.index(
                batch_size=self.batch_size,
                num_workers=self.num_workers,
                max_length=current_dataset.max_passage_length,
                collate_fn=collate_fn,
                precision=self.precision,
                compute_on_cpu=False,
                force_reindex=force_reindex,
            )

            # now compute the question embeddings and compute the top-k accuracy
            predictions = []
            start = time.time()
            for batch in tqdm(
                dataloader,
                desc=f"Computing predictions for dataset {current_dataset.name}",
            ):
                batch = batch.to(pl_module.device)
                # get the top-k indices
                retriever_output = retriever.retrieve(
                    **batch.questions, k=self.k, precision=self.precision
                )
                # compute recall at k
                for batch_idx, retrieved_samples in enumerate(retriever_output):
                    # get the positive passages
                    gold_passages = batch["positives"][batch_idx]
                    # get the index of the gold passages in the retrieved passages
                    gold_passage_indices = []
                    for passage in gold_passages:
                        try:
                            gold_passage_indices.append(
                                retriever.get_index_from_passage(passage)
                            )
                        except ValueError:
                            logger.warning(
                                f"Passage `{passage}` not found in the index. "
                                "We will skip it, but the results might not reflect the "
                                "actual performance."
                            )
                            pass
                    retrieved_indices = [r.document.id for r in retrieved_samples if r]
                    retrieved_passages = [
                        retriever.get_passage_from_index(i) for i in retrieved_indices
                    ]
                    retrieved_scores = [r.score for r in retrieved_samples]
                    # correct predictions are the passages that are in the top-k and are gold
                    correct_indices = set(gold_passage_indices) & set(retrieved_indices)
                    # wrong predictions are the passages that are in the top-k and are not gold
                    wrong_indices = set(retrieved_indices) - set(gold_passage_indices)
                    # add the predictions to the list
                    prediction_output = dict(
                        sample_idx=batch.sample_idx[batch_idx],
                        gold=gold_passages,
                        predictions=retrieved_passages,
                        scores=retrieved_scores,
                        correct=[
                            retriever.get_passage_from_index(i) for i in correct_indices
                        ],
                        wrong=[
                            retriever.get_passage_from_index(i) for i in wrong_indices
                        ],
                    )
                    predictions.append(prediction_output)
            end = time.time()
            logger.info(f"Time to retrieve: {str(end - start)}")

            dataloader_predictions[dataloader_idx] = predictions

            # if pl_module_original_device != pl_module.device:
            #     pl_module.to(pl_module_original_device)

        # return the predictions
        return dataloader_predictions