dinhdat1110's picture
ok
dabac1b
import pytorch_lightning as pl
import torch
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from functools import partial
class MNISTDataModule(pl.LightningDataModule):
def __init__(
self,
data_dir: str = "./",
batch_size: int = 32,
num_workers: int = 0,
seed: int = 42,
train_ratio: float = 0.99,
img_dim: int = 32
):
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
self.num_workers = num_workers
self.train_ratio = min(train_ratio, 0.99)
self.seed = seed
self.transform = transforms.Compose(
[
transforms.Resize((img_dim, img_dim)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5), std=(0.5))
]
)
self.loader = partial(
DataLoader,
batch_size=self.batch_size,
pin_memory=True,
num_workers=self.num_workers,
persistent_workers=True
)
def setup(self, stage: str):
mnist_partial = partial(
MNIST,
root=self.data_dir, transform=self.transform, download=True
)
if stage == "fit":
retrying = True
while retrying:
try:
mnist_full = mnist_partial(train=True)
retrying = False
except:
pass
self.mnist_train, self.mnist_val, _ = random_split(
dataset=mnist_full,
lengths=[self.train_ratio, 0.01, 1 - 0.01 - self.train_ratio],
generator=torch.Generator().manual_seed(self.seed)
)
else:
retrying = True
while retrying:
try:
self.mnist_test = mnist_partial(train=False)
retrying = False
except:
pass
def train_dataloader(self):
return self.loader(dataset=self.mnist_train)
def val_dataloader(self):
return self.loader(dataset=self.mnist_val)
def test_dataloader(self):
return self.loader(dataset=self.mnist_test)