File size: 1,549 Bytes
e61c431
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch
import lightning as L
import torchvision.transforms as T
import os

from config.core import config
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from utility.helper import PadToSquare

class ShoeSandalBoot(L.LightningDataModule):
    def __init__(
        self, 
        dataset_directory, 
        image_size=config.image_size, 
        batch_size=config.image_size,
        max_samples=None
    ):
        super().__init__()

        self.data_dir = dataset_directory
        self.bs = batch_size
        self.max_samples = max_samples # to limit dataset.

        self.transforms = T.Compose([
            # T.Resize(size=(image_size, image_size)),
            PadToSquare(image_size),
            T.ToTensor(),
            T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
        ])
        
        self.ssb = None

    def prepare_data(self):
        pass

    def setup(self, stage):
        if stage == "fit":
            dataset = ImageFolder(
                root=self.data_dir,
                transform=self.transforms
            )
            
            # To Limit Dataset
            if self.max_samples:
                print(f"[INFO] Dataset is Limited to {self.max_samples} Samples")
                self.ssb = torch.utils.data.Subset(dataset, range(min(len(dataset), self.max_samples)))
            else:
                self.ssb = dataset

    def train_dataloader(self):
        return DataLoader(self.ssb, batch_size=self.bs, num_workers=os.cpu_count(), shuffle=True)