# from itertools import product from numbers import Number from pathlib import Path from typing import Any, Dict, Optional, Sequence, Union, Literal # import numpy as np import pandas as pd from lightning import LightningDataModule from sklearn.base import TransformerMixin from torch.utils.data import Dataset, DataLoader, random_split from deepscreen.data.utils.dataset import SingleEntitySingleTargetDataset, BaseEntityDataset from deepscreen.data.utils.label import label_transform from deepscreen.data.utils.collator import collate_fn from deepscreen.data.utils.sampler import SafeBatchSampler class EntityDataModule(LightningDataModule): """ DTI DataModule A DataModule implements 5 key methods: def prepare_data(self): # things to do on 1 GPU/TPU (not on every GPU/TPU in DDP) # download data, pre-process, split, save to disk, etc. def setup(self, stage): # things to do on every process in DDP # load data, set variables, etc. def train_dataloader(self): # return train dataloader def val_dataloader(self): # return validation dataloader def test_dataloader(self): # return test dataloader def teardown(self): # called on every process in DDP # clean up after fit or test This allows you to share a full dataset without explaining how to download, split, transform and process the data. Read the docs: https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html """ def __init__( self, dataset: type[BaseEntityDataset], task: Literal['regression', 'binary', 'multiclass'], n_classes: Optional[int], train: bool, batch_size: int, num_workers: int = 0, thresholds: Optional[Union[Number, Sequence[Number]]] = None, pin_memory: bool = False, data_dir: str = "data/", data_file: Optional[str] = None, train_val_test_split: Optional[Sequence[Number], Sequence[str]] = None, split: Optional[callable] = random_split, ): super().__init__() data_path = Path(data_dir) / data_file # this line allows to access init params with 'self.hparams' attribute # also ensures init params will be stored in ckpt self.save_hyperparameters(logger=False) # data processing self.split = split if train: if all([data_file, split]): if all(isinstance(split, Number) for split in train_val_test_split): pass else: raise ValueError('`train_val_test_split` must be a sequence of 3 numbers ' '(float for percentages and int for sample numbers) if ' '`data_file` and `split` have been specified.') elif all(isinstance(split, str) for split in train_val_test_split) and not any([data_file, split]): self.train_data = dataset(dataset_path=str(Path(data_dir) / train_val_test_split[0])) self.val_data = dataset(dataset_path=str(Path(data_dir) / train_val_test_split[1])) self.test_data = dataset(dataset_path=str(Path(data_dir) / train_val_test_split[2])) else: raise ValueError('For training (train=True), you must specify either ' '`dataset_name` and `split` with `train_val_test_split` of 3 numbers or ' 'solely `train_val_test_split` of 3 data file names.') else: if data_file and not any([split, train_val_test_split]): self.test_data = self.predict_data = dataset(dataset_path=str(Path(data_dir) / data_file)) else: raise ValueError("For testing/predicting (train=False), you must specify only `data_file` without " "`train_val_test_split` or `split`") def prepare_data(self): """ Download data if needed. Do not use it to assign state (e.g., self.x = x). """ def setup(self, stage: Optional[str] = None, encoding: str = None): """ Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`. This method is called by lightning with both `trainer.fit()` and `trainer.test()`, so be careful not to execute data splitting twice. """ # load and split datasets only if not loaded in initialization if not any([self.data_train, self.data_val, self.data_test, self.data_predict]): dataset = SingleEntitySingleTargetDataset( task=self.hparams.task, n_classes=self.hparams.n_classes, dataset_path=Path(self.hparams.data_dir) / self.hparams.dataset_name, transformer=self.hparams.transformer, featurizer=self.hparams.featurizer, thresholds=self.hparams.thresholds, ) if self.hparams.train: self.data_train, self.data_val, self.data_test = self.split( dataset=dataset, lengths=self.hparams.train_val_test_split ) else: self.data_test = self.data_predict = dataset def train_dataloader(self): return DataLoader( dataset=self.data_train, batch_sampler=SafeBatchSampler( data_source=self.data_train, batch_size=self.hparams.batch_size, shuffle=True), # batch_size=self.hparams.batch_size, # shuffle=True, num_workers=self.hparams.num_workers, pin_memory=self.hparams.pin_memory, collate_fn=collate_fn, persistent_workers=True if self.hparams.num_workers > 0 else False ) def val_dataloader(self): return DataLoader( dataset=self.data_val, batch_sampler=SafeBatchSampler( data_source=self.data_val, batch_size=self.hparams.batch_size, shuffle=False), # batch_size=self.hparams.batch_size, # shuffle=False, num_workers=self.hparams.num_workers, pin_memory=self.hparams.pin_memory, collate_fn=collate_fn, persistent_workers=True if self.hparams.num_workers > 0 else False ) def test_dataloader(self): return DataLoader( dataset=self.data_test, batch_sampler=SafeBatchSampler( data_source=self.data_test, batch_size=self.hparams.batch_size, shuffle=False), # batch_size=self.hparams.batch_size, # shuffle=False, num_workers=self.hparams.num_workers, pin_memory=self.hparams.pin_memory, collate_fn=collate_fn, persistent_workers=True if self.hparams.num_workers > 0 else False ) def predict_dataloader(self): return DataLoader( dataset=self.data_predict, batch_sampler=SafeBatchSampler( data_source=self.data_predict, batch_size=self.hparams.batch_size, shuffle=False), # batch_size=self.hparams.batch_size, # shuffle=False, num_workers=self.hparams.num_workers, pin_memory=self.hparams.pin_memory, collate_fn=collate_fn, persistent_workers=True if self.hparams.num_workers > 0 else False ) def teardown(self, stage: Optional[str] = None): """Clean up after fit or test.""" pass def state_dict(self): """Extra things to save to checkpoint.""" return {} def load_state_dict(self, state_dict: Dict[str, Any]): """Things to do when loading checkpoint.""" pass