from types import SimpleNamespace import torch from torch.utils.data import DataLoader from torchvision.datasets import ImageFolder from torchvision.transforms import v2 DATASET_MEAN = 0.5077385902404785 DATASET_STD = 0.255077600479126 class PreprocessedImageFolder(ImageFolder): def __init__(self, root, transform=None): super().__init__(root, transform=transform) self.preprocess = v2.Compose( [ v2.Grayscale(), v2.PILToTensor(), v2.ToDtype(torch.float32, scale=True), v2.Normalize(mean=(DATASET_MEAN,), std=(DATASET_STD,)), ] ) processed_samples = [] for path, target in self.samples: sample = self.loader(path) processed_sample = self.preprocess(sample) processed_samples.append((processed_sample, target)) self.samples = processed_samples def __getitem__(self, index): sample, target = self.samples[index] if self.transform is not None: sample = self.transform(sample) return sample, target augmentations = v2.Compose( [ v2.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), v2.RandomHorizontalFlip(), v2.RandomResizedCrop(size=48, scale=(0.9, 1.1), antialias=True), v2.RandomAffine(degrees=(-10, 10), translate=(0.1, 0.1), scale=(0.9, 1.1)), v2.GaussianBlur(kernel_size=(3, 3), sigma=(0.1, 1.0)), ] ) def make_dls(train_ds, valid_ds, batch_size=64, num_workers=2): train_dl = DataLoader( train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, persistent_workers=True, ) valid_dl = DataLoader( valid_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, persistent_workers=True, ) dls = SimpleNamespace(**{"train": train_dl, "valid": valid_dl}) return dls