Spaces:
Sleeping
Sleeping
File size: 1,471 Bytes
bd65e34 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
import lightning as L
import numpy as np
from torch.utils.data import DataLoader, Subset, Dataset
from data_utils.splitter import chunk_splitter
class FrameDataModule(L.LightningDataModule):
def __init__(
self,
dataset: Dataset,
batch_size: int = 32,
chunk_size_for_splitting: int = 3 * 30,
num_workers: int = 2,
pin_memory: bool = False,
):
super().__init__()
self.dataset = dataset
self.batch_size = batch_size
self.num_workers = num_workers
self.pin_memory = pin_memory
self.chunk_size_for_splitting = chunk_size_for_splitting
split = chunk_splitter(
len(dataset), chunk_size=self.chunk_size_for_splitting, split=0.15 # type: ignore
)
val_indices = np.where(split)[0]
train_indices = np.where(split == 0)[0]
self.ds_train = Subset(self.dataset, train_indices) # type: ignore
self.ds_val = Subset(self.dataset, val_indices) # type: ignore
def train_dataloader(self):
return DataLoader(
self.ds_train,
shuffle=True,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
)
def val_dataloader(self):
return DataLoader(
self.ds_val,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
)
|