57894-Pix2Pix / data /dataloader.py
Muhammad Naufal Rizqullah
first commit
ae0af75
import lightning as L
import torchvision.transforms as T
import os
from torch.utils.data import DataLoader, Subset
from data.dataset import FaceToComicDataset
class FaceToComicDataModule(L.LightningDataModule):
def __init__(
self,
face_path,
comic_path,
image_size=(128, 128),
batch_size=32,
max_samples=None
):
super().__init__()
self.face_dir = face_path
self.comic_dir = comic_path
self.image_size = image_size
self.batch_size = batch_size
self.max_samples = max_samples
self.transform_face = T.Compose([
T.Resize(self.image_size),
T.ToTensor(),
T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])
self.transform_comic = T.Compose([
T.Resize(self.image_size),
T.ToTensor(),
T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])
self.face2comic = None
def prepare_data(self):
# No need to download or prepare data, as it's already present in the directories
pass
def setup(self, stage=None):
if stage == "fit" or stage is None:
dataset = FaceToComicDataset(
face_path=self.face_dir,
comic_path=self.comic_dir,
transform_face=self.transform_face,
transform_comic=self.transform_comic
)
# To Limit Dataset
if self.max_samples:
print(f"[INFO] Dataset is Limited to {self.max_samples} Samples")
self.face2comic = Subset(dataset, range(min(len(dataset), self.max_samples)))
else:
self.face2comic = dataset
def train_dataloader(self):
return DataLoader(self.face2comic, batch_size=self.batch_size, num_workers=os.cpu_count(), shuffle=True)
def val_dataloader(self):
# Implement if you need validation during training
pass
def test_dataloader(self):
# Implement if you need testing after training
pass