Spaces:
Sleeping
Sleeping
File size: 1,635 Bytes
fbf7e95 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
import pytorch_lightning as pl
from torch.utils.data import DataLoader
import torch
from .attribute_selector import AttributeSelector
from .similarity_vector_dataset import SimilarityVectorDataset
from typing import List
class MARCDataModule(pl.LightningDataModule):
def __init__(
self,
train_processed_path: str,
val_processed_path: str,
test_processed_path: str,
attrs: List[str],
batch_size: int,
):
super().__init__()
self.train_processed_path = train_processed_path
self.val_processed_path = val_processed_path
self.test_processed_path = test_processed_path
self.batch_size = batch_size
self.transform = torch.nn.Sequential(AttributeSelector(attrs))
self.train_set = None
self.val_set = None
self.test_set = None
def setup(self, stage=None):
self.train_set = SimilarityVectorDataset(
self.train_processed_path, transform=self.transform
)
self.val_set = SimilarityVectorDataset(
self.val_processed_path, transform=self.transform
)
self.test_set = SimilarityVectorDataset(
self.test_processed_path, transform=self.transform
)
def train_dataloader(self):
return DataLoader(
self.train_set, batch_size=self.batch_size, num_workers=0, shuffle=True
)
def val_dataloader(self):
return DataLoader(self.val_set, batch_size=self.batch_size, num_workers=0)
def test_dataloader(self):
return DataLoader(self.test_set, batch_size=self.batch_size, num_workers=0)
|