File size: 2,108 Bytes
ae0af75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import lightning as L
import torchvision.transforms as T
import os

from torch.utils.data import DataLoader, Subset
from data.dataset import FaceToComicDataset

class FaceToComicDataModule(L.LightningDataModule):
    def __init__(
        self, 
        face_path, 
        comic_path, 
        image_size=(128, 128), 
        batch_size=32, 
        max_samples=None
    ):
        super().__init__()

        self.face_dir = face_path
        self.comic_dir = comic_path
        self.image_size = image_size
        self.batch_size = batch_size
        self.max_samples = max_samples

        self.transform_face = T.Compose([
            T.Resize(self.image_size),
            T.ToTensor(),
            T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
        ])

        self.transform_comic = T.Compose([
            T.Resize(self.image_size),
            T.ToTensor(),
            T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
        ])

        self.face2comic = None

    def prepare_data(self):
        # No need to download or prepare data, as it's already present in the directories
        pass

    def setup(self, stage=None):
        if stage == "fit" or stage is None:
            dataset = FaceToComicDataset(
                face_path=self.face_dir,
                comic_path=self.comic_dir,
                transform_face=self.transform_face,
                transform_comic=self.transform_comic
            )
            
            # To Limit Dataset
            if self.max_samples:
                print(f"[INFO] Dataset is Limited to {self.max_samples} Samples")
                self.face2comic = Subset(dataset, range(min(len(dataset), self.max_samples)))
            else:
                self.face2comic = dataset

    def train_dataloader(self):
        return DataLoader(self.face2comic, batch_size=self.batch_size, num_workers=os.cpu_count(), shuffle=True)

    def val_dataloader(self):
        # Implement if you need validation during training
        pass

    def test_dataloader(self):
        # Implement if you need testing after training
        pass