File size: 2,376 Bytes
6551065
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from lightning.pytorch import LightningDataModule
from torch.utils.data import DataLoader
from dataset.data_helper import create_datasets



class DataModule(LightningDataModule):

    def __init__(

            self,

            args

    ):
        super().__init__()
        self.args = args

    def prepare_data(self):
        """

        Use this method to do things that might write to disk or that need to be done only from a single process in distributed settings.



        download



        tokenize



        etc…

        :return:

        """

    def setup(self, stage: str):
        """

        There are also data operations you might want to perform on every GPU. Use setup to do things like:



        count number of classes



        build vocabulary



        perform train/val/test splits



        apply transforms (defined explicitly in your datamodule or assigned in init)



        etc…

        :param stage:

        :return:

        """
        train_dataset, dev_dataset, test_dataset = create_datasets(self.args)
        self.dataset = {
            "train": train_dataset, "validation": dev_dataset, "test": test_dataset
        }


    def train_dataloader(self):
        """

        Use this method to generate the train dataloader. Usually you just wrap the dataset you defined in setup.

        :return:

        """
        loader = DataLoader(self.dataset["train"], batch_size=self.args.batch_size, drop_last=True, pin_memory=True,
                        num_workers=self.args.num_workers, prefetch_factor=self.args.prefetch_factor)
        return loader


    def val_dataloader(self):
        """

        Use this method to generate the val dataloader. Usually you just wrap the dataset you defined in setup.

        :return:

        """
        loader = DataLoader(self.dataset["validation"], batch_size=self.args.val_batch_size, drop_last=False, pin_memory=True,
                            num_workers=self.args.num_workers, prefetch_factor=self.args.prefetch_factor)
        return loader


    def test_dataloader(self):
        loader = DataLoader(self.dataset["test"], batch_size=self.args.test_batch_size, drop_last=False, pin_memory=False,
                        num_workers=self.args.num_workers, prefetch_factor=self.args.prefetch_factor)
        return loader