|
import torch |
|
from torchvision import datasets, transforms |
|
|
|
class MNISTDataModule: |
|
def __init__(self, batch_size=64, val_batch_size=1000): |
|
self.batch_size = batch_size |
|
self.val_batch_size = val_batch_size |
|
|
|
def get_dataloaders(self): |
|
"""Create training and test dataloaders.""" |
|
transform = transforms.Compose([ |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.5,), (0.5,)) |
|
]) |
|
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True) |
|
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True) |
|
|
|
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True) |
|
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=self.val_batch_size, shuffle=False) |
|
|
|
return train_loader, test_loader |