Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
from numbers import Number | |
from pathlib import Path | |
from typing import Any, Dict, Optional, Sequence, Type | |
from lightning import LightningDataModule | |
from sklearn.base import TransformerMixin | |
from torch.utils.data import Dataset, DataLoader | |
from deepscreen.data.utils import collate_fn, SafeBatchSampler | |
from deepscreen.data.utils.dataset import BaseEntityDataset | |
class EntityDataModule(LightningDataModule): | |
""" | |
def prepare_data(self): | |
# things to do on 1 GPU/TPU (not on every GPU/TPU in DDP) | |
# download data, pre-process, split, save to disk, etc. | |
def setup(self, stage): | |
# things to do on every process in DDP | |
# load data, set variables, etc. | |
def train_dataloader(self): | |
# return train dataloader | |
def val_dataloader(self): | |
# return validation dataloader | |
def test_dataloader(self): | |
# return test dataloader | |
def teardown(self): | |
# called on every process in DDP | |
# clean up after fit or test | |
""" | |
def __init__( | |
self, | |
dataset: type[BaseEntityDataset], | |
transformer: type[TransformerMixin], | |
train: bool, | |
batch_size: int, | |
data_dir: str = "data/", | |
data_file: Optional[str] = None, | |
train_val_test_split: Optional[Sequence[Number], Sequence[str]] = None, | |
split: Optional[callable] = None, | |
num_workers: int = 0, | |
pin_memory: bool = False, | |
): | |
super().__init__() | |
# data processing | |
self.split = split | |
if train: | |
if all([data_file, split]): | |
if all(isinstance(split, Number) for split in train_val_test_split): | |
pass | |
else: | |
raise ValueError('`train_val_test_split` must be a sequence of 3 numbers ' | |
'(float for percentages and int for sample numbers) if ' | |
'`data_file` and `split` have been specified.') | |
elif all(isinstance(split, str) for split in train_val_test_split) and not any([data_file, split]): | |
self.train_data = dataset(dataset_path=str(Path(data_dir) / train_val_test_split[0])) | |
self.val_data = dataset(dataset_path=str(Path(data_dir) / train_val_test_split[1])) | |
self.test_data = dataset(dataset_path=str(Path(data_dir) / train_val_test_split[2])) | |
else: | |
raise ValueError('For training (train=True), you must specify either ' | |
'`dataset_name` and `split` with `train_val_test_split` of 3 numbers or ' | |
'solely `train_val_test_split` of 3 data file names.') | |
else: | |
if data_file and not any([split, train_val_test_split]): | |
self.test_data = self.predict_data = dataset(dataset_path=str(Path(data_dir) / data_file)) | |
else: | |
raise ValueError("For testing/predicting (train=False), you must specify only `data_file` without " | |
"`train_val_test_split` or `split`") | |
# this line allows to access init params with 'self.hparams' attribute | |
# also ensures init params will be stored in ckpt | |
self.save_hyperparameters(logger=False) | |
def prepare_data(self): | |
""" | |
Download data if needed. | |
Do not use it to assign state (e.g., self.x = x). | |
""" | |
def setup(self, stage: Optional[str] = None, encoding: str = None): | |
""" | |
Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`. | |
This method is called by lightning with both `trainer.fit()` and `trainer.test()`, so be | |
careful not to execute data splitting twice. | |
""" | |
# TODO test SafeBatchSampler (which skips samples with any None without introducing variable batch size) | |
# TODO: find a way to apply transformer.fit_transform only to train and transformer.transform only to val, test | |
# load and split datasets only if not loaded in initialization | |
if not any([self.train_data, self.test_data, self.val_data, self.predict_data]): | |
self.train_data, self.val_data, self.test_data = self.split( | |
dataset=self.hparams.dataset(data_dir=self.hparams.data_dir, | |
dataset_name=self.hparams.train_dataset_name), | |
lengths=self.hparams.train_val_test_split | |
) | |
def train_dataloader(self): | |
return DataLoader( | |
dataset=self.train_data, | |
batch_sampler=SafeBatchSampler( | |
data_source=self.train_data, | |
batch_size=self.hparams.batch_size, | |
shuffle=True), | |
# batch_size=self.hparams.batch_size, | |
# shuffle=True, | |
num_workers=self.hparams.num_workers, | |
pin_memory=self.hparams.pin_memory, | |
collate_fn=collate_fn, | |
persistent_workers=True if self.hparams.num_workers > 0 else False | |
) | |
def val_dataloader(self): | |
return DataLoader( | |
dataset=self.val_data, | |
batch_sampler=SafeBatchSampler( | |
data_source=self.val_data, | |
batch_size=self.hparams.batch_size, | |
shuffle=False), | |
# batch_size=self.hparams.batch_size, | |
# shuffle=False, | |
num_workers=self.hparams.num_workers, | |
pin_memory=self.hparams.pin_memory, | |
collate_fn=collate_fn, | |
persistent_workers=True if self.hparams.num_workers > 0 else False | |
) | |
def test_dataloader(self): | |
return DataLoader( | |
dataset=self.test_data, | |
batch_sampler=SafeBatchSampler( | |
data_source=self.test_data, | |
batch_size=self.hparams.batch_size, | |
shuffle=False), | |
# batch_size=self.hparams.batch_size, | |
# shuffle=False, | |
num_workers=self.hparams.num_workers, | |
pin_memory=self.hparams.pin_memory, | |
collate_fn=collate_fn, | |
persistent_workers=True if self.hparams.num_workers > 0 else False | |
) | |
def predict_dataloader(self): | |
return DataLoader( | |
dataset=self.predict_data, | |
batch_sampler=SafeBatchSampler( | |
data_source=self.predict_data, | |
batch_size=self.hparams.batch_size, | |
shuffle=False), | |
# batch_size=self.hparams.batch_size, | |
# shuffle=False, | |
num_workers=self.hparams.num_workers, | |
pin_memory=self.hparams.pin_memory, | |
collate_fn=collate_fn, | |
persistent_workers=True if self.hparams.num_workers > 0 else False | |
) | |
def teardown(self, stage: Optional[str] = None): | |
"""Clean up after fit or test.""" | |
pass | |
def state_dict(self): | |
"""Extra things to save to checkpoint.""" | |
return {} | |
def load_state_dict(self, state_dict: Dict[str, Any]): | |
"""Things to do when loading checkpoint.""" | |
pass | |