Spaces:
Runtime error
Runtime error
import pytorch_lightning as pl | |
import torch | |
from torchvision.datasets import MNIST | |
from torch.utils.data import DataLoader, random_split | |
from torchvision import transforms | |
from functools import partial | |
class MNISTDataModule(pl.LightningDataModule): | |
def __init__( | |
self, | |
data_dir: str = "./", | |
batch_size: int = 32, | |
num_workers: int = 0, | |
seed: int = 42, | |
train_ratio: float = 0.99, | |
img_dim: int = 32 | |
): | |
super().__init__() | |
self.data_dir = data_dir | |
self.batch_size = batch_size | |
self.num_workers = num_workers | |
self.train_ratio = min(train_ratio, 0.99) | |
self.seed = seed | |
self.transform = transforms.Compose( | |
[ | |
transforms.Resize((img_dim, img_dim)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=(0.5), std=(0.5)) | |
] | |
) | |
self.loader = partial( | |
DataLoader, | |
batch_size=self.batch_size, | |
pin_memory=True, | |
num_workers=self.num_workers, | |
persistent_workers=True | |
) | |
def setup(self, stage: str): | |
mnist_partial = partial( | |
MNIST, | |
root=self.data_dir, transform=self.transform, download=True | |
) | |
if stage == "fit": | |
retrying = True | |
while retrying: | |
try: | |
mnist_full = mnist_partial(train=True) | |
retrying = False | |
except: | |
pass | |
self.mnist_train, self.mnist_val, _ = random_split( | |
dataset=mnist_full, | |
lengths=[self.train_ratio, 0.01, 1 - 0.01 - self.train_ratio], | |
generator=torch.Generator().manual_seed(self.seed) | |
) | |
else: | |
retrying = True | |
while retrying: | |
try: | |
self.mnist_test = mnist_partial(train=False) | |
retrying = False | |
except: | |
pass | |
def train_dataloader(self): | |
return self.loader(dataset=self.mnist_train) | |
def val_dataloader(self): | |
return self.loader(dataset=self.mnist_val) | |
def test_dataloader(self): | |
return self.loader(dataset=self.mnist_test) | |