from lightning.pytorch import LightningDataModule from torch.utils.data import DataLoader from dataset.data_helper import create_datasets class DataModule(LightningDataModule): def __init__( self, args ): super().__init__() self.args = args def prepare_data(self): """ Use this method to do things that might write to disk or that need to be done only from a single process in distributed settings. download tokenize etc… :return: """ def setup(self, stage: str): """ There are also data operations you might want to perform on every GPU. Use setup to do things like: count number of classes build vocabulary perform train/val/test splits apply transforms (defined explicitly in your datamodule or assigned in init) etc… :param stage: :return: """ train_dataset, dev_dataset, test_dataset = create_datasets(self.args) self.dataset = { "train": train_dataset, "validation": dev_dataset, "test": test_dataset } def train_dataloader(self): """ Use this method to generate the train dataloader. Usually you just wrap the dataset you defined in setup. :return: """ loader = DataLoader(self.dataset["train"], batch_size=self.args.batch_size, drop_last=True, pin_memory=True, num_workers=self.args.num_workers, prefetch_factor=self.args.prefetch_factor) return loader def val_dataloader(self): """ Use this method to generate the val dataloader. Usually you just wrap the dataset you defined in setup. :return: """ loader = DataLoader(self.dataset["validation"], batch_size=self.args.val_batch_size, drop_last=False, pin_memory=True, num_workers=self.args.num_workers, prefetch_factor=self.args.prefetch_factor) return loader def test_dataloader(self): loader = DataLoader(self.dataset["test"], batch_size=self.args.test_batch_size, drop_last=False, pin_memory=False, num_workers=self.args.num_workers, prefetch_factor=self.args.prefetch_factor) return loader