Spaces:
Sleeping
Sleeping
from typing import Any, Dict, Optional | |
from lightning import LightningDataModule | |
from torch.utils.data import DataLoader, Dataset | |
from src.data.components.hrcwhu import HRCWHU | |
class HRCWHUDataModule(LightningDataModule): | |
def __init__( | |
self, | |
root: str, | |
train_pipeline: None, | |
val_pipeline: None, | |
test_pipeline: None, | |
seed: int=42, | |
batch_size: int = 1, | |
num_workers: int = 0, | |
pin_memory: bool = False, | |
persistent_workers: bool = False, | |
) -> None: | |
super().__init__() | |
# this line allows to access init params with 'self.hparams' attribute | |
# also ensures init params will be stored in ckpt | |
self.save_hyperparameters(logger=False) | |
self.train_dataset: Optional[Dataset] = None | |
self.val_dataset: Optional[Dataset] = None | |
self.test_dataset: Optional[Dataset] = None | |
self.batch_size_per_device = batch_size | |
def num_classes(self) -> int: | |
return len(HRCWHU.METAINFO["classes"]) | |
def prepare_data(self) -> None: | |
"""Download data if needed. Lightning ensures that `self.prepare_data()` is called only | |
within a single process on CPU, so you can safely add your downloading logic within. In | |
case of multi-node training, the execution of this hook depends upon | |
`self.prepare_data_per_node()`. | |
Do not use it to assign state (self.x = y). | |
""" | |
# train | |
HRCWHU( | |
root=self.hparams.root, | |
phase="train", | |
**self.hparams.train_pipeline, | |
seed=self.hparams.seed, | |
) | |
# val or test | |
HRCWHU( | |
root=self.hparams.root, | |
phase="test", | |
**self.hparams.test_pipeline, | |
seed=self.hparams.seed, | |
) | |
def setup(self, stage: Optional[str] = None) -> None: | |
"""Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`. | |
This method is called by Lightning before `trainer.fit()`, `trainer.validate()`, `trainer.test()`, and | |
`trainer.predict()`, so be careful not to execute things like random split twice! Also, it is called after | |
`self.prepare_data()` and there is a barrier in between which ensures that all the processes proceed to | |
`self.setup()` once the data is prepared and available for use. | |
:param stage: The stage to setup. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. Defaults to ``None``. | |
""" | |
# Divide batch size by the number of devices. | |
if self.trainer is not None: | |
if self.hparams.batch_size % self.trainer.world_size != 0: | |
raise RuntimeError( | |
f"Batch size ({self.hparams.batch_size}) is not divisible by the number of devices ({self.trainer.world_size})." | |
) | |
self.batch_size_per_device = self.hparams.batch_size // self.trainer.world_size | |
# load and split datasets only if not loaded already | |
if not self.train_dataset and not self.val_dataset and not self.test_dataset: | |
self.train_dataset = HRCWHU( | |
root=self.hparams.root, | |
phase="train", | |
**self.hparams.train_pipeline, | |
seed=self.hparams.seed, | |
) | |
self.val_dataset = self.test_dataset = HRCWHU( | |
root=self.hparams.root, | |
phase="test", | |
**self.hparams.test_pipeline, | |
seed=self.hparams.seed, | |
) | |
def train_dataloader(self) -> DataLoader[Any]: | |
"""Create and return the train dataloader. | |
:return: The train dataloader. | |
""" | |
return DataLoader( | |
dataset=self.train_dataset, | |
batch_size=self.batch_size_per_device, | |
num_workers=self.hparams.num_workers, | |
pin_memory=self.hparams.pin_memory, | |
persistent_workers=self.hparams.persistent_workers, | |
shuffle=True, | |
) | |
def val_dataloader(self) -> DataLoader[Any]: | |
"""Create and return the validation dataloader. | |
:return: The validation dataloader. | |
""" | |
return DataLoader( | |
dataset=self.val_dataset, | |
batch_size=self.batch_size_per_device, | |
num_workers=self.hparams.num_workers, | |
pin_memory=self.hparams.pin_memory, | |
persistent_workers=self.hparams.persistent_workers, | |
shuffle=False, | |
) | |
def test_dataloader(self) -> DataLoader[Any]: | |
"""Create and return the test dataloader. | |
:return: The test dataloader. | |
""" | |
return DataLoader( | |
dataset=self.test_dataset, | |
batch_size=self.batch_size_per_device, | |
num_workers=self.hparams.num_workers, | |
pin_memory=self.hparams.pin_memory, | |
persistent_workers=self.hparams.persistent_workers, | |
shuffle=False, | |
) | |
def teardown(self, stage: Optional[str] = None) -> None: | |
"""Lightning hook for cleaning up after `trainer.fit()`, `trainer.validate()`, | |
`trainer.test()`, and `trainer.predict()`. | |
:param stage: The stage being torn down. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. | |
Defaults to ``None``. | |
""" | |
pass | |
def state_dict(self) -> Dict[Any, Any]: | |
"""Called when saving a checkpoint. Implement to generate and save the datamodule state. | |
:return: A dictionary containing the datamodule state that you want to save. | |
""" | |
return {} | |
def load_state_dict(self, state_dict: Dict[str, Any]) -> None: | |
"""Called when loading a checkpoint. Implement to reload datamodule state given datamodule | |
`state_dict()`. | |
:param state_dict: The datamodule state returned by `self.state_dict()`. | |
""" | |
pass | |
if __name__ == "__main__": | |
_ = HRCWHUDataModule() | |