libokj's picture
Upload 110 files
c0ec7e6
raw
history blame
12.8 kB
# from itertools import product
from collections import namedtuple
from numbers import Number
from typing import Any, Dict, Optional, Sequence, Union, Literal
# import numpy as np
import pandas as pd
from lightning import LightningDataModule
from torch.utils.data import Dataset, DataLoader, random_split
from deepscreen.data.utils.label import label_transform
from deepscreen.data.utils.collator import collate_fn
from deepscreen.data.utils.sampler import SafeBatchSampler
class DTIDataset(Dataset):
def __init__(
self,
task: Literal['regression', 'binary', 'multiclass'],
n_classes: Optional[int],
data_dir: str,
dataset_name: str,
drug_featurizer: callable,
protein_featurizer: callable,
thresholds: Optional[Union[Number, Sequence[Number]]] = None,
discard_intermediate: Optional[bool] = False,
):
df = pd.read_csv(
f'{data_dir}{dataset_name}.csv',
header=0, sep=',',
usecols=lambda x: x in ['X1', 'ID1', 'X2', 'ID2', 'Y', 'U'],
dtype={'X1': 'str', 'ID1': 'str',
'X2': 'str', 'ID2': 'str',
'Y': 'float32', 'U': 'str'}
)
# if 'ID1' in df:
# self.x1_to_id1 = dict(zip(df['X1'], df['ID1']))
# if 'ID2' in df:
# self.x2_to_id2 = dict(zip(df['X2'], df['ID2']))
# self.id2_to_indexes = dict(zip(df['ID2'], range(len(df['ID2']))))
# self.x2_to_indexes = dict(zip(df['X2'], range(len(df['X2']))))
# # train and eval mode data processing (fully labelled)
# if 'Y' in df.columns and df['Y'].notnull().all():
# Forward-fill all non-label columns
df.loc[:, df.columns != 'Y'] = df.loc[:, df.columns != 'Y'].ffill(axis=0)
if 'Y' in df:
# Transform labels
df['Y'] = df['Y'].apply(label_transform, units=df.get('U', None), thresholds=thresholds,
discard_intermediate=discard_intermediate).astype('float32')
# Filter out rows with a NaN in Y (missing values)
df.dropna(subset=['Y'], inplace=True)
# Validate target labels for training/testing
# TODO: check sklearn.utils.multiclass.check_classification_targets
match task:
case 'regression':
assert all(df['Y'].apply(lambda x: isinstance(x, Number))), \
f"Y for task `regression` must be numeric; got {set(df['Y'].apply(type))}."
case 'binary':
assert all(df['Y'].isin([0, 1])), \
f"Y for task `binary` (classification) must be 0 or 1, but Y got {pd.unique(df['Y'])}." \
"\nYou may set `thresholds` to discretize continuous labels."
case 'multiclass':
assert n_classes >= 3, f'n_classes for task `multiclass` (classification) must be at least 3.'
assert all(df['Y'].apply(lambda x: x.is_integer() and x >= 0)), \
f"Y for task `multiclass` (classification) must be non-negative integers, " \
f"but Y got {pd.unique(df['Y'])}." \
"\nYou may set `thresholds` to discretize continuous labels."
target_n_unique = df['Y'].nunique()
assert target_n_unique == n_classes, \
f"You have set n_classes for task `multiclass` (classification) task to {n_classes}, " \
f"but Y has {target_n_unique} unique labels."
# # Predict mode data processing
# else:
# df = pd.DataFrame(product(df['X1'].dropna(), df['X2'].dropna()), columns=['X1', 'X2'])
# if hasattr(self, "x1_to_id1"):
# df['ID1'] = df['X1'].map(self.x1_to_id1)
# if hasattr(self, "x1_to_id2"):
# df['ID2'] = df['X2'].map(self.x2_to_id2)
# self.smiles = df['X1']
# self.fasta = df['X2']
# self.smiles_ids = df.get('ID1', df['X1'])
# self.fasta_ids = df.get('ID2', df['X2'])
# self.labels = df.get('Y', None)
self.df = df
self.drug_featurizer = drug_featurizer if drug_featurizer is not None else (lambda x: x)
self.protein_featurizer = protein_featurizer if protein_featurizer is not None else (lambda x: x)
self.n_classes = df['Y'].nunique()
# self.train = train
self.Data = namedtuple('Data', ['FT1', 'ID1', 'FT2', 'ID2', 'Y'])
def __len__(self):
return len(self.df.index)
def __getitem__(self, idx):
sample = self.df.loc[idx]
return self.Data(
FT1=self.drug_featurizer(sample['X1']),
ID1=sample.get('ID1', sample['X1']),
FT2=self.protein_featurizer(sample['X2']),
ID2=sample.get('ID2', sample['X2']),
Y=sample.get('Y')
)
# {
# 'FT1': self.drug_featurizer(sample['X1']),
# 'ID1': sample.get('ID1', sample['X1']),
# 'FT2': self.protein_featurizer(sample['X2']),
# 'ID2': sample.get('ID2', sample['X2']),
# 'Y': sample.get('Y')
# }
# if self.train:
# sample = self.drug_featurizer(self.smiles[idx]), self.protein_featurizer(self.fasta[idx]), self.labels[idx]
# sample = {
# 'FT1': self.drug_featurizer(self.smiles[idx]),
# 'FT2': self.protein_featurizer(self.fasta[idx]),
# 'ID2': self.smiles_ids[idx],
# }
# else:
# # sample = self.drug_featurizer(self.smiles[idx]), self.protein_featurizer(self.fasta[idx])
# sample = {
# 'FT1': self.drug_featurizer(self.smiles[idx]),
# 'FT2': self.protein_featurizer(self.fasta[idx]),
# }
#
# if all([True if n is not None else False for n in sample.values()]):
# return sample # | {
# # 'ID1': self.smiles_ids[idx],
# # 'X1': self.drug_featurizer(self.smiles[idx]),
# # 'ID2': self.fasta_ids[idx],
# # 'X2': self.protein_featurizer(self.fasta[idx]),
# # }
# else:
# return self.__getitem__(np.random.randint(0, self.size))
class DTIdatamodule(LightningDataModule):
"""
DTI DataModule
A DataModule implements 5 key methods:
def prepare_data(self):
# things to do on 1 GPU/TPU (not on every GPU/TPU in DDP)
# download data, pre-process, split, save to disk, etc.
def setup(self, stage):
# things to do on every process in DDP
# load data, set variables, etc.
def train_dataloader(self):
# return train dataloader
def val_dataloader(self):
# return validation dataloader
def test_dataloader(self):
# return test dataloader
def teardown(self):
# called on every process in DDP
# clean up after fit or test
This allows you to share a full dataset without explaining how to download,
split, transform and process the data.
Read the docs:
https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html
"""
def __init__(
self,
task: Literal['regression', 'binary', 'multiclass'],
n_classes: Optional[int],
train: bool,
drug_featurizer: callable,
protein_featurizer: callable,
batch_size: int,
train_val_test_split: Optional[Sequence[Number]],
num_workers: int = 0,
thresholds: Optional[Union[Number, Sequence[Number]]] = None,
pin_memory: bool = False,
data_dir: str = "data/",
dataset_name: Optional[str] = None,
split: Optional[callable] = random_split,
):
super().__init__()
# this line allows to access init params with 'self.hparams' attribute
# also ensures init params will be stored in ckpt
self.save_hyperparameters(logger=False)
# data processing
self.data_split = split
self.data_train: Optional[Dataset] = None
self.data_val: Optional[Dataset] = None
self.data_test: Optional[Dataset] = None
self.data_predict: Optional[Dataset] = None
def prepare_data(self):
"""
Download data if needed.
Do not use it to assign state (e.g., self.x = x).
"""
def setup(self, stage: Optional[str] = None, encoding: str = None):
"""
Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`.
This method is called by lightning with both `trainer.fit()` and `trainer.test()`, so be
careful not to execute data splitting twice.
"""
# TODO test SafeBatchSampler (which skips samples with any None without introducing variable batch size)
# load and split datasets only if not loaded in initialization
if not any([self.data_train, self.data_val, self.data_test, self.data_predict]):
dataset = DTIDataset(
task=self.hparams.task,
n_classes=self.hparams.n_classes,
data_dir=self.hparams.data_dir,
drug_featurizer=self.hparams.drug_featurizer,
protein_featurizer=self.hparams.protein_featurizer,
dataset_name=self.hparams.dataset_name,
thresholds=self.hparams.thresholds,
)
if self.hparams.train:
self.data_train, self.data_val, self.data_test = self.data_split(
dataset=dataset,
lengths=self.hparams.train_val_test_split
)
else:
self.data_test = self.data_predict = dataset
def train_dataloader(self):
return DataLoader(
dataset=self.data_train,
batch_sampler=SafeBatchSampler(
data_source=self.data_train,
batch_size=self.hparams.batch_size,
drop_last=True,
shuffle=True,
),
# batch_size=self.hparams.batch_size,
# shuffle=True,
num_workers=self.hparams.num_workers,
pin_memory=self.hparams.pin_memory,
collate_fn=collate_fn,
persistent_workers=True if self.hparams.num_workers > 0 else False
)
def val_dataloader(self):
return DataLoader(
dataset=self.data_val,
batch_sampler=SafeBatchSampler(
data_source=self.data_val,
batch_size=self.hparams.batch_size,
drop_last=False,
shuffle=False,
),
# batch_size=self.hparams.batch_size,
# shuffle=False,
num_workers=self.hparams.num_workers,
pin_memory=self.hparams.pin_memory,
collate_fn=collate_fn,
persistent_workers=True if self.hparams.num_workers > 0 else False
)
def test_dataloader(self):
return DataLoader(
dataset=self.data_test,
batch_sampler=SafeBatchSampler(
data_source=self.data_test,
batch_size=self.hparams.batch_size,
drop_last=False,
shuffle=False,
),
# batch_size=self.hparams.batch_size,
# shuffle=False,
num_workers=self.hparams.num_workers,
pin_memory=self.hparams.pin_memory,
collate_fn=collate_fn,
persistent_workers=True if self.hparams.num_workers > 0 else False
)
def predict_dataloader(self):
return DataLoader(
dataset=self.data_predict,
batch_sampler=SafeBatchSampler(
data_source=self.data_predict,
batch_size=self.hparams.batch_size,
drop_last=False,
shuffle=False,
),
# batch_size=self.hparams.batch_size,
# shuffle=False,
num_workers=self.hparams.num_workers,
pin_memory=self.hparams.pin_memory,
collate_fn=collate_fn,
persistent_workers=True if self.hparams.num_workers > 0 else False
)
def teardown(self, stage: Optional[str] = None):
"""Clean up after fit or test."""
pass
def state_dict(self):
"""Extra things to save to checkpoint."""
return {}
def load_state_dict(self, state_dict: Dict[str, Any]):
"""Things to do when loading checkpoint."""
pass