Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
from functools import partial | |
from numbers import Number | |
from pathlib import Path | |
from typing import Any, Dict, Optional, Sequence, Union, Literal | |
from lightning import LightningDataModule | |
import pandas as pd | |
from sklearn.preprocessing import LabelEncoder | |
from torch.utils.data import Dataset, DataLoader | |
from deepscreen.data.utils import label_transform, collate_fn, SafeBatchSampler | |
from deepscreen.utils import get_logger | |
log = get_logger(__name__) | |
# TODO: save a list of corrupted records | |
class DTIDataset(Dataset): | |
def __init__( | |
self, | |
task: Literal['regression', 'binary', 'multiclass'], | |
n_class: Optional[int], | |
data_path: str | Path, | |
drug_featurizer: callable, | |
protein_featurizer: callable, | |
thresholds: Optional[Union[Number, Sequence[Number]]] = None, | |
discard_intermediate: Optional[bool] = False, | |
): | |
df = pd.read_csv( | |
data_path, | |
engine='python', | |
header=0, | |
usecols=lambda x: x in ['X1', 'ID1', 'X2', 'ID2', 'Y', 'U'], | |
dtype={ | |
'X1': 'str', | |
'ID1': 'str', | |
'X2': 'str', | |
'ID2': 'str', | |
'Y': 'float32', | |
'U': 'str', | |
}, | |
) | |
# Read the whole data table | |
# if 'ID1' in df: | |
# self.x1_to_id1 = dict(zip(df['X1'], df['ID1'])) | |
# if 'ID2' in df: | |
# self.x2_to_id2 = dict(zip(df['X2'], df['ID2'])) | |
# self.id2_to_indexes = dict(zip(df['ID2'], range(len(df['ID2'])))) | |
# self.x2_to_indexes = dict(zip(df['X2'], range(len(df['X2'])))) | |
# # train and eval mode data processing (fully labelled) | |
# if 'Y' in df.columns and df['Y'].notnull().all(): | |
log.info(f"Processing data file: {data_path}") | |
# Forward-fill all non-label columns | |
df.loc[:, df.columns != 'Y'] = df.loc[:, df.columns != 'Y'].ffill(axis=0) | |
if 'Y' in df: | |
log.info(f"Performing pre-transformation target validation.") | |
# TODO: check sklearn.utils.multiclass.check_classification_targets | |
match task: | |
case 'regression': | |
assert all(df['Y'].apply(lambda x: isinstance(x, Number))), \ | |
f"""`Y` must be numeric for `regression` task, | |
but it has {set(df['Y'].apply(type))}.""" | |
case 'binary': | |
if all(df['Y'].isin([0, 1])): | |
assert not thresholds, \ | |
f"""`Y` is already 0 or 1 for `binary` (classification) `task`, | |
but still got `thresholds` {thresholds}. | |
Double check your choices of `task` and `thresholds` and records in the `Y` column.""" | |
else: | |
assert thresholds, \ | |
f"""`Y` must be 0 or 1 for `binary` (classification) `task`, | |
but it has {pd.unique(df['Y'])}. | |
You must set `thresholds` to discretize continuous labels.""" | |
case 'multiclass': | |
assert n_class >= 3, f'`n_class` for `multiclass` (classification) `task` must be at least 3.' | |
if all(df['Y'].apply(lambda x: x.is_integer() and x >= 0)): | |
assert not thresholds, \ | |
f"""`Y` is already non-negative integers for | |
`multiclass` (classification) `task`, but still got `thresholds` {thresholds}. | |
Double check your choice of `task`, `thresholds` and records in the `Y` column.""" | |
else: | |
assert thresholds, \ | |
f"""`Y` must be non-negative integers for | |
`multiclass` (classification) 'task',but it has {pd.unique(df['Y'])}. | |
You must set `thresholds` to discretize continuous labels.""" | |
if 'U' in df.columns: | |
units = df['U'] | |
else: | |
units = None | |
log.warning("Units ('U') not in the data table. " | |
"Assuming all labels to be discrete or in p-scale (-log10[M]).") | |
# Transform labels | |
df['Y'] = label_transform(labels=df['Y'], units=units, thresholds=thresholds, | |
discard_intermediate=discard_intermediate) | |
# Filter out rows with a NaN in Y (missing values) | |
df.dropna(subset=['Y'], inplace=True) | |
log.info(f"Performing post-transformation target validation.") | |
match task: | |
case 'regression': | |
df['Y'] = df['Y'].astype('float32') | |
assert all(df['Y'].apply(lambda x: isinstance(x, Number))), \ | |
f"""`Y` must be numeric for `regression` task, | |
but after transformation it still has {set(df['Y'].apply(type))}. | |
Double check your choices of `task` and `thresholds` and records in the `Y` and `U` columns.""" | |
case 'binary': | |
df['Y'] = df['Y'].astype('int') | |
assert all(df['Y'].isin([0, 1])), \ | |
f"""`Y` must be 0 or 1 for `binary` (classification) `task`, " | |
but after transformation it still has {pd.unique(df['Y'])}. | |
Double check your choices of `task` and `thresholds` and records in the `Y` and `U` columns.""" | |
case 'multiclass': | |
df['Y'] = df['Y'].astype('int') | |
assert all(df['Y'].apply(lambda x: x.is_integer() and x >= 0)), \ | |
f"""Y must be non-negative integers for task `multiclass` (classification) | |
but after transformation it still has {pd.unique(df['Y'])}. | |
Double check your choices of `task` and `thresholds` and records in the `Y` and `U` columns.""" | |
target_n_unique = df['Y'].nunique() | |
assert target_n_unique == n_class, \ | |
f"""You have set `n_class` for `multiclass` (classification) `task` to {n_class}, | |
but after transformation Y still has {target_n_unique} unique labels. | |
Double check your choices of `task` and `thresholds` and records in the `Y` and `U` columns.""" | |
# Indexed protein/FASTA for retrieval metrics | |
df['IDX'] = LabelEncoder().fit_transform(df['X2']) | |
self.df = df | |
self.drug_featurizer = drug_featurizer if drug_featurizer is not None else (lambda x: x) | |
self.protein_featurizer = protein_featurizer if protein_featurizer is not None else (lambda x: x) | |
def __len__(self): | |
return len(self.df.index) | |
def __getitem__(self, i): | |
sample = self.df.loc[i] | |
return { | |
'N': i, | |
'X1': self.drug_featurizer(sample['X1']), | |
'ID1': sample.get('ID1', sample['X1']), | |
'X2': self.protein_featurizer(sample['X2']), | |
'ID2': sample.get('ID2', sample['X2']), | |
'Y': sample.get('Y'), | |
'IDX': sample['IDX'], | |
} | |
class DTIDataModule(LightningDataModule): | |
""" | |
DTI DataModule | |
A DataModule implements 5 key methods: | |
def prepare_data(self): | |
# things to do on 1 GPU/TPU (not on every GPU/TPU in DDP) | |
# download data, pre-process, split, save to disk, etc. | |
def setup(self, stage): | |
# things to do on every process in DDP | |
# load data, set variables, etc. | |
def train_dataloader(self): | |
# return train dataloader | |
def val_dataloader(self): | |
# return validation dataloader | |
def test_dataloader(self): | |
# return test dataloader | |
def teardown(self): | |
# called on every process in DDP | |
# clean up after fit or test | |
This allows you to share a full dataset without explaining how to download, | |
split, transform and process the data. | |
Read the docs: | |
https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html | |
""" | |
def __init__( | |
self, | |
task: Literal['regression', 'binary', 'multiclass'], | |
n_class: Optional[int], | |
batch_size: int, | |
# train: bool, | |
drug_featurizer: callable, | |
protein_featurizer: callable, | |
collator: callable = collate_fn, | |
data_dir: str = "data/", | |
data_file: Optional[str] = None, | |
train_val_test_split: Optional[Union[Sequence[Number | str]]] = None, | |
split: Optional[callable] = None, | |
thresholds: Optional[Union[Number, Sequence[Number]]] = None, | |
discard_intermediate: Optional[bool] = False, | |
num_workers: int = 0, | |
pin_memory: bool = False, | |
): | |
super().__init__() | |
self.train_data: Optional[Dataset] = None | |
self.val_data: Optional[Dataset] = None | |
self.test_data: Optional[Dataset] = None | |
self.predict_data: Optional[Dataset] = None | |
self.split = split | |
self.collator = collator | |
self.dataset = partial( | |
DTIDataset, | |
task=task, | |
n_class=n_class, | |
drug_featurizer=drug_featurizer, | |
protein_featurizer=protein_featurizer, | |
thresholds=thresholds, | |
discard_intermediate=discard_intermediate | |
) | |
if train_val_test_split: | |
# TODO test behavior for trainer.test and predict when this is passed | |
if len(train_val_test_split) not in [2, 3]: | |
raise ValueError('Length of `train_val_test_split` must be 2 (for training without testing) or 3.') | |
if all([data_file, split]): | |
if all(isinstance(split, Number) for split in train_val_test_split): | |
pass | |
else: | |
raise ValueError('`train_val_test_split` must be a sequence numbers ' | |
'(float for percentages and int for sample numbers) ' | |
'if both `data_file` and `split` have been specified.') | |
elif all(isinstance(split, str) for split in train_val_test_split) and not any([data_file, split]): | |
split_paths = [] | |
for split in train_val_test_split: | |
split = Path(split) | |
if not split.is_absolute(): | |
split = Path(data_dir, split) | |
split_paths.append(split) | |
self.train_data = self.dataset(data_path=split_paths[0]) | |
self.val_data = self.dataset(data_path=split_paths[1]) | |
if len(train_val_test_split) == 3: | |
self.test_data = self.dataset(data_path=split_paths[2]) | |
else: | |
raise ValueError('For training, you must specify either `data_file`, `split`, ' | |
'and `train_val_test_split` as a sequence of numbers or ' | |
'solely `train_val_test_split` as a sequence of data file paths.') | |
elif data_file and not any([split, train_val_test_split]): | |
data_file = Path(data_file) | |
if not data_file.is_absolute(): | |
data_file = Path(data_dir, data_file) | |
self.test_data = self.predict_data = self.dataset(data_path=data_file) | |
else: | |
raise ValueError("For training, you must specify `train_val_test_split`. " | |
"For testing/predicting, you must specify only `data_file` without " | |
"`train_val_test_split` or `split`.") | |
# this line allows to access init params with 'self.hparams' attribute | |
# also ensures init params will be stored in ckpt | |
self.save_hyperparameters(logger=False) # ignore=['split'] | |
def prepare_data(self): | |
""" | |
Download data if needed. | |
Do not use it to assign state (e.g., self.x = x). | |
""" | |
def setup(self, stage: Optional[str] = None, encoding: str = None): | |
""" | |
Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`. | |
This method is called by lightning with both `trainer.fit()` and `trainer.test()`, so be | |
careful not to execute data splitting twice. | |
""" | |
# TODO test SafeBatchSampler (which skips samples with any None without introducing variable batch size) | |
# load and split datasets only if not loaded in initialization | |
if not any([self.train_data, self.test_data, self.val_data, self.predict_data]): | |
self.train_data, self.val_data, self.test_data = self.split( | |
dataset=self.dataset(data_path=Path(self.hparams.data_dir, self.hparams.data_file)), | |
lengths=self.hparams.train_val_test_split | |
) | |
def train_dataloader(self): | |
return DataLoader( | |
dataset=self.train_data, | |
batch_sampler=SafeBatchSampler( | |
data_source=self.train_data, | |
batch_size=self.hparams.batch_size, | |
# Dropping the last batch prevents problems caused by variable batch sizes in training, e.g., | |
# batch_size=1 in BatchNorm, and shuffling ensures the model be trained on all samples over epochs. | |
drop_last=True, | |
shuffle=True, | |
), | |
# batch_size=self.hparams.batch_size, | |
# shuffle=True, | |
num_workers=self.hparams.num_workers, | |
pin_memory=self.hparams.pin_memory, | |
collate_fn=self.collator, | |
persistent_workers=True if self.hparams.num_workers > 0 else False | |
) | |
def val_dataloader(self): | |
return DataLoader( | |
dataset=self.val_data, | |
batch_sampler=SafeBatchSampler( | |
data_source=self.val_data, | |
batch_size=self.hparams.batch_size, | |
drop_last=False, | |
shuffle=False | |
), | |
# batch_size=self.hparams.batch_size, | |
# shuffle=False, | |
num_workers=self.hparams.num_workers, | |
pin_memory=self.hparams.pin_memory, | |
collate_fn=self.collator, | |
persistent_workers=True if self.hparams.num_workers > 0 else False | |
) | |
def test_dataloader(self): | |
return DataLoader( | |
dataset=self.test_data, | |
batch_sampler=SafeBatchSampler( | |
data_source=self.test_data, | |
batch_size=self.hparams.batch_size, | |
drop_last=False, | |
shuffle=False | |
), | |
# batch_size=self.hparams.batch_size, | |
# shuffle=False, | |
num_workers=self.hparams.num_workers, | |
pin_memory=self.hparams.pin_memory, | |
collate_fn=self.collator, | |
persistent_workers=True if self.hparams.num_workers > 0 else False | |
) | |
def predict_dataloader(self): | |
return DataLoader( | |
dataset=self.predict_data, | |
batch_sampler=SafeBatchSampler( | |
data_source=self.predict_data, | |
batch_size=self.hparams.batch_size, | |
drop_last=False, | |
shuffle=False | |
), | |
# batch_size=self.hparams.batch_size, | |
# shuffle=False, | |
num_workers=self.hparams.num_workers, | |
pin_memory=self.hparams.pin_memory, | |
collate_fn=self.collator, | |
persistent_workers=True if self.hparams.num_workers > 0 else False | |
) | |
def teardown(self, stage: Optional[str] = None): | |
"""Clean up after fit or test.""" | |
pass | |
def state_dict(self): | |
"""Extra things to save to checkpoint.""" | |
return {} | |
def load_state_dict(self, state_dict: Dict[str, Any]): | |
"""Things to do when loading checkpoint.""" | |
pass | |