File size: 8,041 Bytes
2f044c1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 |
import logging
import time
from pathlib import Path
from typing import List, Optional, Set
import lightning as pl
import torch
from lightning.pytorch.trainer.states import RunningStage
from omegaconf import DictConfig
from torch.utils.data import DataLoader
from tqdm import tqdm
from relik.common.log import get_logger
from relik.retriever.callbacks.base import NLPTemplateCallback, PredictionCallback
from relik.retriever.common.model_inputs import ModelInputs
from relik.retriever.data.base.datasets import BaseDataset
from relik.retriever.data.datasets import GoldenRetrieverDataset
from relik.retriever.indexers.base import BaseDocumentIndex
from relik.retriever.pytorch_modules.model import GoldenRetriever
logger = get_logger(__name__, level=logging.INFO)
class GoldenRetrieverPredictionCallback(PredictionCallback):
def __init__(
self,
k: int | None = None,
batch_size: int = 32,
num_workers: int = 8,
document_index: BaseDocumentIndex | None = None,
precision: str | int = 32,
force_reindex: bool = True,
retriever_dir: Optional[Path] = None,
stages: Set[str | RunningStage] | None = None,
other_callbacks: List[DictConfig] | List[NLPTemplateCallback] | None = None,
dataset: DictConfig | BaseDataset | None = None,
dataloader: DataLoader | None = None,
*args,
**kwargs,
):
super().__init__(batch_size, stages, other_callbacks, dataset, dataloader)
self.k = k
self.num_workers = num_workers
self.document_index = document_index
self.precision = precision
self.force_reindex = force_reindex
self.retriever_dir = retriever_dir
@torch.no_grad()
def __call__(
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
datasets: DictConfig
| BaseDataset
| List[DictConfig]
| List[BaseDataset]
| None = None,
dataloaders: DataLoader | List[DataLoader] | None = None,
*args,
**kwargs,
) -> dict:
stage = trainer.state.stage
logger.info(f"Computing predictions for stage {stage.value}")
if stage not in self.stages:
raise ValueError(
f"Stage `{stage}` not supported, only {self.stages} are supported"
)
self.datasets, self.dataloaders = self._get_datasets_and_dataloaders(
datasets,
dataloaders,
trainer,
dataloader_kwargs=dict(
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=True,
shuffle=False,
),
)
# set the model to eval mode
pl_module.eval()
# get the retriever
retriever: GoldenRetriever = pl_module.model
# here we will store the samples with predictions for each dataloader
dataloader_predictions = {}
# compute the passage embeddings index for each dataloader
for dataloader_idx, dataloader in enumerate(self.dataloaders):
current_dataset: GoldenRetrieverDataset = self.datasets[dataloader_idx]
logger.info(
f"Computing passage embeddings for dataset {current_dataset.name}"
)
tokenizer = current_dataset.tokenizer
def collate_fn(x):
return ModelInputs(
tokenizer(
x,
truncation=True,
padding=True,
max_length=current_dataset.max_passage_length,
return_tensors="pt",
)
)
# check if we need to reindex the passages and
# also if we need to load the retriever from disk
if (self.retriever_dir is not None and trainer.current_epoch == 0) or (
self.retriever_dir is not None and stage == RunningStage.TESTING
):
force_reindex = False
else:
force_reindex = self.force_reindex
if (
not force_reindex
and self.retriever_dir is not None
and stage == RunningStage.TESTING
):
retriever = retriever.from_pretrained(self.retriever_dir)
# you never know :)
retriever.eval()
retriever.index(
batch_size=self.batch_size,
num_workers=self.num_workers,
max_length=current_dataset.max_passage_length,
collate_fn=collate_fn,
precision=self.precision,
compute_on_cpu=False,
force_reindex=force_reindex,
)
# now compute the question embeddings and compute the top-k accuracy
predictions = []
start = time.time()
for batch in tqdm(
dataloader,
desc=f"Computing predictions for dataset {current_dataset.name}",
):
batch = batch.to(pl_module.device)
# get the top-k indices
retriever_output = retriever.retrieve(
**batch.questions, k=self.k, precision=self.precision
)
# compute recall at k
for batch_idx, retrieved_samples in enumerate(retriever_output):
# get the positive passages
gold_passages = batch["positives"][batch_idx]
# get the index of the gold passages in the retrieved passages
gold_passage_indices = []
for passage in gold_passages:
try:
gold_passage_indices.append(
retriever.get_index_from_passage(passage)
)
except ValueError:
logger.warning(
f"Passage `{passage}` not found in the index. "
"We will skip it, but the results might not reflect the "
"actual performance."
)
pass
retrieved_indices = [r.document.id for r in retrieved_samples if r]
retrieved_passages = [
retriever.get_passage_from_index(i) for i in retrieved_indices
]
retrieved_scores = [r.score for r in retrieved_samples]
# correct predictions are the passages that are in the top-k and are gold
correct_indices = set(gold_passage_indices) & set(retrieved_indices)
# wrong predictions are the passages that are in the top-k and are not gold
wrong_indices = set(retrieved_indices) - set(gold_passage_indices)
# add the predictions to the list
prediction_output = dict(
sample_idx=batch.sample_idx[batch_idx],
gold=gold_passages,
predictions=retrieved_passages,
scores=retrieved_scores,
correct=[
retriever.get_passage_from_index(i) for i in correct_indices
],
wrong=[
retriever.get_passage_from_index(i) for i in wrong_indices
],
)
predictions.append(prediction_output)
end = time.time()
logger.info(f"Time to retrieve: {str(end - start)}")
dataloader_predictions[dataloader_idx] = predictions
# if pl_module_original_device != pl_module.device:
# pl_module.to(pl_module_original_device)
# return the predictions
return dataloader_predictions
|