File size: 5,810 Bytes
2f044c1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
import hydra
import lightning as pl
import torch
from lightning.pytorch.trainer.states import RunningStage
from omegaconf import DictConfig
from torch.utils.data import DataLoader, Dataset
from relik.common.log import get_logger
from relik.retriever.data.base.datasets import BaseDataset
logger = get_logger(__name__)
STAGES_COMPATIBILITY_MAP = {
"train": RunningStage.TRAINING,
"val": RunningStage.VALIDATING,
"test": RunningStage.TESTING,
}
DEFAULT_STAGES = {
RunningStage.VALIDATING,
RunningStage.TESTING,
RunningStage.SANITY_CHECKING,
RunningStage.PREDICTING,
}
class PredictionCallback(pl.Callback):
def __init__(
self,
batch_size: int = 32,
stages: Optional[Set[Union[str, RunningStage]]] = None,
other_callbacks: Optional[
Union[List[DictConfig], List["NLPTemplateCallback"]]
] = None,
datasets: Optional[Union[DictConfig, BaseDataset]] = None,
dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
*args,
**kwargs,
):
super().__init__()
# parameters
self.batch_size = batch_size
self.datasets = datasets
self.dataloaders = dataloaders
# callback initialization
if stages is None:
stages = DEFAULT_STAGES
# compatibily stuff
stages = {STAGES_COMPATIBILITY_MAP.get(stage, stage) for stage in stages}
self.stages = [RunningStage(stage) for stage in stages]
self.other_callbacks = other_callbacks or []
for i, callback in enumerate(self.other_callbacks):
if isinstance(callback, DictConfig):
self.other_callbacks[i] = hydra.utils.instantiate(
callback, _recursive_=False
)
@torch.no_grad()
def __call__(
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
*args,
**kwargs,
) -> Any:
# it should return the predictions
raise NotImplementedError
def on_validation_epoch_end(
self, trainer: pl.Trainer, pl_module: pl.LightningModule
):
predictions = self(trainer, pl_module)
for callback in self.other_callbacks:
callback(
trainer=trainer,
pl_module=pl_module,
callback=self,
predictions=predictions,
)
def on_test_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
predictions = self(trainer, pl_module)
for callback in self.other_callbacks:
callback(
trainer=trainer,
pl_module=pl_module,
callback=self,
predictions=predictions,
)
@staticmethod
def _get_datasets_and_dataloaders(
dataset: Optional[Union[Dataset, DictConfig]],
dataloader: Optional[DataLoader],
trainer: pl.Trainer,
dataloader_kwargs: Optional[Dict[str, Any]] = None,
collate_fn: Optional[Callable] = None,
collate_fn_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[List[Dataset], List[DataLoader]]:
"""
Get the datasets and dataloaders from the datamodule or from the dataset provided.
Args:
dataset (`Optional[Union[Dataset, DictConfig]]`):
The dataset to use. If `None`, the datamodule is used.
dataloader (`Optional[DataLoader]`):
The dataloader to use. If `None`, the datamodule is used.
trainer (`pl.Trainer`):
The trainer that contains the datamodule.
dataloader_kwargs (`Optional[Dict[str, Any]]`):
The kwargs to pass to the dataloader.
collate_fn (`Optional[Callable]`):
The collate function to use.
collate_fn_kwargs (`Optional[Dict[str, Any]]`):
The kwargs to pass to the collate function.
Returns:
`Tuple[List[Dataset], List[DataLoader]]`: The datasets and dataloaders.
"""
# if a dataset is provided, use it
if dataset is not None:
dataloader_kwargs = dataloader_kwargs or {}
# get dataset
if isinstance(dataset, DictConfig):
dataset = hydra.utils.instantiate(dataset, _recursive_=False)
datasets = [dataset] if not isinstance(dataset, list) else dataset
if dataloader is not None:
dataloaders = (
[dataloader] if isinstance(dataloader, DataLoader) else dataloader
)
else:
collate_fn = collate_fn or partial(
datasets[0].collate_fn, **collate_fn_kwargs
)
dataloader_kwargs["collate_fn"] = collate_fn
dataloaders = [DataLoader(datasets[0], **dataloader_kwargs)]
else:
# get the dataloaders and datasets from the datamodule
datasets = (
trainer.datamodule.test_datasets
if trainer.state.stage == RunningStage.TESTING
else trainer.datamodule.val_datasets
)
dataloaders = (
trainer.test_dataloaders
if trainer.state.stage == RunningStage.TESTING
else trainer.val_dataloaders
)
return datasets, dataloaders
class NLPTemplateCallback:
def __call__(
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
callback: PredictionCallback,
predictions: Dict[str, Any],
*args,
**kwargs,
) -> Any:
raise NotImplementedError
|