Spaces:
Runtime error
Runtime error
File size: 2,296 Bytes
9457143 dabac1b 9457143 dabac1b 9457143 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
import pytorch_lightning as pl
import torch
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from functools import partial
class MNISTDataModule(pl.LightningDataModule):
def __init__(
self,
data_dir: str = "./",
batch_size: int = 32,
num_workers: int = 0,
seed: int = 42,
train_ratio: float = 0.99,
img_dim: int = 32
):
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
self.num_workers = num_workers
self.train_ratio = min(train_ratio, 0.99)
self.seed = seed
self.transform = transforms.Compose(
[
transforms.Resize((img_dim, img_dim)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5), std=(0.5))
]
)
self.loader = partial(
DataLoader,
batch_size=self.batch_size,
pin_memory=True,
num_workers=self.num_workers,
persistent_workers=True
)
def setup(self, stage: str):
mnist_partial = partial(
MNIST,
root=self.data_dir, transform=self.transform, download=True
)
if stage == "fit":
retrying = True
while retrying:
try:
mnist_full = mnist_partial(train=True)
retrying = False
except:
pass
self.mnist_train, self.mnist_val, _ = random_split(
dataset=mnist_full,
lengths=[self.train_ratio, 0.01, 1 - 0.01 - self.train_ratio],
generator=torch.Generator().manual_seed(self.seed)
)
else:
retrying = True
while retrying:
try:
self.mnist_test = mnist_partial(train=False)
retrying = False
except:
pass
def train_dataloader(self):
return self.loader(dataset=self.mnist_train)
def val_dataloader(self):
return self.loader(dataset=self.mnist_val)
def test_dataloader(self):
return self.loader(dataset=self.mnist_test)
|