Spaces:
Sleeping
Sleeping
File size: 2,108 Bytes
ae0af75 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
import lightning as L
import torchvision.transforms as T
import os
from torch.utils.data import DataLoader, Subset
from data.dataset import FaceToComicDataset
class FaceToComicDataModule(L.LightningDataModule):
def __init__(
self,
face_path,
comic_path,
image_size=(128, 128),
batch_size=32,
max_samples=None
):
super().__init__()
self.face_dir = face_path
self.comic_dir = comic_path
self.image_size = image_size
self.batch_size = batch_size
self.max_samples = max_samples
self.transform_face = T.Compose([
T.Resize(self.image_size),
T.ToTensor(),
T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])
self.transform_comic = T.Compose([
T.Resize(self.image_size),
T.ToTensor(),
T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])
self.face2comic = None
def prepare_data(self):
# No need to download or prepare data, as it's already present in the directories
pass
def setup(self, stage=None):
if stage == "fit" or stage is None:
dataset = FaceToComicDataset(
face_path=self.face_dir,
comic_path=self.comic_dir,
transform_face=self.transform_face,
transform_comic=self.transform_comic
)
# To Limit Dataset
if self.max_samples:
print(f"[INFO] Dataset is Limited to {self.max_samples} Samples")
self.face2comic = Subset(dataset, range(min(len(dataset), self.max_samples)))
else:
self.face2comic = dataset
def train_dataloader(self):
return DataLoader(self.face2comic, batch_size=self.batch_size, num_workers=os.cpu_count(), shuffle=True)
def val_dataloader(self):
# Implement if you need validation during training
pass
def test_dataloader(self):
# Implement if you need testing after training
pass
|