Spaces:
Sleeping
Sleeping
File size: 1,214 Bytes
b1205cc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
def get_data_loaders(data_dir, batch_size=32):
# Data augmentation + normalization for training
transform_train = transforms.Compose([
transforms.RandomResizedCrop(128),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(10),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
# Only resize + normalize for validation
transform_val = transforms.Compose([
transforms.Resize((128, 128)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
train_dir = f"{data_dir}/training"
val_dir = f"{data_dir}/validation"
train_dataset = datasets.ImageFolder(train_dir, transform=transform_train)
val_dataset = datasets.ImageFolder(val_dir, transform=transform_val)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
return train_loader, val_loader, train_dataset.classes
|