Spaces:
Sleeping
Sleeping
from typing import List | |
import __main__ | |
import rootutils | |
import torch | |
from torch_geometric.data import Dataset | |
# setup root dir and pythonpath | |
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) | |
from src.data.components.prepare_data import CropPairedPDB | |
setattr(__main__, "CropPairedPDB", CropPairedPDB) | |
class PinderDataset(Dataset): | |
"""Pinder dataset. | |
Args: | |
Dataset: PyTorch Geometric Dataset. | |
""" | |
def __init__(self, file_paths: List[str]) -> None: | |
"""Initialize the PinderDataset. | |
Args: | |
file_paths: List of file paths. | |
""" | |
super().__init__() | |
self.file_paths = file_paths | |
def processed_file_names(self) -> List[str]: | |
"""Return the processed file names. | |
Returns: | |
List[str]: List of processed | |
""" | |
return self.file_paths | |
def len(self) -> int: | |
"""Return the length of the dataset. | |
Returns: | |
int: Length of the dataset | |
""" | |
return len(self.processed_file_names) | |
def get(self, idx) -> CropPairedPDB: | |
"""Get the data at the given index. | |
Args: | |
idx: Index of the data. | |
Returns: | |
CropPairedPDB: CropPairedPDB object. | |
""" | |
data = torch.load(self.processed_file_names[idx], weights_only=False) | |
return data | |
if __name__ == "__main__": | |
file_paths = ["./data/processed/apo/test/1a19__A1_P11540--1a19__B1_P11540.pt"] | |
dataset = PinderDataset(file_paths=file_paths) | |
print(dataset[0]) | |