Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import re | |
from functools import partial | |
from numbers import Number | |
from pathlib import Path | |
from typing import Any, Dict, Optional, Sequence, Union, Literal | |
from lightning import LightningDataModule | |
import pandas as pd | |
from pandarallel import pandarallel | |
from rdkit import Chem | |
#import swifter | |
from sklearn.preprocessing import LabelEncoder | |
from torch.utils.data import Dataset, DataLoader | |
from deepscreen.data.utils import label_transform, collate_fn, SafeBatchSampler | |
from deepscreen.utils import get_logger | |
log = get_logger(__name__) | |
pandarallel.initialize(progress_bar=True) | |
SMILES_PAT = r"[^A-Za-z0-9=#:+\-\[\]<>()/\\@%,.*]" | |
FASTA_PAT = r"[^A-Z*\-]" | |
def validate_seq_str(seq, regex): | |
if seq: | |
err_charset = set(re.findall(regex, seq)) | |
if not err_charset: | |
return None | |
else: | |
return ', '.join(err_charset) | |
else: | |
return 'Empty string' | |
# TODO: save a list of corrupted records | |
def rdkit_canonicalize(smiles): | |
try: | |
mol = Chem.MolFromSmiles(smiles) | |
smiles = Chem.MolToSmiles(mol) | |
except Exception as e: | |
log.warning(f'Failed to canonicalize SMILES using RDKIT due to {str(e)}. Returning original SMILES: {smiles}') | |
return smiles | |
class DTIDataset(Dataset): | |
def __init__( | |
self, | |
task: Literal['regression', 'binary', 'multiclass'], | |
num_classes: Optional[int], | |
data_path: str | Path, | |
drug_featurizer: callable, | |
protein_featurizer: callable, | |
thresholds: Optional[Union[Number, Sequence[Number]]] = None, | |
discard_intermediate: Optional[bool] = False, | |
query: Optional[str] = 'X2' | |
): | |
df = pd.read_csv( | |
data_path, | |
engine='python', | |
header=0, | |
usecols=lambda x: x in ['X1', 'ID1', 'X2', 'ID2', 'Y', 'U'], | |
dtype={ | |
'X1': 'str', | |
'ID1': 'str', | |
'X2': 'str', | |
'ID2': 'str', | |
'Y': 'float32', | |
'U': 'str', | |
}, | |
) | |
# Read the whole data table | |
# if 'ID1' in df: | |
# self.x1_to_id1 = dict(zip(df['X1'], df['ID1'])) | |
# if 'ID2' in df: | |
# self.x2_to_id2 = dict(zip(df['X2'], df['ID2'])) | |
# self.id2_to_indexes = dict(zip(df['ID2'], range(len(df['ID2'])))) | |
# self.x2_to_indexes = dict(zip(df['X2'], range(len(df['X2'])))) | |
# # train and eval mode data processing (fully labelled) | |
# if 'Y' in df.columns and df['Y'].notnull().all(): | |
log.info(f"Processing data file: {data_path}") | |
# Forward-fill all non-label columns | |
df.loc[:, df.columns != 'Y'] = df.loc[:, df.columns != 'Y'].ffill(axis=0) | |
# Fill NAs in string cols with an empty string to prevent wrong type inference by pytorch collator | |
for col in df.columns: | |
if df[col].dtype == 'object': | |
df[col] = df[col].fillna('') | |
# TODO potentially allow running through the whole data validation process | |
# error = False | |
if 'Y' in df: | |
log.info(f"Validating labels (`Y`)...") | |
# TODO: check sklearn.utils.multiclass.check_classification_targets | |
match task: | |
case 'regression': | |
assert all(df['Y'].parallel_apply(lambda x: isinstance(x, Number))), \ | |
f"""`Y` must be numeric for `regression` task, | |
but it has {set(df['Y'].parallel_apply(type))}.""" | |
case 'binary': | |
if all(df['Y'].isin([0, 1])): | |
assert not thresholds, \ | |
f"""`Y` is already 0 or 1 for `binary` (classification) `task`, | |
but still got `thresholds` ({thresholds}). | |
Double check your choices of `task` and `thresholds`, and records in the `Y` column.""" | |
else: | |
assert thresholds, \ | |
f"""`Y` must be 0 or 1 for `binary` (classification) `task`, | |
but it has {pd.unique(df['Y'])}. | |
You may set `thresholds` to discretize continuous labels.""" # TODO print err idx instead | |
case 'multiclass': | |
assert num_classes >= 3, f'`num_classes` for `task=multiclass` must be at least 3.' | |
if all(df['Y'].parallel_apply(lambda x: x.is_integer() and x >= 0)): | |
assert not thresholds, \ | |
f"""`Y` is already non-negative integers for | |
`multiclass` (classification) `task`, but still got `thresholds` ({thresholds}). | |
Double check your choice of `task`, `thresholds` and records in the `Y` column.""" | |
else: | |
assert thresholds, \ | |
f"""`Y` must be non-negative integers for | |
`multiclass` (classification) 'task',but it has {pd.unique(df['Y'])}. | |
You must set `thresholds` to discretize continuous labels.""" # TODO print err idx instead | |
if 'U' in df.columns: | |
units = df['U'] | |
else: | |
units = None | |
log.warning("Units ('U') not in the data table. " | |
"Assuming all labels to be discrete or in p-scale (-log10[M]).") | |
# Transform labels | |
df['Y'] = label_transform(labels=df['Y'], units=units, thresholds=thresholds, | |
discard_intermediate=discard_intermediate) | |
# Filter out rows with a NaN in Y (missing values) | |
df.dropna(subset=['Y'], inplace=True) | |
match task: | |
case 'regression': | |
df['Y'] = df['Y'].astype('float32') | |
assert all(df['Y'].parallel_apply(lambda x: isinstance(x, Number))), \ | |
f"""`Y` must be numeric for `regression` task, | |
but after transformation it still has {set(df['Y'].parallel_apply(type))}. | |
Double check your choices of `task` and `thresholds` and records in the `Y` and `U` columns.""" | |
# TODO print err idx instead | |
case 'binary': | |
df['Y'] = df['Y'].astype('int') | |
assert all(df['Y'].isin([0, 1])), \ | |
f"""`Y` must be 0 or 1 for `task=binary`, " | |
but after transformation it still has {pd.unique(df['Y'])}. | |
Double check your choices of `task` and `thresholds` and records in the `Y` and `U` columns.""" | |
# TODO print err idx instead | |
case 'multiclass': | |
df['Y'] = df['Y'].astype('int') | |
assert all(df['Y'].parallel_apply(lambda x: x.is_integer() and x >= 0)), \ | |
f"""Y must be non-negative integers for `task=multiclass` | |
but after transformation it still has {pd.unique(df['Y'])}. | |
Double check your choices of `task` and `thresholds` and records in the `Y` and `U` columns.""" | |
# TODO print err idx instead | |
target_n_unique = df['Y'].nunique() | |
assert target_n_unique == num_classes, \ | |
f"""You have set `num_classes` for `task=multiclass` to {num_classes}, | |
but after transformation Y still has {target_n_unique} unique labels. | |
Double check your choices of `task` and `thresholds` and records in the `Y` and `U` columns.""" | |
log.info("Validating SMILES (`X1`)...") | |
df['X1_ERR'] = df['X1'].parallel_apply(validate_seq_str, regex=SMILES_PAT) | |
if not df['X1_ERR'].isna().all(): | |
raise Exception(f"Encountered invalid SMILES:\n{df[~df['X1_ERR'].isna()][['X1', 'X1_ERR']]}") | |
df['X1^'] = df['X1'].parallel_apply(rdkit_canonicalize) | |
log.info("Validating FASTA (`X2`)...") | |
df['X2'] = df['X2'].str.upper() | |
df['X2_ERR'] = df['X2'].parallel_apply(validate_seq_str, regex=FASTA_PAT) | |
if not df['X2_ERR'].isna().all(): | |
raise Exception(f"Encountered invalid FASTA:\n{df[~df['X2_ERR'].isna()][['X2', 'X2_ERR']]}") | |
# FASTA/SMILES indices as query for retrieval metrics like enrichment factor and hit rate | |
if query: | |
df['ID^'] = LabelEncoder().fit_transform(df[query]) | |
self.df = df | |
self.drug_featurizer = drug_featurizer if drug_featurizer is not None else (lambda x: x) | |
self.protein_featurizer = protein_featurizer if protein_featurizer is not None else (lambda x: x) | |
def __len__(self): | |
return len(self.df.index) | |
def __getitem__(self, i): | |
sample = self.df.loc[i] | |
sample_dict = { | |
'N': i, | |
'X1': sample['X1'], | |
'X1^': self.drug_featurizer(sample['X1^']), | |
# 'ID1': sample.get('ID1'), | |
'X2': sample['X2'], | |
'X2^': self.protein_featurizer(sample['X2']), | |
# 'ID2': sample.get('ID2'), | |
# 'Y': sample.get('Y'), | |
# 'ID^': sample.get('ID^'), | |
} | |
optional_keys = ['ID1', 'ID2', 'ID^', 'Y'] | |
sample_dict.update({key: sample[key] for key in optional_keys if sample.get(key) is not None}) | |
return sample_dict | |
class DTIDataModule(LightningDataModule): | |
""" | |
DTI DataModule | |
A DataModule implements 5 key methods: | |
def prepare_data(self): | |
# things to do on 1 GPU/TPU (not on every GPU/TPU in DDP) | |
# download data, pre-process, split, save to disk, etc. | |
def setup(self, stage): | |
# things to do on every process in DDP | |
# load data, set variables, etc. | |
def train_dataloader(self): | |
# return train dataloader | |
def val_dataloader(self): | |
# return validation dataloader | |
def test_dataloader(self): | |
# return test dataloader | |
def teardown(self): | |
# called on every process in DDP | |
# clean up after fit or test | |
This allows you to share a full dataset without explaining how to download, | |
split, transform and process the data. | |
Read the docs: | |
https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html | |
""" | |
def __init__( | |
self, | |
task: Literal['regression', 'binary', 'multiclass'], | |
num_classes: Optional[int], | |
batch_size: int, | |
# train: bool, | |
drug_featurizer: callable, | |
protein_featurizer: callable, | |
collator: callable = collate_fn, | |
data_dir: str = "data/", | |
data_file: Optional[str] = None, | |
train_val_test_split: Optional[Union[Sequence[Number | str]]] = None, | |
split: Optional[callable] = None, | |
thresholds: Optional[Union[Number, Sequence[Number]]] = None, | |
discard_intermediate: Optional[bool] = False, | |
query: Optional[str] = 'X2', | |
num_workers: int = 0, | |
pin_memory: bool = False, | |
): | |
super().__init__() | |
self.train_data: Optional[Dataset] = None | |
self.val_data: Optional[Dataset] = None | |
self.test_data: Optional[Dataset] = None | |
self.predict_data: Optional[Dataset] = None | |
self.split = split | |
self.collator = collator | |
self.dataset = partial( | |
DTIDataset, | |
task=task, | |
num_classes=num_classes, | |
drug_featurizer=drug_featurizer, | |
protein_featurizer=protein_featurizer, | |
thresholds=thresholds, | |
discard_intermediate=discard_intermediate, | |
query=query | |
) | |
# this line allows to access init params with 'self.hparams' ensures init params will be stored in ckpt | |
self.save_hyperparameters(logger=False) # ignore=['split'] | |
def prepare_data(self): | |
""" | |
Download data if needed. | |
Do not use it to assign state (e.g., self.x = x). | |
""" | |
def setup(self, stage: Optional[str] = None, encoding: str = None): | |
""" | |
Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`. | |
This method is called by lightning with both `trainer.fit()` and `trainer.test()`, so be | |
careful not to execute data splitting twice. | |
""" | |
# load and split datasets only if not loaded in initialization | |
if not any([self.train_data, self.test_data, self.val_data, self.predict_data]): | |
if self.hparams.train_val_test_split: | |
if len(self.hparams.train_val_test_split) != 3: | |
raise ValueError('Length of `train_val_test_split` must be 3. ' | |
'Set the second element to None for training without validation. ' | |
'Set the third element to None for training without testing.') | |
self.train_data = self.hparams.train_val_test_split[0] | |
self.val_data = self.hparams.train_val_test_split[1] | |
self.test_data = self.hparams.train_val_test_split[2] | |
if all([self.hparams.data_file, self.split]): | |
if all(isinstance(split, Number) or split is None | |
for split in self.hparams.train_val_test_split): | |
split_data = self.split( | |
dataset=self.dataset(data_path=Path(self.hparams.data_dir, self.hparams.data_file)), | |
lengths=[split for split in self.hparams.train_val_test_split if split is not None] | |
) | |
for dataset in ['train_data', 'val_data', 'test_data']: | |
if getattr(self, dataset) is not None: | |
setattr(self, dataset, split_data.pop(0)) | |
else: | |
raise ValueError('`train_val_test_split` must be a sequence numbers or None' | |
'(float for percentages and int for sample numbers) ' | |
'if both `data_file` and `split` have been specified.') | |
elif (all(isinstance(split, str) or split is None | |
for split in self.hparams.train_val_test_split) | |
and not any([self.hparams.data_file, self.split])): | |
for dataset in ['train_data', 'val_data', 'test_data']: | |
if getattr(self, dataset) is not None: | |
data_path = Path(getattr(self, dataset)) | |
if not data_path.is_absolute(): | |
data_path = Path(self.hparams.data_dir, data_path) | |
setattr(self, dataset, self.dataset(data_path=data_path)) | |
else: | |
raise ValueError('For training, you must specify either all of `data_file`, `split`, ' | |
'and `train_val_test_split` as a sequence of numbers or ' | |
'solely `train_val_test_split` as a sequence of data file paths.') | |
elif self.hparams.data_file and not any([self.split, self.hparams.train_val_test_split]): | |
data_path = Path(self.hparams.data_file) | |
if not data_path.is_absolute(): | |
data_path = Path(self.hparams.data_dir, data_path) | |
self.test_data = self.predict_data = self.dataset(data_path=data_path) | |
else: | |
raise ValueError("For training, you must specify `train_val_test_split`. " | |
"For testing/predicting, you must specify only `data_file` without " | |
"`train_val_test_split` or `split`.") | |
def train_dataloader(self): | |
return DataLoader( | |
dataset=self.train_data, | |
batch_sampler=SafeBatchSampler( | |
data_source=self.train_data, | |
batch_size=self.hparams.batch_size, | |
# Dropping the last batch prevents problems caused by variable batch sizes in training, e.g., | |
# batch_size=1 in BatchNorm, and shuffling ensures the model be trained on all samples over epochs. | |
drop_last=True, | |
shuffle=True, | |
), | |
# batch_size=self.hparams.batch_size, | |
# shuffle=True, | |
num_workers=self.hparams.num_workers, | |
pin_memory=self.hparams.pin_memory, | |
collate_fn=self.collator, | |
persistent_workers=True if self.hparams.num_workers > 0 else False | |
) | |
def val_dataloader(self): | |
return DataLoader( | |
dataset=self.val_data, | |
batch_sampler=SafeBatchSampler( | |
data_source=self.val_data, | |
batch_size=self.hparams.batch_size, | |
drop_last=False, | |
shuffle=False | |
), | |
# batch_size=self.hparams.batch_size, | |
# shuffle=False, | |
num_workers=self.hparams.num_workers, | |
pin_memory=self.hparams.pin_memory, | |
collate_fn=self.collator, | |
persistent_workers=True if self.hparams.num_workers > 0 else False | |
) | |
def test_dataloader(self): | |
return DataLoader( | |
dataset=self.test_data, | |
batch_sampler=SafeBatchSampler( | |
data_source=self.test_data, | |
batch_size=self.hparams.batch_size, | |
drop_last=False, | |
shuffle=False | |
), | |
# batch_size=self.hparams.batch_size, | |
# shuffle=False, | |
num_workers=self.hparams.num_workers, | |
pin_memory=self.hparams.pin_memory, | |
collate_fn=self.collator, | |
persistent_workers=True if self.hparams.num_workers > 0 else False | |
) | |
def predict_dataloader(self): | |
return DataLoader( | |
dataset=self.predict_data, | |
batch_sampler=SafeBatchSampler( | |
data_source=self.predict_data, | |
batch_size=self.hparams.batch_size, | |
drop_last=False, | |
shuffle=False | |
), | |
# batch_size=self.hparams.batch_size, | |
# shuffle=False, | |
num_workers=self.hparams.num_workers, | |
pin_memory=self.hparams.pin_memory, | |
collate_fn=self.collator, | |
persistent_workers=True if self.hparams.num_workers > 0 else False | |
) | |
def teardown(self, stage: Optional[str] = None): | |
"""Clean up after fit or test.""" | |
pass | |
def state_dict(self): | |
"""Extra things to save to checkpoint.""" | |
return {} | |
def load_state_dict(self, state_dict: Dict[str, Any]): | |
"""Things to do when loading checkpoint.""" | |
pass | |