# from itertools import product from collections import namedtuple from numbers import Number from typing import Any, Dict, Optional, Sequence, Union, Literal # import numpy as np import pandas as pd from lightning import LightningDataModule from torch.utils.data import Dataset, DataLoader, random_split from deepscreen.data.utils.label import label_transform from deepscreen.data.utils.collator import collate_fn from deepscreen.data.utils.sampler import SafeBatchSampler class DTIDataset(Dataset): def __init__( self, task: Literal['regression', 'binary', 'multiclass'], n_classes: Optional[int], data_dir: str, dataset_name: str, drug_featurizer: callable, protein_featurizer: callable, thresholds: Optional[Union[Number, Sequence[Number]]] = None, discard_intermediate: Optional[bool] = False, ): df = pd.read_csv( f'{data_dir}{dataset_name}.csv', header=0, sep=',', usecols=lambda x: x in ['X1', 'ID1', 'X2', 'ID2', 'Y', 'U'], dtype={'X1': 'str', 'ID1': 'str', 'X2': 'str', 'ID2': 'str', 'Y': 'float32', 'U': 'str'} ) # if 'ID1' in df: # self.x1_to_id1 = dict(zip(df['X1'], df['ID1'])) # if 'ID2' in df: # self.x2_to_id2 = dict(zip(df['X2'], df['ID2'])) # self.id2_to_indexes = dict(zip(df['ID2'], range(len(df['ID2'])))) # self.x2_to_indexes = dict(zip(df['X2'], range(len(df['X2'])))) # # train and eval mode data processing (fully labelled) # if 'Y' in df.columns and df['Y'].notnull().all(): # Forward-fill all non-label columns df.loc[:, df.columns != 'Y'] = df.loc[:, df.columns != 'Y'].ffill(axis=0) if 'Y' in df: # Transform labels df['Y'] = df['Y'].apply(label_transform, units=df.get('U', None), thresholds=thresholds, discard_intermediate=discard_intermediate).astype('float32') # Filter out rows with a NaN in Y (missing values) df.dropna(subset=['Y'], inplace=True) # Validate target labels for training/testing # TODO: check sklearn.utils.multiclass.check_classification_targets match task: case 'regression': assert all(df['Y'].apply(lambda x: isinstance(x, Number))), \ f"Y for task `regression` must be numeric; got {set(df['Y'].apply(type))}." case 'binary': assert all(df['Y'].isin([0, 1])), \ f"Y for task `binary` (classification) must be 0 or 1, but Y got {pd.unique(df['Y'])}." \ "\nYou may set `thresholds` to discretize continuous labels." case 'multiclass': assert n_classes >= 3, f'n_classes for task `multiclass` (classification) must be at least 3.' assert all(df['Y'].apply(lambda x: x.is_integer() and x >= 0)), \ f"Y for task `multiclass` (classification) must be non-negative integers, " \ f"but Y got {pd.unique(df['Y'])}." \ "\nYou may set `thresholds` to discretize continuous labels." target_n_unique = df['Y'].nunique() assert target_n_unique == n_classes, \ f"You have set n_classes for task `multiclass` (classification) task to {n_classes}, " \ f"but Y has {target_n_unique} unique labels." # # Predict mode data processing # else: # df = pd.DataFrame(product(df['X1'].dropna(), df['X2'].dropna()), columns=['X1', 'X2']) # if hasattr(self, "x1_to_id1"): # df['ID1'] = df['X1'].map(self.x1_to_id1) # if hasattr(self, "x1_to_id2"): # df['ID2'] = df['X2'].map(self.x2_to_id2) # self.smiles = df['X1'] # self.fasta = df['X2'] # self.smiles_ids = df.get('ID1', df['X1']) # self.fasta_ids = df.get('ID2', df['X2']) # self.labels = df.get('Y', None) self.df = df self.drug_featurizer = drug_featurizer if drug_featurizer is not None else (lambda x: x) self.protein_featurizer = protein_featurizer if protein_featurizer is not None else (lambda x: x) self.n_classes = df['Y'].nunique() # self.train = train self.Data = namedtuple('Data', ['FT1', 'ID1', 'FT2', 'ID2', 'Y']) def __len__(self): return len(self.df.index) def __getitem__(self, idx): sample = self.df.loc[idx] return self.Data( FT1=self.drug_featurizer(sample['X1']), ID1=sample.get('ID1', sample['X1']), FT2=self.protein_featurizer(sample['X2']), ID2=sample.get('ID2', sample['X2']), Y=sample.get('Y') ) # { # 'FT1': self.drug_featurizer(sample['X1']), # 'ID1': sample.get('ID1', sample['X1']), # 'FT2': self.protein_featurizer(sample['X2']), # 'ID2': sample.get('ID2', sample['X2']), # 'Y': sample.get('Y') # } # if self.train: # sample = self.drug_featurizer(self.smiles[idx]), self.protein_featurizer(self.fasta[idx]), self.labels[idx] # sample = { # 'FT1': self.drug_featurizer(self.smiles[idx]), # 'FT2': self.protein_featurizer(self.fasta[idx]), # 'ID2': self.smiles_ids[idx], # } # else: # # sample = self.drug_featurizer(self.smiles[idx]), self.protein_featurizer(self.fasta[idx]) # sample = { # 'FT1': self.drug_featurizer(self.smiles[idx]), # 'FT2': self.protein_featurizer(self.fasta[idx]), # } # # if all([True if n is not None else False for n in sample.values()]): # return sample # | { # # 'ID1': self.smiles_ids[idx], # # 'X1': self.drug_featurizer(self.smiles[idx]), # # 'ID2': self.fasta_ids[idx], # # 'X2': self.protein_featurizer(self.fasta[idx]), # # } # else: # return self.__getitem__(np.random.randint(0, self.size)) class DTIdatamodule(LightningDataModule): """ DTI DataModule A DataModule implements 5 key methods: def prepare_data(self): # things to do on 1 GPU/TPU (not on every GPU/TPU in DDP) # download data, pre-process, split, save to disk, etc. def setup(self, stage): # things to do on every process in DDP # load data, set variables, etc. def train_dataloader(self): # return train dataloader def val_dataloader(self): # return validation dataloader def test_dataloader(self): # return test dataloader def teardown(self): # called on every process in DDP # clean up after fit or test This allows you to share a full dataset without explaining how to download, split, transform and process the data. Read the docs: https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html """ def __init__( self, task: Literal['regression', 'binary', 'multiclass'], n_classes: Optional[int], train: bool, drug_featurizer: callable, protein_featurizer: callable, batch_size: int, train_val_test_split: Optional[Sequence[Number]], num_workers: int = 0, thresholds: Optional[Union[Number, Sequence[Number]]] = None, pin_memory: bool = False, data_dir: str = "data/", dataset_name: Optional[str] = None, split: Optional[callable] = random_split, ): super().__init__() # this line allows to access init params with 'self.hparams' attribute # also ensures init params will be stored in ckpt self.save_hyperparameters(logger=False) # data processing self.data_split = split self.data_train: Optional[Dataset] = None self.data_val: Optional[Dataset] = None self.data_test: Optional[Dataset] = None self.data_predict: Optional[Dataset] = None def prepare_data(self): """ Download data if needed. Do not use it to assign state (e.g., self.x = x). """ def setup(self, stage: Optional[str] = None, encoding: str = None): """ Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`. This method is called by lightning with both `trainer.fit()` and `trainer.test()`, so be careful not to execute data splitting twice. """ # TODO test SafeBatchSampler (which skips samples with any None without introducing variable batch size) # load and split datasets only if not loaded in initialization if not any([self.data_train, self.data_val, self.data_test, self.data_predict]): dataset = DTIDataset( task=self.hparams.task, n_classes=self.hparams.n_classes, data_dir=self.hparams.data_dir, drug_featurizer=self.hparams.drug_featurizer, protein_featurizer=self.hparams.protein_featurizer, dataset_name=self.hparams.dataset_name, thresholds=self.hparams.thresholds, ) if self.hparams.train: self.data_train, self.data_val, self.data_test = self.data_split( dataset=dataset, lengths=self.hparams.train_val_test_split ) else: self.data_test = self.data_predict = dataset def train_dataloader(self): return DataLoader( dataset=self.data_train, batch_sampler=SafeBatchSampler( data_source=self.data_train, batch_size=self.hparams.batch_size, drop_last=True, shuffle=True, ), # batch_size=self.hparams.batch_size, # shuffle=True, num_workers=self.hparams.num_workers, pin_memory=self.hparams.pin_memory, collate_fn=collate_fn, persistent_workers=True if self.hparams.num_workers > 0 else False ) def val_dataloader(self): return DataLoader( dataset=self.data_val, batch_sampler=SafeBatchSampler( data_source=self.data_val, batch_size=self.hparams.batch_size, drop_last=False, shuffle=False, ), # batch_size=self.hparams.batch_size, # shuffle=False, num_workers=self.hparams.num_workers, pin_memory=self.hparams.pin_memory, collate_fn=collate_fn, persistent_workers=True if self.hparams.num_workers > 0 else False ) def test_dataloader(self): return DataLoader( dataset=self.data_test, batch_sampler=SafeBatchSampler( data_source=self.data_test, batch_size=self.hparams.batch_size, drop_last=False, shuffle=False, ), # batch_size=self.hparams.batch_size, # shuffle=False, num_workers=self.hparams.num_workers, pin_memory=self.hparams.pin_memory, collate_fn=collate_fn, persistent_workers=True if self.hparams.num_workers > 0 else False ) def predict_dataloader(self): return DataLoader( dataset=self.data_predict, batch_sampler=SafeBatchSampler( data_source=self.data_predict, batch_size=self.hparams.batch_size, drop_last=False, shuffle=False, ), # batch_size=self.hparams.batch_size, # shuffle=False, num_workers=self.hparams.num_workers, pin_memory=self.hparams.pin_memory, collate_fn=collate_fn, persistent_workers=True if self.hparams.num_workers > 0 else False ) def teardown(self, stage: Optional[str] = None): """Clean up after fit or test.""" pass def state_dict(self): """Extra things to save to checkpoint.""" return {} def load_state_dict(self, state_dict: Dict[str, Any]): """Things to do when loading checkpoint.""" pass