Spaces:
Runtime error
Runtime error
import pytorch_lightning as pl | |
import torch | |
import os | |
from PIL import Image | |
from torch.utils.data import DataLoader, Dataset, random_split | |
from torchvision import transforms | |
from functools import partial | |
class CelebADataset(Dataset): | |
def __init__( | |
self, | |
data_dir: str, | |
img_dim: int = 64 | |
): | |
self.list_path = os.listdir(data_dir) | |
self.data_dir = data_dir | |
self.transform = transforms.Compose( | |
[ | |
transforms.Resize((img_dim, img_dim)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) | |
] | |
) | |
def __len__(self): | |
return len(self.list_path) | |
def __getitem__(self, index): | |
img = Image.open(os.path.join(self.data_dir, self.list_path[index])) | |
return self.transform(img) | |
class CelebADataModule(pl.LightningDataModule): | |
def __init__( | |
self, | |
data_dir: str = "./", | |
batch_size: int = 32, | |
num_workers: int = 0, | |
seed: int = 42, | |
train_ratio: float = 0.99, | |
img_dim: int = 64 | |
): | |
super().__init__() | |
self.data_dir = data_dir | |
self.batch_size = batch_size | |
self.num_workers = num_workers | |
self.train_ratio = min(train_ratio, 0.99) | |
self.img_dim = img_dim | |
self.seed = seed | |
self.loader = partial( | |
DataLoader, | |
batch_size=self.batch_size, | |
pin_memory=True, | |
num_workers=self.num_workers, | |
persistent_workers=True | |
) | |
def setup(self, stage: str): | |
if stage == "fit": | |
dataset = CelebADataset(self.data_dir, self.img_dim) | |
self.CelebA_train, self.CelebA_val, _ = random_split( | |
dataset=dataset, | |
lengths=[self.train_ratio, 0.01, 1 - 0.01 - self.train_ratio], | |
generator=torch.Generator().manual_seed(self.seed) | |
) | |
else: | |
pass | |
def train_dataloader(self): | |
return self.loader(dataset=self.CelebA_train) | |
def val_dataloader(self): | |
return self.loader(dataset=self.CelebA_val) | |