Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
from functools import partial | |
from typing import Optional, Sequence, Dict | |
from torch import nn, optim, Tensor | |
from lightning import LightningModule | |
from torchmetrics import Metric, MetricCollection | |
class DTILightningModule(LightningModule): | |
""" | |
Drug Target Interaction Prediction | |
optimizer: a partially or fully initialized instance of class torch.optim.Optimizer | |
drug_encoder: a fully initialized instance of class torch.nn.Module | |
protein_encoder: a fully initialized instance of class torch.nn.Module | |
classifier: a fully initialized instance of class torch.nn.Module | |
model: a fully initialized instance of class torch.nn.Module | |
metrics: a list of fully initialized instances of class torchmetrics.Metric | |
""" | |
extra_return_keys = ['ID1', 'X1', 'ID2', 'X2', 'N'] | |
def __init__( | |
self, | |
optimizer: optim.Optimizer, | |
scheduler: Optional[optim.lr_scheduler | Dict], | |
predictor: nn.Module, | |
metrics: Optional[Dict[str, Metric]] = (), | |
out: nn.Module = None, | |
loss: nn.Module = None, | |
activation: nn.Module = None, | |
): | |
super().__init__() | |
self.predictor = predictor | |
self.out = out | |
self.loss = loss | |
self.activation = activation | |
# Automatically averaged over batches: | |
# Separate metric instances for train, val and test step to ensure a proper reduction over the epoch | |
metrics = MetricCollection(dict(metrics)) | |
self.train_metrics = metrics.clone(prefix="train/") | |
self.val_metrics = metrics.clone(prefix="val/") | |
self.test_metrics = metrics.clone(prefix="test/") | |
# allows access to init params with 'self.hparams' attribute and ensures init params will be stored in ckpt | |
self.save_hyperparameters(logger=False, | |
ignore=['predictor', 'out', 'loss', 'activation', 'metrics']) | |
def setup(self, stage): | |
match stage: | |
case 'fit': | |
dataloader = self.trainer.datamodule.train_dataloader() | |
dummy_batch = next(iter(dataloader)) | |
self.forward(dummy_batch) | |
# case 'validate': | |
# dataloader = self.trainer.datamodule.val_dataloader() | |
# case 'test': | |
# dataloader = self.trainer.datamodule.test_dataloader() | |
# case 'predict': | |
# dataloader = self.trainer.datamodule.predict_dataloader() | |
# for key, value in dummy_batch.items(): | |
# if isinstance(value, Tensor): | |
# dummy_batch[key] = value.to(self.device) | |
def forward(self, batch): | |
output = self.predictor(batch['X1^'], batch['X2^']) | |
target = batch.get('Y') | |
indexes = batch.get('ID^') | |
preds = None | |
loss = None | |
if isinstance(output, Tensor): | |
output = self.out(output).squeeze(1) | |
preds = self.activation(output) | |
elif isinstance(output, Sequence): | |
output = list(output) | |
# If multi-objective, assume the zeroth element in `output` is main while the rest are auxiliary | |
output[0] = self.out(output[0]).squeeze(1) | |
# Downstream metrics evaluation only needs main-objective preds | |
preds = self.activation(output[0]) | |
if target is not None: | |
loss = self.loss(output, target.float()) | |
return preds, target, indexes, loss | |
def training_step(self, batch, batch_idx): | |
preds, target, indexes, loss = self.forward(batch) | |
self.log('train/loss', loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True) | |
self.train_metrics(preds=preds, target=target, indexes=indexes.long()) | |
self.log_dict(self.train_metrics, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True) | |
return_dict = { | |
'Y^': preds, | |
'Y': target, | |
'loss': loss | |
} | |
for key in self.extra_return_keys: | |
if key in batch: | |
return_dict[key] = batch[key] | |
return return_dict | |
def on_train_epoch_end(self): | |
pass | |
def validation_step(self, batch, batch_idx): | |
preds, target, indexes, loss = self.forward(batch) | |
self.log('val/loss', loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True) | |
self.val_metrics(preds=preds, target=target, indexes=indexes.long()) | |
self.log_dict(self.val_metrics, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True) | |
return_dict = { | |
'Y^': preds, | |
'Y': target, | |
'loss': loss | |
} | |
for key in self.extra_return_keys: | |
if key in batch: | |
return_dict[key] = batch[key] | |
return return_dict | |
def on_validation_epoch_end(self): | |
pass | |
def test_step(self, batch, batch_idx): | |
preds, target, indexes, loss = self.forward(batch) | |
self.log('test/loss', loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True) | |
self.test_metrics(preds=preds, target=target, indexes=indexes.long()) | |
self.log_dict(self.test_metrics, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True) | |
return_dict = { | |
'Y^': preds, | |
'Y': target, | |
'loss': loss | |
} | |
for key in self.extra_return_keys: | |
if key in batch: | |
return_dict[key] = batch[key] | |
return return_dict | |
def on_test_epoch_end(self): | |
pass | |
def predict_step(self, batch, batch_idx, dataloader_idx=0): | |
preds, _, _, _ = self.forward(batch) | |
# return a dictionary for callbacks like BasePredictionWriter | |
return_dict = { | |
'Y^': preds, | |
} | |
for key in self.extra_return_keys: | |
if key in batch: | |
return_dict[key] = batch[key] | |
return return_dict | |
def configure_optimizers(self): | |
optimizers_config = {'optimizer': self.hparams.optimizer(params=self.parameters())} | |
if self.hparams.get('scheduler'): | |
if isinstance(self.hparams.scheduler, partial): | |
optimizers_config['lr_scheduler'] = { | |
"scheduler": self.hparams.scheduler(optimizer=optimizers_config['optimizer']), | |
"monitor": "val/loss", | |
"interval": "epoch", | |
"frequency": 1, | |
} | |
else: | |
self.hparams.scheduler['scheduler'] = self.hparams.scheduler['scheduler']( | |
optimizer=optimizers_config['optimizer'] | |
) | |
optimizers_config['lr_scheduler'] = dict(self.hparams.scheduler) | |
return optimizers_config | |