""" Contains functionality for creating Pytorch DataLoaders for image classification data. """ import os from torchvision import datasets, transforms from torch.utils.data import DataLoader NUM_WORKERS = os.cpu_count() def create_dataloaders( train_dir: str, test_dir: str, transform: transforms.Compose, batch_size: int, num_workers: int=NUM_WORKERS ): """ Creates training and testing DataLoaders """ # Use ImageFolder to create Datasets train_data = datasets.ImageFolder(train_dir, transform=transform) test_data = datasets.ImageFolder(test_dir, transform=transform) # Get class names class_names = train_data.classes # Turn images into data loaders train_dataloader = DataLoader( train_data, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True ) test_dataloader = DataLoader( test_data, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True ) return train_dataloader, test_dataloader, class_names