Spaces:
Sleeping
Sleeping
import lightning as L | |
import numpy as np | |
from torch.utils.data import DataLoader, Subset, Dataset | |
from data_utils.splitter import chunk_splitter | |
class FrameDataModule(L.LightningDataModule): | |
def __init__( | |
self, | |
dataset: Dataset, | |
batch_size: int = 32, | |
chunk_size_for_splitting: int = 3 * 30, | |
num_workers: int = 2, | |
pin_memory: bool = False, | |
): | |
super().__init__() | |
self.dataset = dataset | |
self.batch_size = batch_size | |
self.num_workers = num_workers | |
self.pin_memory = pin_memory | |
self.chunk_size_for_splitting = chunk_size_for_splitting | |
split = chunk_splitter( | |
len(dataset), chunk_size=self.chunk_size_for_splitting, split=0.15 # type: ignore | |
) | |
val_indices = np.where(split)[0] | |
train_indices = np.where(split == 0)[0] | |
self.ds_train = Subset(self.dataset, train_indices) # type: ignore | |
self.ds_val = Subset(self.dataset, val_indices) # type: ignore | |
def train_dataloader(self): | |
return DataLoader( | |
self.ds_train, | |
shuffle=True, | |
batch_size=self.batch_size, | |
num_workers=self.num_workers, | |
pin_memory=self.pin_memory, | |
) | |
def val_dataloader(self): | |
return DataLoader( | |
self.ds_val, | |
batch_size=self.batch_size, | |
num_workers=self.num_workers, | |
pin_memory=self.pin_memory, | |
) | |