from functools import partial from typing import Optional, Sequence, Dict from torch import nn, optim, Tensor from lightning import LightningModule from torchmetrics import Metric, MetricCollection class DTILightningModule(LightningModule): """ Drug Target Interaction Prediction optimizer: a partially or fully initialized instance of class torch.optim.Optimizer drug_encoder: a fully initialized instance of class torch.nn.Module protein_encoder: a fully initialized instance of class torch.nn.Module classifier: a fully initialized instance of class torch.nn.Module model: a fully initialized instance of class torch.nn.Module metrics: a list of fully initialized instances of class torchmetrics.Metric """ extra_return_keys = ['ID1', 'X1', 'ID2', 'X2', 'N'] def __init__( self, optimizer: optim.Optimizer, scheduler: Optional[optim.lr_scheduler | Dict], predictor: nn.Module, metrics: Optional[Dict[str, Metric]] = (), out: nn.Module = None, loss: nn.Module = None, activation: nn.Module = None, ): super().__init__() self.predictor = predictor self.out = out self.loss = loss self.activation = activation # Automatically averaged over batches: # Separate metric instances for train, val and test step to ensure a proper reduction over the epoch metrics = MetricCollection(dict(metrics)) self.train_metrics = metrics.clone(prefix="train/") self.val_metrics = metrics.clone(prefix="val/") self.test_metrics = metrics.clone(prefix="test/") # allows access to init params with 'self.hparams' attribute and ensures init params will be stored in ckpt self.save_hyperparameters(logger=False, ignore=['predictor', 'out', 'loss', 'activation', 'metrics']) def setup(self, stage): match stage: case 'fit': dataloader = self.trainer.datamodule.train_dataloader() dummy_batch = next(iter(dataloader)) self.forward(dummy_batch) # case 'validate': # dataloader = self.trainer.datamodule.val_dataloader() # case 'test': # dataloader = self.trainer.datamodule.test_dataloader() # case 'predict': # dataloader = self.trainer.datamodule.predict_dataloader() # for key, value in dummy_batch.items(): # if isinstance(value, Tensor): # dummy_batch[key] = value.to(self.device) def forward(self, batch): output = self.predictor(batch['X1^'], batch['X2^']) target = batch.get('Y') indexes = batch.get('ID^') preds = None loss = None if isinstance(output, Tensor): output = self.out(output).squeeze(1) preds = self.activation(output) elif isinstance(output, Sequence): output = list(output) # If multi-objective, assume the zeroth element in `output` is main while the rest are auxiliary output[0] = self.out(output[0]).squeeze(1) # Downstream metrics evaluation only needs main-objective preds preds = self.activation(output[0]) if target is not None: loss = self.loss(output, target.float()) return preds, target, indexes, loss def training_step(self, batch, batch_idx): preds, target, indexes, loss = self.forward(batch) self.log('train/loss', loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True) self.train_metrics(preds=preds, target=target, indexes=indexes.long()) self.log_dict(self.train_metrics, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True) return_dict = { 'Y^': preds, 'Y': target, 'loss': loss } for key in self.extra_return_keys: if key in batch: return_dict[key] = batch[key] return return_dict def on_train_epoch_end(self): pass def validation_step(self, batch, batch_idx): preds, target, indexes, loss = self.forward(batch) self.log('val/loss', loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True) self.val_metrics(preds=preds, target=target, indexes=indexes.long()) self.log_dict(self.val_metrics, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True) return_dict = { 'Y^': preds, 'Y': target, 'loss': loss } for key in self.extra_return_keys: if key in batch: return_dict[key] = batch[key] return return_dict def on_validation_epoch_end(self): pass def test_step(self, batch, batch_idx): preds, target, indexes, loss = self.forward(batch) self.log('test/loss', loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True) self.test_metrics(preds=preds, target=target, indexes=indexes.long()) self.log_dict(self.test_metrics, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True) return_dict = { 'Y^': preds, 'Y': target, 'loss': loss } for key in self.extra_return_keys: if key in batch: return_dict[key] = batch[key] return return_dict def on_test_epoch_end(self): pass def predict_step(self, batch, batch_idx, dataloader_idx=0): preds, _, _, _ = self.forward(batch) # return a dictionary for callbacks like BasePredictionWriter return_dict = { 'Y^': preds, } for key in self.extra_return_keys: if key in batch: return_dict[key] = batch[key] return return_dict def configure_optimizers(self): optimizers_config = {'optimizer': self.hparams.optimizer(params=self.parameters())} if self.hparams.get('scheduler'): if isinstance(self.hparams.scheduler, partial): optimizers_config['lr_scheduler'] = { "scheduler": self.hparams.scheduler(optimizer=optimizers_config['optimizer']), "monitor": "val/loss", "interval": "epoch", "frequency": 1, } else: self.hparams.scheduler['scheduler'] = self.hparams.scheduler['scheduler']( optimizer=optimizers_config['optimizer'] ) optimizers_config['lr_scheduler'] = dict(self.hparams.scheduler) return optimizers_config