Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
# from itertools import product | |
from collections import namedtuple | |
from numbers import Number | |
from typing import Any, Dict, Optional, Sequence, Union, Literal | |
# import numpy as np | |
import pandas as pd | |
from lightning import LightningDataModule | |
from torch.utils.data import Dataset, DataLoader, random_split | |
from deepscreen.data.utils.label import label_transform | |
from deepscreen.data.utils.collator import collate_fn | |
from deepscreen.data.utils.sampler import SafeBatchSampler | |
class DTIDataset(Dataset): | |
def __init__( | |
self, | |
task: Literal['regression', 'binary', 'multiclass'], | |
n_classes: Optional[int], | |
data_dir: str, | |
dataset_name: str, | |
drug_featurizer: callable, | |
protein_featurizer: callable, | |
thresholds: Optional[Union[Number, Sequence[Number]]] = None, | |
discard_intermediate: Optional[bool] = False, | |
): | |
df = pd.read_csv( | |
f'{data_dir}{dataset_name}.csv', | |
header=0, sep=',', | |
usecols=lambda x: x in ['X1', 'ID1', 'X2', 'ID2', 'Y', 'U'], | |
dtype={'X1': 'str', 'ID1': 'str', | |
'X2': 'str', 'ID2': 'str', | |
'Y': 'float32', 'U': 'str'} | |
) | |
# if 'ID1' in df: | |
# self.x1_to_id1 = dict(zip(df['X1'], df['ID1'])) | |
# if 'ID2' in df: | |
# self.x2_to_id2 = dict(zip(df['X2'], df['ID2'])) | |
# self.id2_to_indexes = dict(zip(df['ID2'], range(len(df['ID2'])))) | |
# self.x2_to_indexes = dict(zip(df['X2'], range(len(df['X2'])))) | |
# # train and eval mode data processing (fully labelled) | |
# if 'Y' in df.columns and df['Y'].notnull().all(): | |
# Forward-fill all non-label columns | |
df.loc[:, df.columns != 'Y'] = df.loc[:, df.columns != 'Y'].ffill(axis=0) | |
if 'Y' in df: | |
# Transform labels | |
df['Y'] = df['Y'].apply(label_transform, units=df.get('U', None), thresholds=thresholds, | |
discard_intermediate=discard_intermediate).astype('float32') | |
# Filter out rows with a NaN in Y (missing values) | |
df.dropna(subset=['Y'], inplace=True) | |
# Validate target labels for training/testing | |
# TODO: check sklearn.utils.multiclass.check_classification_targets | |
match task: | |
case 'regression': | |
assert all(df['Y'].apply(lambda x: isinstance(x, Number))), \ | |
f"Y for task `regression` must be numeric; got {set(df['Y'].apply(type))}." | |
case 'binary': | |
assert all(df['Y'].isin([0, 1])), \ | |
f"Y for task `binary` (classification) must be 0 or 1, but Y got {pd.unique(df['Y'])}." \ | |
"\nYou may set `thresholds` to discretize continuous labels." | |
case 'multiclass': | |
assert n_classes >= 3, f'n_classes for task `multiclass` (classification) must be at least 3.' | |
assert all(df['Y'].apply(lambda x: x.is_integer() and x >= 0)), \ | |
f"Y for task `multiclass` (classification) must be non-negative integers, " \ | |
f"but Y got {pd.unique(df['Y'])}." \ | |
"\nYou may set `thresholds` to discretize continuous labels." | |
target_n_unique = df['Y'].nunique() | |
assert target_n_unique == n_classes, \ | |
f"You have set n_classes for task `multiclass` (classification) task to {n_classes}, " \ | |
f"but Y has {target_n_unique} unique labels." | |
# # Predict mode data processing | |
# else: | |
# df = pd.DataFrame(product(df['X1'].dropna(), df['X2'].dropna()), columns=['X1', 'X2']) | |
# if hasattr(self, "x1_to_id1"): | |
# df['ID1'] = df['X1'].map(self.x1_to_id1) | |
# if hasattr(self, "x1_to_id2"): | |
# df['ID2'] = df['X2'].map(self.x2_to_id2) | |
# self.smiles = df['X1'] | |
# self.fasta = df['X2'] | |
# self.smiles_ids = df.get('ID1', df['X1']) | |
# self.fasta_ids = df.get('ID2', df['X2']) | |
# self.labels = df.get('Y', None) | |
self.df = df | |
self.drug_featurizer = drug_featurizer if drug_featurizer is not None else (lambda x: x) | |
self.protein_featurizer = protein_featurizer if protein_featurizer is not None else (lambda x: x) | |
self.n_classes = df['Y'].nunique() | |
# self.train = train | |
self.Data = namedtuple('Data', ['FT1', 'ID1', 'FT2', 'ID2', 'Y']) | |
def __len__(self): | |
return len(self.df.index) | |
def __getitem__(self, idx): | |
sample = self.df.loc[idx] | |
return self.Data( | |
FT1=self.drug_featurizer(sample['X1']), | |
ID1=sample.get('ID1', sample['X1']), | |
FT2=self.protein_featurizer(sample['X2']), | |
ID2=sample.get('ID2', sample['X2']), | |
Y=sample.get('Y') | |
) | |
# { | |
# 'FT1': self.drug_featurizer(sample['X1']), | |
# 'ID1': sample.get('ID1', sample['X1']), | |
# 'FT2': self.protein_featurizer(sample['X2']), | |
# 'ID2': sample.get('ID2', sample['X2']), | |
# 'Y': sample.get('Y') | |
# } | |
# if self.train: | |
# sample = self.drug_featurizer(self.smiles[idx]), self.protein_featurizer(self.fasta[idx]), self.labels[idx] | |
# sample = { | |
# 'FT1': self.drug_featurizer(self.smiles[idx]), | |
# 'FT2': self.protein_featurizer(self.fasta[idx]), | |
# 'ID2': self.smiles_ids[idx], | |
# } | |
# else: | |
# # sample = self.drug_featurizer(self.smiles[idx]), self.protein_featurizer(self.fasta[idx]) | |
# sample = { | |
# 'FT1': self.drug_featurizer(self.smiles[idx]), | |
# 'FT2': self.protein_featurizer(self.fasta[idx]), | |
# } | |
# | |
# if all([True if n is not None else False for n in sample.values()]): | |
# return sample # | { | |
# # 'ID1': self.smiles_ids[idx], | |
# # 'X1': self.drug_featurizer(self.smiles[idx]), | |
# # 'ID2': self.fasta_ids[idx], | |
# # 'X2': self.protein_featurizer(self.fasta[idx]), | |
# # } | |
# else: | |
# return self.__getitem__(np.random.randint(0, self.size)) | |
class DTIdatamodule(LightningDataModule): | |
""" | |
DTI DataModule | |
A DataModule implements 5 key methods: | |
def prepare_data(self): | |
# things to do on 1 GPU/TPU (not on every GPU/TPU in DDP) | |
# download data, pre-process, split, save to disk, etc. | |
def setup(self, stage): | |
# things to do on every process in DDP | |
# load data, set variables, etc. | |
def train_dataloader(self): | |
# return train dataloader | |
def val_dataloader(self): | |
# return validation dataloader | |
def test_dataloader(self): | |
# return test dataloader | |
def teardown(self): | |
# called on every process in DDP | |
# clean up after fit or test | |
This allows you to share a full dataset without explaining how to download, | |
split, transform and process the data. | |
Read the docs: | |
https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html | |
""" | |
def __init__( | |
self, | |
task: Literal['regression', 'binary', 'multiclass'], | |
n_classes: Optional[int], | |
train: bool, | |
drug_featurizer: callable, | |
protein_featurizer: callable, | |
batch_size: int, | |
train_val_test_split: Optional[Sequence[Number]], | |
num_workers: int = 0, | |
thresholds: Optional[Union[Number, Sequence[Number]]] = None, | |
pin_memory: bool = False, | |
data_dir: str = "data/", | |
dataset_name: Optional[str] = None, | |
split: Optional[callable] = random_split, | |
): | |
super().__init__() | |
# this line allows to access init params with 'self.hparams' attribute | |
# also ensures init params will be stored in ckpt | |
self.save_hyperparameters(logger=False) | |
# data processing | |
self.data_split = split | |
self.data_train: Optional[Dataset] = None | |
self.data_val: Optional[Dataset] = None | |
self.data_test: Optional[Dataset] = None | |
self.data_predict: Optional[Dataset] = None | |
def prepare_data(self): | |
""" | |
Download data if needed. | |
Do not use it to assign state (e.g., self.x = x). | |
""" | |
def setup(self, stage: Optional[str] = None, encoding: str = None): | |
""" | |
Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`. | |
This method is called by lightning with both `trainer.fit()` and `trainer.test()`, so be | |
careful not to execute data splitting twice. | |
""" | |
# TODO test SafeBatchSampler (which skips samples with any None without introducing variable batch size) | |
# load and split datasets only if not loaded in initialization | |
if not any([self.data_train, self.data_val, self.data_test, self.data_predict]): | |
dataset = DTIDataset( | |
task=self.hparams.task, | |
n_classes=self.hparams.n_classes, | |
data_dir=self.hparams.data_dir, | |
drug_featurizer=self.hparams.drug_featurizer, | |
protein_featurizer=self.hparams.protein_featurizer, | |
dataset_name=self.hparams.dataset_name, | |
thresholds=self.hparams.thresholds, | |
) | |
if self.hparams.train: | |
self.data_train, self.data_val, self.data_test = self.data_split( | |
dataset=dataset, | |
lengths=self.hparams.train_val_test_split | |
) | |
else: | |
self.data_test = self.data_predict = dataset | |
def train_dataloader(self): | |
return DataLoader( | |
dataset=self.data_train, | |
batch_sampler=SafeBatchSampler( | |
data_source=self.data_train, | |
batch_size=self.hparams.batch_size, | |
drop_last=True, | |
shuffle=True, | |
), | |
# batch_size=self.hparams.batch_size, | |
# shuffle=True, | |
num_workers=self.hparams.num_workers, | |
pin_memory=self.hparams.pin_memory, | |
collate_fn=collate_fn, | |
persistent_workers=True if self.hparams.num_workers > 0 else False | |
) | |
def val_dataloader(self): | |
return DataLoader( | |
dataset=self.data_val, | |
batch_sampler=SafeBatchSampler( | |
data_source=self.data_val, | |
batch_size=self.hparams.batch_size, | |
drop_last=False, | |
shuffle=False, | |
), | |
# batch_size=self.hparams.batch_size, | |
# shuffle=False, | |
num_workers=self.hparams.num_workers, | |
pin_memory=self.hparams.pin_memory, | |
collate_fn=collate_fn, | |
persistent_workers=True if self.hparams.num_workers > 0 else False | |
) | |
def test_dataloader(self): | |
return DataLoader( | |
dataset=self.data_test, | |
batch_sampler=SafeBatchSampler( | |
data_source=self.data_test, | |
batch_size=self.hparams.batch_size, | |
drop_last=False, | |
shuffle=False, | |
), | |
# batch_size=self.hparams.batch_size, | |
# shuffle=False, | |
num_workers=self.hparams.num_workers, | |
pin_memory=self.hparams.pin_memory, | |
collate_fn=collate_fn, | |
persistent_workers=True if self.hparams.num_workers > 0 else False | |
) | |
def predict_dataloader(self): | |
return DataLoader( | |
dataset=self.data_predict, | |
batch_sampler=SafeBatchSampler( | |
data_source=self.data_predict, | |
batch_size=self.hparams.batch_size, | |
drop_last=False, | |
shuffle=False, | |
), | |
# batch_size=self.hparams.batch_size, | |
# shuffle=False, | |
num_workers=self.hparams.num_workers, | |
pin_memory=self.hparams.pin_memory, | |
collate_fn=collate_fn, | |
persistent_workers=True if self.hparams.num_workers > 0 else False | |
) | |
def teardown(self, stage: Optional[str] = None): | |
"""Clean up after fit or test.""" | |
pass | |
def state_dict(self): | |
"""Extra things to save to checkpoint.""" | |
return {} | |
def load_state_dict(self, state_dict: Dict[str, Any]): | |
"""Things to do when loading checkpoint.""" | |
pass | |