import pytorch_lightning as pl from torch.utils.data import DataLoader import torch from .attribute_selector import AttributeSelector from .similarity_vector_dataset import SimilarityVectorDataset from typing import List class MARCDataModule(pl.LightningDataModule): def __init__( self, train_processed_path: str, val_processed_path: str, test_processed_path: str, attrs: List[str], batch_size: int, ): super().__init__() self.train_processed_path = train_processed_path self.val_processed_path = val_processed_path self.test_processed_path = test_processed_path self.batch_size = batch_size self.transform = torch.nn.Sequential(AttributeSelector(attrs)) self.train_set = None self.val_set = None self.test_set = None def setup(self, stage=None): self.train_set = SimilarityVectorDataset( self.train_processed_path, transform=self.transform ) self.val_set = SimilarityVectorDataset( self.val_processed_path, transform=self.transform ) self.test_set = SimilarityVectorDataset( self.test_processed_path, transform=self.transform ) def train_dataloader(self): return DataLoader( self.train_set, batch_size=self.batch_size, num_workers=0, shuffle=True ) def val_dataloader(self): return DataLoader(self.val_set, batch_size=self.batch_size, num_workers=0) def test_dataloader(self): return DataLoader(self.test_set, batch_size=self.batch_size, num_workers=0)