harry
feat: baseline model
1fff313
raw
history blame contribute delete
917 Bytes
import torch
from torchvision import datasets, transforms
class MNISTDataModule:
def __init__(self, batch_size=64, val_batch_size=1000):
self.batch_size = batch_size
self.val_batch_size = val_batch_size
def get_dataloaders(self):
"""Create training and test dataloaders."""
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=self.val_batch_size, shuffle=False)
return train_loader, test_loader