|
from lightning.pytorch import LightningDataModule
|
|
from torch.utils.data import DataLoader
|
|
from dataset.data_helper import create_datasets
|
|
|
|
|
|
|
|
class DataModule(LightningDataModule):
|
|
|
|
def __init__(
|
|
self,
|
|
args
|
|
):
|
|
super().__init__()
|
|
self.args = args
|
|
|
|
def prepare_data(self):
|
|
"""
|
|
Use this method to do things that might write to disk or that need to be done only from a single process in distributed settings.
|
|
|
|
download
|
|
|
|
tokenize
|
|
|
|
etc…
|
|
:return:
|
|
"""
|
|
|
|
def setup(self, stage: str):
|
|
"""
|
|
There are also data operations you might want to perform on every GPU. Use setup to do things like:
|
|
|
|
count number of classes
|
|
|
|
build vocabulary
|
|
|
|
perform train/val/test splits
|
|
|
|
apply transforms (defined explicitly in your datamodule or assigned in init)
|
|
|
|
etc…
|
|
:param stage:
|
|
:return:
|
|
"""
|
|
train_dataset, dev_dataset, test_dataset = create_datasets(self.args)
|
|
self.dataset = {
|
|
"train": train_dataset, "validation": dev_dataset, "test": test_dataset
|
|
}
|
|
|
|
|
|
def train_dataloader(self):
|
|
"""
|
|
Use this method to generate the train dataloader. Usually you just wrap the dataset you defined in setup.
|
|
:return:
|
|
"""
|
|
loader = DataLoader(self.dataset["train"], batch_size=self.args.batch_size, drop_last=True, pin_memory=True,
|
|
num_workers=self.args.num_workers, prefetch_factor=self.args.prefetch_factor)
|
|
return loader
|
|
|
|
|
|
def val_dataloader(self):
|
|
"""
|
|
Use this method to generate the val dataloader. Usually you just wrap the dataset you defined in setup.
|
|
:return:
|
|
"""
|
|
loader = DataLoader(self.dataset["validation"], batch_size=self.args.val_batch_size, drop_last=False, pin_memory=True,
|
|
num_workers=self.args.num_workers, prefetch_factor=self.args.prefetch_factor)
|
|
return loader
|
|
|
|
|
|
def test_dataloader(self):
|
|
loader = DataLoader(self.dataset["test"], batch_size=self.args.test_batch_size, drop_last=False, pin_memory=False,
|
|
num_workers=self.args.num_workers, prefetch_factor=self.args.prefetch_factor)
|
|
return loader |