|
from typing import Callable, List, Optional, Union |
|
|
|
import torch |
|
from pytorch_lightning import LightningDataModule |
|
from torch.utils.data import DataLoader, Dataset |
|
|
|
|
|
class SampleDataset(Dataset): |
|
|
|
def __init__(self, |
|
x: Union[List, torch.Tensor], |
|
y: Union[List, torch.Tensor], |
|
transforms: Optional[Callable] = None) -> None: |
|
super(SampleDataset, self).__init__() |
|
self.x = x |
|
self.y = y |
|
|
|
if transforms is None: |
|
|
|
|
|
self.transforms = lambda x: x |
|
else: |
|
self.transforms = transforms |
|
|
|
def __len__(self): |
|
return len(self.x) |
|
|
|
def __getitem__(self, index: int): |
|
x = self.x[index] |
|
y = self.y[index] |
|
|
|
x = self.transforms(x) |
|
return x, y |
|
|
|
|
|
class SampleDataModule(LightningDataModule): |
|
|
|
def __init__(self, |
|
x: Union[List, torch.Tensor], |
|
y: Union[List, torch.Tensor], |
|
transforms: Optional[Callable] = None, |
|
val_ratio: float = 0, |
|
batch_size: int = 32) -> None: |
|
super(SampleDataModule, self).__init__() |
|
assert 0 <= val_ratio < 1 |
|
assert isinstance(batch_size, int) |
|
self.x = x |
|
self.y = y |
|
|
|
self.transforms = transforms |
|
self.val_ratio = val_ratio |
|
self.batch_size = batch_size |
|
|
|
self.setup() |
|
self.prepare_data() |
|
|
|
def setup(self, stage: Optional[str] = None) -> None: |
|
pass |
|
|
|
def prepare_data(self) -> None: |
|
n_samples: int = len(self.x) |
|
train_size: int = n_samples - int(n_samples * self.val_ratio) |
|
|
|
self.train_dataset = SampleDataset(x=self.x[:train_size], |
|
y=self.y[:train_size], |
|
transforms=self.transforms) |
|
if train_size < n_samples: |
|
self.val_dataset = SampleDataset(x=self.x[train_size:], |
|
y=self.y[train_size:], |
|
transforms=self.transforms) |
|
else: |
|
self.val_dataset = SampleDataset(x=self.x[-self.batch_size:], |
|
y=self.y[-self.batch_size:], |
|
transforms=self.transforms) |
|
|
|
def train_dataloader(self) -> DataLoader: |
|
return DataLoader(dataset=self.train_dataset, |
|
batch_size=self.batch_size) |
|
|
|
def val_dataloader(self) -> DataLoader: |
|
return DataLoader(dataset=self.val_dataset, batch_size=self.batch_size) |
|
|