File size: 2,376 Bytes
6551065 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
from lightning.pytorch import LightningDataModule
from torch.utils.data import DataLoader
from dataset.data_helper import create_datasets
class DataModule(LightningDataModule):
def __init__(
self,
args
):
super().__init__()
self.args = args
def prepare_data(self):
"""
Use this method to do things that might write to disk or that need to be done only from a single process in distributed settings.
download
tokenize
etc…
:return:
"""
def setup(self, stage: str):
"""
There are also data operations you might want to perform on every GPU. Use setup to do things like:
count number of classes
build vocabulary
perform train/val/test splits
apply transforms (defined explicitly in your datamodule or assigned in init)
etc…
:param stage:
:return:
"""
train_dataset, dev_dataset, test_dataset = create_datasets(self.args)
self.dataset = {
"train": train_dataset, "validation": dev_dataset, "test": test_dataset
}
def train_dataloader(self):
"""
Use this method to generate the train dataloader. Usually you just wrap the dataset you defined in setup.
:return:
"""
loader = DataLoader(self.dataset["train"], batch_size=self.args.batch_size, drop_last=True, pin_memory=True,
num_workers=self.args.num_workers, prefetch_factor=self.args.prefetch_factor)
return loader
def val_dataloader(self):
"""
Use this method to generate the val dataloader. Usually you just wrap the dataset you defined in setup.
:return:
"""
loader = DataLoader(self.dataset["validation"], batch_size=self.args.val_batch_size, drop_last=False, pin_memory=True,
num_workers=self.args.num_workers, prefetch_factor=self.args.prefetch_factor)
return loader
def test_dataloader(self):
loader = DataLoader(self.dataset["test"], batch_size=self.args.test_batch_size, drop_last=False, pin_memory=False,
num_workers=self.args.num_workers, prefetch_factor=self.args.prefetch_factor)
return loader |