File size: 2,183 Bytes
9457143
 
 
 
 
 
 
 
 
 
 
 
 
dabac1b
9457143
 
 
 
 
dabac1b
9457143
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dabac1b
 
9457143
 
 
 
 
 
dabac1b
9457143
 
 
 
 
 
 
 
 
 
 
 
dabac1b
9457143
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import pytorch_lightning as pl
import torch
import os
from PIL import Image
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import transforms
from functools import partial


class CelebADataset(Dataset):
    def __init__(
            self,
            data_dir: str,
            img_dim: int = 64
    ):
        self.list_path = os.listdir(data_dir)
        self.data_dir = data_dir
        self.transform = transforms.Compose(
            [
                transforms.Resize((img_dim, img_dim)),
                transforms.ToTensor(),
                transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
            ]
        )

    def __len__(self):
        return len(self.list_path)

    def __getitem__(self, index):
        img = Image.open(os.path.join(self.data_dir, self.list_path[index]))
        return self.transform(img)


class CelebADataModule(pl.LightningDataModule):
    def __init__(
        self,
        data_dir: str = "./",
        batch_size: int = 32,
        num_workers: int = 0,
        seed: int = 42,
        train_ratio: float = 0.99,
        img_dim: int = 64
    ):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.train_ratio = min(train_ratio, 0.99)
        self.img_dim = img_dim
        self.seed = seed

        self.loader = partial(
            DataLoader,
            batch_size=self.batch_size,
            pin_memory=True,
            num_workers=self.num_workers,
            persistent_workers=True
        )

    def setup(self, stage: str):
        if stage == "fit":
            dataset = CelebADataset(self.data_dir, self.img_dim)
            self.CelebA_train, self.CelebA_val, _ = random_split(
                dataset=dataset,
                lengths=[self.train_ratio, 0.01, 1 - 0.01 - self.train_ratio],
                generator=torch.Generator().manual_seed(self.seed)
            )
        else:
            pass

    def train_dataloader(self):
        return self.loader(dataset=self.CelebA_train)

    def val_dataloader(self):
        return self.loader(dataset=self.CelebA_val)