File size: 917 Bytes
1fff313
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch
from torchvision import datasets, transforms

class MNISTDataModule:
    def __init__(self, batch_size=64, val_batch_size=1000):
        self.batch_size = batch_size
        self.val_batch_size = val_batch_size
        
    def get_dataloaders(self):
        """Create training and test dataloaders."""
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])
        train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
        test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)

        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=self.val_batch_size, shuffle=False)

        return train_loader, test_loader