import torch from torch.utils.data import DataLoader from torchvision import datasets def create_datasets(train_dir, test_dir, data_transform): train_data = datasets.ImageFolder(root=train_dir, transform=data_transform, target_transform=None) test_data = datasets.ImageFolder(root=test_dir, transform=data_transform) return train_data, test_data def create_dataloaders(train_dataset, test_dataset, batch_size, num_workers): train_dataloader = DataLoader(dataset=train_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True) test_dataloader = DataLoader(dataset=test_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False) return train_dataloader, test_dataloader def data_setup(train_dir, test_dir, data_transform, batch_size, num_workers): train_dataset, test_dataset = create_datasets(train_dir=train_dir, test_dir=test_dir, data_transform=data_transform) class_names = train_dataset.classes train_dataloader, test_dataloader = create_dataloaders(train_dataset=train_dataset, test_dataset=test_dataset, batch_size=batch_size, num_workers=num_workers) return train_dataloader, test_dataloader, class_names