from typing import List import __main__ import rootutils import torch from torch_geometric.data import Dataset # setup root dir and pythonpath rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) from src.data.components.prepare_data import CropPairedPDB setattr(__main__, "CropPairedPDB", CropPairedPDB) class PinderDataset(Dataset): """Pinder dataset. Args: Dataset: PyTorch Geometric Dataset. """ def __init__(self, file_paths: List[str]) -> None: """Initialize the PinderDataset. Args: file_paths: List of file paths. """ super().__init__() self.file_paths = file_paths @property def processed_file_names(self) -> List[str]: """Return the processed file names. Returns: List[str]: List of processed """ return self.file_paths def len(self) -> int: """Return the length of the dataset. Returns: int: Length of the dataset """ return len(self.processed_file_names) def get(self, idx) -> CropPairedPDB: """Get the data at the given index. Args: idx: Index of the data. Returns: CropPairedPDB: CropPairedPDB object. """ data = torch.load(self.processed_file_names[idx], weights_only=False) return data if __name__ == "__main__": file_paths = ["./data/processed/apo/test/1a19__A1_P11540--1a19__B1_P11540.pt"] dataset = PinderDataset(file_paths=file_paths) print(dataset[0])