File size: 2,028 Bytes
b066d77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from types import SimpleNamespace

import torch
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.transforms import v2

DATASET_MEAN = 0.5077385902404785
DATASET_STD = 0.255077600479126


class PreprocessedImageFolder(ImageFolder):
    def __init__(self, root, transform=None):
        super().__init__(root, transform=transform)
        self.preprocess = v2.Compose(
            [
                v2.Grayscale(),
                v2.PILToTensor(),
                v2.ToDtype(torch.float32, scale=True),
                v2.Normalize(mean=(DATASET_MEAN,), std=(DATASET_STD,)),
            ]
        )

        processed_samples = []
        for path, target in self.samples:
            sample = self.loader(path)
            processed_sample = self.preprocess(sample)
            processed_samples.append((processed_sample, target))

        self.samples = processed_samples

    def __getitem__(self, index):
        sample, target = self.samples[index]
        if self.transform is not None:
            sample = self.transform(sample)
        return sample, target


augmentations = v2.Compose(
    [
        v2.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
        v2.RandomHorizontalFlip(),
        v2.RandomResizedCrop(size=48, scale=(0.9, 1.1), antialias=True),
        v2.RandomAffine(degrees=(-10, 10), translate=(0.1, 0.1), scale=(0.9, 1.1)),
        v2.GaussianBlur(kernel_size=(3, 3), sigma=(0.1, 1.0)),
    ]
)


def make_dls(train_ds, valid_ds, batch_size=64, num_workers=2):
    train_dl = DataLoader(
        train_ds,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,
        persistent_workers=True,
    )
    valid_dl = DataLoader(
        valid_ds,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,
        persistent_workers=True,
    )

    dls = SimpleNamespace(**{"train": train_dl, "valid": valid_dl})
    return dls