Spaces:
Sleeping
Sleeping
File size: 2,028 Bytes
b066d77 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
from types import SimpleNamespace
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.transforms import v2
DATASET_MEAN = 0.5077385902404785
DATASET_STD = 0.255077600479126
class PreprocessedImageFolder(ImageFolder):
def __init__(self, root, transform=None):
super().__init__(root, transform=transform)
self.preprocess = v2.Compose(
[
v2.Grayscale(),
v2.PILToTensor(),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize(mean=(DATASET_MEAN,), std=(DATASET_STD,)),
]
)
processed_samples = []
for path, target in self.samples:
sample = self.loader(path)
processed_sample = self.preprocess(sample)
processed_samples.append((processed_sample, target))
self.samples = processed_samples
def __getitem__(self, index):
sample, target = self.samples[index]
if self.transform is not None:
sample = self.transform(sample)
return sample, target
augmentations = v2.Compose(
[
v2.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
v2.RandomHorizontalFlip(),
v2.RandomResizedCrop(size=48, scale=(0.9, 1.1), antialias=True),
v2.RandomAffine(degrees=(-10, 10), translate=(0.1, 0.1), scale=(0.9, 1.1)),
v2.GaussianBlur(kernel_size=(3, 3), sigma=(0.1, 1.0)),
]
)
def make_dls(train_ds, valid_ds, batch_size=64, num_workers=2):
train_dl = DataLoader(
train_ds,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=True,
persistent_workers=True,
)
valid_dl = DataLoader(
valid_ds,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=True,
persistent_workers=True,
)
dls = SimpleNamespace(**{"train": train_dl, "valid": valid_dl})
return dls
|