from torchvision import datasets, transforms from torch.utils.data import DataLoader def get_data_loaders(data_dir, batch_size=32): # Data augmentation + normalization for training transform_train = transforms.Compose([ transforms.RandomResizedCrop(128), transforms.RandomHorizontalFlip(), transforms.RandomRotation(10), transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) ]) # Only resize + normalize for validation transform_val = transforms.Compose([ transforms.Resize((128, 128)), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) ]) train_dir = f"{data_dir}/training" val_dir = f"{data_dir}/validation" train_dataset = datasets.ImageFolder(train_dir, transform=transform_train) val_dataset = datasets.ImageFolder(val_dir, transform=transform_val) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) return train_loader, val_loader, train_dataset.classes