File size: 6,781 Bytes
c0ec7e6
22761bf
c0ec7e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0cb6552
 
c0ec7e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0cb6552
 
 
 
 
 
 
 
c0ec7e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0cb6552
 
 
 
c0ec7e6
 
0cb6552
 
 
 
 
 
c0ec7e6
 
 
 
 
 
 
 
 
 
0cb6552
 
 
 
c0ec7e6
 
0cb6552
 
 
 
 
 
c0ec7e6
 
 
 
 
 
 
 
 
 
0cb6552
 
 
 
c0ec7e6
 
0cb6552
 
 
 
 
 
c0ec7e6
 
 
 
 
 
0cb6552
 
c0ec7e6
 
0cb6552
 
 
 
 
 
c0ec7e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
from functools import partial
from typing import Optional, Sequence, Dict

from torch import nn, optim, Tensor
from lightning import LightningModule
from torchmetrics import Metric, MetricCollection


class DTILightningModule(LightningModule):
    """
        Drug Target Interaction Prediction

        optimizer: a partially or fully initialized instance of class torch.optim.Optimizer
        drug_encoder: a fully initialized instance of class torch.nn.Module
        protein_encoder: a fully initialized instance of class torch.nn.Module
        classifier: a fully initialized instance of class torch.nn.Module
        model: a fully initialized instance of class torch.nn.Module
        metrics: a list of fully initialized instances of class torchmetrics.Metric
    """
    extra_return_keys = ['ID1', 'X1', 'ID2', 'X2', 'N']

    def __init__(
            self,
            optimizer: optim.Optimizer,
            scheduler: Optional[optim.lr_scheduler | Dict],
            predictor: nn.Module,
            metrics: Optional[Dict[str, Metric]] = (),
            out: nn.Module = None,
            loss: nn.Module = None,
            activation: nn.Module = None,
    ):
        super().__init__()

        self.predictor = predictor
        self.out = out
        self.loss = loss
        self.activation = activation

        # Automatically averaged over batches:
        # Separate metric instances for train, val and test step to ensure a proper reduction over the epoch
        metrics = MetricCollection(dict(metrics))
        self.train_metrics = metrics.clone(prefix="train/")
        self.val_metrics = metrics.clone(prefix="val/")
        self.test_metrics = metrics.clone(prefix="test/")

        # allows access to init params with 'self.hparams' attribute and ensures init params will be stored in ckpt
        self.save_hyperparameters(logger=False,
                                  ignore=['predictor', 'out', 'loss', 'activation', 'metrics'])

    def setup(self, stage):
        match stage:
            case 'fit':
                dataloader = self.trainer.datamodule.train_dataloader()
                dummy_batch = next(iter(dataloader))
                self.forward(dummy_batch)
            # case 'validate':
            #     dataloader = self.trainer.datamodule.val_dataloader()
            # case 'test':
            #     dataloader = self.trainer.datamodule.test_dataloader()
            # case 'predict':
            #     dataloader = self.trainer.datamodule.predict_dataloader()

        # for key, value in dummy_batch.items():
        #     if isinstance(value, Tensor):
        #         dummy_batch[key] = value.to(self.device)

    def forward(self, batch):
        output = self.predictor(batch['X1^'], batch['X2^'])
        target = batch.get('Y')
        indexes = batch.get('ID^')
        preds = None
        loss = None

        if isinstance(output, Tensor):
            output = self.out(output).squeeze(1)
            preds = self.activation(output)

        elif isinstance(output, Sequence):
            output = list(output)
            # If multi-objective, assume the zeroth element in `output` is main while the rest are auxiliary
            output[0] = self.out(output[0]).squeeze(1)
            # Downstream metrics evaluation only needs main-objective preds
            preds = self.activation(output[0])

        if target is not None:
            loss = self.loss(output, target.float())

        return preds, target, indexes, loss

    def training_step(self, batch, batch_idx):
        preds, target, indexes, loss = self.forward(batch)
        self.log('train/loss', loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
        self.train_metrics(preds=preds, target=target, indexes=indexes.long())
        self.log_dict(self.train_metrics, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)

        return_dict = {
            'Y^': preds,
            'Y': target,
            'loss': loss
        }

        for key in self.extra_return_keys:
            if key in batch:
                return_dict[key] = batch[key]

        return return_dict

    def on_train_epoch_end(self):
        pass

    def validation_step(self, batch, batch_idx):
        preds, target, indexes, loss = self.forward(batch)

        self.log('val/loss', loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
        self.val_metrics(preds=preds, target=target, indexes=indexes.long())
        self.log_dict(self.val_metrics, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)

        return_dict = {
            'Y^': preds,
            'Y': target,
            'loss': loss
        }

        for key in self.extra_return_keys:
            if key in batch:
                return_dict[key] = batch[key]

        return return_dict

    def on_validation_epoch_end(self):
        pass

    def test_step(self, batch, batch_idx):
        preds, target, indexes, loss = self.forward(batch)

        self.log('test/loss', loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
        self.test_metrics(preds=preds, target=target, indexes=indexes.long())
        self.log_dict(self.test_metrics, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)

        return_dict = {
            'Y^': preds,
            'Y': target,
            'loss': loss
        }

        for key in self.extra_return_keys:
            if key in batch:
                return_dict[key] = batch[key]

        return return_dict

    def on_test_epoch_end(self):
        pass

    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        preds, _, _, _ = self.forward(batch)
        # return a dictionary for callbacks like BasePredictionWriter
        return_dict = {
            'Y^': preds,
        }

        for key in self.extra_return_keys:
            if key in batch:
                return_dict[key] = batch[key]

        return return_dict

    def configure_optimizers(self):
        optimizers_config = {'optimizer': self.hparams.optimizer(params=self.parameters())}
        if self.hparams.get('scheduler'):
            if isinstance(self.hparams.scheduler, partial):
                optimizers_config['lr_scheduler'] = {
                    "scheduler": self.hparams.scheduler(optimizer=optimizers_config['optimizer']),
                    "monitor": "val/loss",
                    "interval": "epoch",
                    "frequency": 1,
                }
            else:
                self.hparams.scheduler['scheduler'] = self.hparams.scheduler['scheduler'](
                    optimizer=optimizers_config['optimizer']
                )
                optimizers_config['lr_scheduler'] = dict(self.hparams.scheduler)
        return optimizers_config