File size: 2,296 Bytes
9457143
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dabac1b
 
9457143
 
 
 
 
 
 
 
 
dabac1b
9457143
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import pytorch_lightning as pl
import torch
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from functools import partial


class MNISTDataModule(pl.LightningDataModule):
    def __init__(
        self,
        data_dir: str = "./",
        batch_size: int = 32,
        num_workers: int = 0,
        seed: int = 42,
        train_ratio: float = 0.99,
        img_dim: int = 32
    ):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.train_ratio = min(train_ratio, 0.99)
        self.seed = seed
        self.transform = transforms.Compose(
            [
                transforms.Resize((img_dim, img_dim)),
                transforms.ToTensor(),
                transforms.Normalize(mean=(0.5), std=(0.5))
            ]
        )
        self.loader = partial(
            DataLoader,
            batch_size=self.batch_size,
            pin_memory=True,
            num_workers=self.num_workers,
            persistent_workers=True
        )

    def setup(self, stage: str):
        mnist_partial = partial(
            MNIST,
            root=self.data_dir, transform=self.transform, download=True
        )
        if stage == "fit":
            retrying = True
            while retrying:
                try:
                    mnist_full = mnist_partial(train=True)
                    retrying = False
                except:
                    pass
            self.mnist_train, self.mnist_val, _ = random_split(
                dataset=mnist_full,
                lengths=[self.train_ratio, 0.01, 1 - 0.01 - self.train_ratio],
                generator=torch.Generator().manual_seed(self.seed)
            )
        else:
            retrying = True
            while retrying:
                try:
                    self.mnist_test = mnist_partial(train=False)
                    retrying = False
                except:
                    pass

    def train_dataloader(self):
        return self.loader(dataset=self.mnist_train)

    def val_dataloader(self):
        return self.loader(dataset=self.mnist_val)

    def test_dataloader(self):
        return self.loader(dataset=self.mnist_test)