File size: 917 Bytes
1fff313 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
import torch
from torchvision import datasets, transforms
class MNISTDataModule:
def __init__(self, batch_size=64, val_batch_size=1000):
self.batch_size = batch_size
self.val_batch_size = val_batch_size
def get_dataloaders(self):
"""Create training and test dataloaders."""
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=self.val_batch_size, shuffle=False)
return train_loader, test_loader |