libokj's picture
Upload 110 files
c0ec7e6
raw
history blame
8.05 kB
# from itertools import product
from numbers import Number
from pathlib import Path
from typing import Any, Dict, Optional, Sequence, Union, Literal
# import numpy as np
import pandas as pd
from lightning import LightningDataModule
from sklearn.base import TransformerMixin
from torch.utils.data import Dataset, DataLoader, random_split
from deepscreen.data.utils.dataset import SingleEntitySingleTargetDataset, BaseEntityDataset
from deepscreen.data.utils.label import label_transform
from deepscreen.data.utils.collator import collate_fn
from deepscreen.data.utils.sampler import SafeBatchSampler
class EntityDataModule(LightningDataModule):
"""
DTI DataModule
A DataModule implements 5 key methods:
def prepare_data(self):
# things to do on 1 GPU/TPU (not on every GPU/TPU in DDP)
# download data, pre-process, split, save to disk, etc.
def setup(self, stage):
# things to do on every process in DDP
# load data, set variables, etc.
def train_dataloader(self):
# return train dataloader
def val_dataloader(self):
# return validation dataloader
def test_dataloader(self):
# return test dataloader
def teardown(self):
# called on every process in DDP
# clean up after fit or test
This allows you to share a full dataset without explaining how to download,
split, transform and process the data.
Read the docs:
https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html
"""
def __init__(
self,
dataset: type[BaseEntityDataset],
task: Literal['regression', 'binary', 'multiclass'],
n_classes: Optional[int],
train: bool,
batch_size: int,
num_workers: int = 0,
thresholds: Optional[Union[Number, Sequence[Number]]] = None,
pin_memory: bool = False,
data_dir: str = "data/",
data_file: Optional[str] = None,
train_val_test_split: Optional[Sequence[Number], Sequence[str]] = None,
split: Optional[callable] = random_split,
):
super().__init__()
data_path = Path(data_dir) / data_file
# this line allows to access init params with 'self.hparams' attribute
# also ensures init params will be stored in ckpt
self.save_hyperparameters(logger=False)
# data processing
self.split = split
if train:
if all([data_file, split]):
if all(isinstance(split, Number) for split in train_val_test_split):
pass
else:
raise ValueError('`train_val_test_split` must be a sequence of 3 numbers '
'(float for percentages and int for sample numbers) if '
'`data_file` and `split` have been specified.')
elif all(isinstance(split, str) for split in train_val_test_split) and not any([data_file, split]):
self.train_data = dataset(dataset_path=str(Path(data_dir) / train_val_test_split[0]))
self.val_data = dataset(dataset_path=str(Path(data_dir) / train_val_test_split[1]))
self.test_data = dataset(dataset_path=str(Path(data_dir) / train_val_test_split[2]))
else:
raise ValueError('For training (train=True), you must specify either '
'`dataset_name` and `split` with `train_val_test_split` of 3 numbers or '
'solely `train_val_test_split` of 3 data file names.')
else:
if data_file and not any([split, train_val_test_split]):
self.test_data = self.predict_data = dataset(dataset_path=str(Path(data_dir) / data_file))
else:
raise ValueError("For testing/predicting (train=False), you must specify only `data_file` without "
"`train_val_test_split` or `split`")
def prepare_data(self):
"""
Download data if needed.
Do not use it to assign state (e.g., self.x = x).
"""
def setup(self, stage: Optional[str] = None, encoding: str = None):
"""
Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`.
This method is called by lightning with both `trainer.fit()` and `trainer.test()`, so be
careful not to execute data splitting twice.
"""
# load and split datasets only if not loaded in initialization
if not any([self.data_train, self.data_val, self.data_test, self.data_predict]):
dataset = SingleEntitySingleTargetDataset(
task=self.hparams.task,
n_classes=self.hparams.n_classes,
dataset_path=Path(self.hparams.data_dir) / self.hparams.dataset_name,
transformer=self.hparams.transformer,
featurizer=self.hparams.featurizer,
thresholds=self.hparams.thresholds,
)
if self.hparams.train:
self.data_train, self.data_val, self.data_test = self.split(
dataset=dataset,
lengths=self.hparams.train_val_test_split
)
else:
self.data_test = self.data_predict = dataset
def train_dataloader(self):
return DataLoader(
dataset=self.data_train,
batch_sampler=SafeBatchSampler(
data_source=self.data_train,
batch_size=self.hparams.batch_size,
shuffle=True),
# batch_size=self.hparams.batch_size,
# shuffle=True,
num_workers=self.hparams.num_workers,
pin_memory=self.hparams.pin_memory,
collate_fn=collate_fn,
persistent_workers=True if self.hparams.num_workers > 0 else False
)
def val_dataloader(self):
return DataLoader(
dataset=self.data_val,
batch_sampler=SafeBatchSampler(
data_source=self.data_val,
batch_size=self.hparams.batch_size,
shuffle=False),
# batch_size=self.hparams.batch_size,
# shuffle=False,
num_workers=self.hparams.num_workers,
pin_memory=self.hparams.pin_memory,
collate_fn=collate_fn,
persistent_workers=True if self.hparams.num_workers > 0 else False
)
def test_dataloader(self):
return DataLoader(
dataset=self.data_test,
batch_sampler=SafeBatchSampler(
data_source=self.data_test,
batch_size=self.hparams.batch_size,
shuffle=False),
# batch_size=self.hparams.batch_size,
# shuffle=False,
num_workers=self.hparams.num_workers,
pin_memory=self.hparams.pin_memory,
collate_fn=collate_fn,
persistent_workers=True if self.hparams.num_workers > 0 else False
)
def predict_dataloader(self):
return DataLoader(
dataset=self.data_predict,
batch_sampler=SafeBatchSampler(
data_source=self.data_predict,
batch_size=self.hparams.batch_size,
shuffle=False),
# batch_size=self.hparams.batch_size,
# shuffle=False,
num_workers=self.hparams.num_workers,
pin_memory=self.hparams.pin_memory,
collate_fn=collate_fn,
persistent_workers=True if self.hparams.num_workers > 0 else False
)
def teardown(self, stage: Optional[str] = None):
"""Clean up after fit or test."""
pass
def state_dict(self):
"""Extra things to save to checkpoint."""
return {}
def load_state_dict(self, state_dict: Dict[str, Any]):
"""Things to do when loading checkpoint."""
pass