Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 6,781 Bytes
c0ec7e6 22761bf c0ec7e6 0cb6552 c0ec7e6 0cb6552 c0ec7e6 0cb6552 c0ec7e6 0cb6552 c0ec7e6 0cb6552 c0ec7e6 0cb6552 c0ec7e6 0cb6552 c0ec7e6 0cb6552 c0ec7e6 0cb6552 c0ec7e6 0cb6552 c0ec7e6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 |
from functools import partial
from typing import Optional, Sequence, Dict
from torch import nn, optim, Tensor
from lightning import LightningModule
from torchmetrics import Metric, MetricCollection
class DTILightningModule(LightningModule):
"""
Drug Target Interaction Prediction
optimizer: a partially or fully initialized instance of class torch.optim.Optimizer
drug_encoder: a fully initialized instance of class torch.nn.Module
protein_encoder: a fully initialized instance of class torch.nn.Module
classifier: a fully initialized instance of class torch.nn.Module
model: a fully initialized instance of class torch.nn.Module
metrics: a list of fully initialized instances of class torchmetrics.Metric
"""
extra_return_keys = ['ID1', 'X1', 'ID2', 'X2', 'N']
def __init__(
self,
optimizer: optim.Optimizer,
scheduler: Optional[optim.lr_scheduler | Dict],
predictor: nn.Module,
metrics: Optional[Dict[str, Metric]] = (),
out: nn.Module = None,
loss: nn.Module = None,
activation: nn.Module = None,
):
super().__init__()
self.predictor = predictor
self.out = out
self.loss = loss
self.activation = activation
# Automatically averaged over batches:
# Separate metric instances for train, val and test step to ensure a proper reduction over the epoch
metrics = MetricCollection(dict(metrics))
self.train_metrics = metrics.clone(prefix="train/")
self.val_metrics = metrics.clone(prefix="val/")
self.test_metrics = metrics.clone(prefix="test/")
# allows access to init params with 'self.hparams' attribute and ensures init params will be stored in ckpt
self.save_hyperparameters(logger=False,
ignore=['predictor', 'out', 'loss', 'activation', 'metrics'])
def setup(self, stage):
match stage:
case 'fit':
dataloader = self.trainer.datamodule.train_dataloader()
dummy_batch = next(iter(dataloader))
self.forward(dummy_batch)
# case 'validate':
# dataloader = self.trainer.datamodule.val_dataloader()
# case 'test':
# dataloader = self.trainer.datamodule.test_dataloader()
# case 'predict':
# dataloader = self.trainer.datamodule.predict_dataloader()
# for key, value in dummy_batch.items():
# if isinstance(value, Tensor):
# dummy_batch[key] = value.to(self.device)
def forward(self, batch):
output = self.predictor(batch['X1^'], batch['X2^'])
target = batch.get('Y')
indexes = batch.get('ID^')
preds = None
loss = None
if isinstance(output, Tensor):
output = self.out(output).squeeze(1)
preds = self.activation(output)
elif isinstance(output, Sequence):
output = list(output)
# If multi-objective, assume the zeroth element in `output` is main while the rest are auxiliary
output[0] = self.out(output[0]).squeeze(1)
# Downstream metrics evaluation only needs main-objective preds
preds = self.activation(output[0])
if target is not None:
loss = self.loss(output, target.float())
return preds, target, indexes, loss
def training_step(self, batch, batch_idx):
preds, target, indexes, loss = self.forward(batch)
self.log('train/loss', loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
self.train_metrics(preds=preds, target=target, indexes=indexes.long())
self.log_dict(self.train_metrics, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
return_dict = {
'Y^': preds,
'Y': target,
'loss': loss
}
for key in self.extra_return_keys:
if key in batch:
return_dict[key] = batch[key]
return return_dict
def on_train_epoch_end(self):
pass
def validation_step(self, batch, batch_idx):
preds, target, indexes, loss = self.forward(batch)
self.log('val/loss', loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
self.val_metrics(preds=preds, target=target, indexes=indexes.long())
self.log_dict(self.val_metrics, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
return_dict = {
'Y^': preds,
'Y': target,
'loss': loss
}
for key in self.extra_return_keys:
if key in batch:
return_dict[key] = batch[key]
return return_dict
def on_validation_epoch_end(self):
pass
def test_step(self, batch, batch_idx):
preds, target, indexes, loss = self.forward(batch)
self.log('test/loss', loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
self.test_metrics(preds=preds, target=target, indexes=indexes.long())
self.log_dict(self.test_metrics, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
return_dict = {
'Y^': preds,
'Y': target,
'loss': loss
}
for key in self.extra_return_keys:
if key in batch:
return_dict[key] = batch[key]
return return_dict
def on_test_epoch_end(self):
pass
def predict_step(self, batch, batch_idx, dataloader_idx=0):
preds, _, _, _ = self.forward(batch)
# return a dictionary for callbacks like BasePredictionWriter
return_dict = {
'Y^': preds,
}
for key in self.extra_return_keys:
if key in batch:
return_dict[key] = batch[key]
return return_dict
def configure_optimizers(self):
optimizers_config = {'optimizer': self.hparams.optimizer(params=self.parameters())}
if self.hparams.get('scheduler'):
if isinstance(self.hparams.scheduler, partial):
optimizers_config['lr_scheduler'] = {
"scheduler": self.hparams.scheduler(optimizer=optimizers_config['optimizer']),
"monitor": "val/loss",
"interval": "epoch",
"frequency": 1,
}
else:
self.hparams.scheduler['scheduler'] = self.hparams.scheduler['scheduler'](
optimizer=optimizers_config['optimizer']
)
optimizers_config['lr_scheduler'] = dict(self.hparams.scheduler)
return optimizers_config
|