Spaces:
Sleeping
Sleeping
import pytorch_lightning as pl | |
from torch.utils.data import DataLoader | |
import torch | |
from .attribute_selector import AttributeSelector | |
from .similarity_vector_dataset import SimilarityVectorDataset | |
from typing import List | |
class MARCDataModule(pl.LightningDataModule): | |
def __init__( | |
self, | |
train_processed_path: str, | |
val_processed_path: str, | |
test_processed_path: str, | |
attrs: List[str], | |
batch_size: int, | |
): | |
super().__init__() | |
self.train_processed_path = train_processed_path | |
self.val_processed_path = val_processed_path | |
self.test_processed_path = test_processed_path | |
self.batch_size = batch_size | |
self.transform = torch.nn.Sequential(AttributeSelector(attrs)) | |
self.train_set = None | |
self.val_set = None | |
self.test_set = None | |
def setup(self, stage=None): | |
self.train_set = SimilarityVectorDataset( | |
self.train_processed_path, transform=self.transform | |
) | |
self.val_set = SimilarityVectorDataset( | |
self.val_processed_path, transform=self.transform | |
) | |
self.test_set = SimilarityVectorDataset( | |
self.test_processed_path, transform=self.transform | |
) | |
def train_dataloader(self): | |
return DataLoader( | |
self.train_set, batch_size=self.batch_size, num_workers=0, shuffle=True | |
) | |
def val_dataloader(self): | |
return DataLoader(self.val_set, batch_size=self.batch_size, num_workers=0) | |
def test_dataloader(self): | |
return DataLoader(self.test_set, batch_size=self.batch_size, num_workers=0) | |