xf3227 commited on
Commit
6fc43ab
·
1 Parent(s): c4c51f0
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. adrd/__init__.py +22 -0
  2. adrd/__pycache__/__init__.cpython-311.pyc +0 -0
  3. adrd/_ds/__init__.py +0 -0
  4. adrd/_ds/lddl.py +71 -0
  5. adrd/model/__init__.py +6 -0
  6. adrd/model/__pycache__/__init__.cpython-311.pyc +0 -0
  7. adrd/model/__pycache__/adrd_model.cpython-311.pyc +0 -0
  8. adrd/model/__pycache__/calibration.cpython-311.pyc +0 -0
  9. adrd/model/__pycache__/imaging_model.cpython-311.pyc +0 -0
  10. adrd/model/__pycache__/train_resnet.cpython-311.pyc +0 -0
  11. adrd/model/adrd_model.py +976 -0
  12. adrd/model/calibration.py +450 -0
  13. adrd/model/cnn_resnet3d_with_linear_classifier.py +533 -0
  14. adrd/model/imaging_model.py +843 -0
  15. adrd/model/train_resnet.py +484 -0
  16. adrd/model/transformer.py +600 -0
  17. adrd/nn/__init__.py +12 -0
  18. adrd/nn/__pycache__/__init__.cpython-311.pyc +0 -0
  19. adrd/nn/__pycache__/blocks.cpython-311.pyc +0 -0
  20. adrd/nn/__pycache__/c3d.cpython-311.pyc +0 -0
  21. adrd/nn/__pycache__/cnn_resnet3d.cpython-311.pyc +0 -0
  22. adrd/nn/__pycache__/cnn_resnet3d_with_linear_classifier.cpython-311.pyc +0 -0
  23. adrd/nn/__pycache__/dense_net.cpython-311.pyc +0 -0
  24. adrd/nn/__pycache__/focal_loss.cpython-311.pyc +0 -0
  25. adrd/nn/__pycache__/img_model_wrapper.cpython-311.pyc +0 -0
  26. adrd/nn/__pycache__/net_resnet3d.cpython-311.pyc +0 -0
  27. adrd/nn/__pycache__/resnet3d.cpython-311.pyc +0 -0
  28. adrd/nn/__pycache__/resnet_img_model.cpython-311.pyc +0 -0
  29. adrd/nn/__pycache__/selfattention.cpython-311.pyc +0 -0
  30. adrd/nn/__pycache__/transformer.cpython-311.pyc +0 -0
  31. adrd/nn/__pycache__/unet.cpython-311.pyc +0 -0
  32. adrd/nn/__pycache__/unet_3d.cpython-311.pyc +0 -0
  33. adrd/nn/__pycache__/unet_img_model.cpython-311.pyc +0 -0
  34. adrd/nn/__pycache__/vitautoenc.cpython-311.pyc +0 -0
  35. adrd/nn/blocks.py +57 -0
  36. adrd/nn/c3d.py +99 -0
  37. adrd/nn/cnn_resnet3d.py +81 -0
  38. adrd/nn/cnn_resnet3d_with_linear_classifier.py +56 -0
  39. adrd/nn/dense_net.py +211 -0
  40. adrd/nn/focal_loss.py +120 -0
  41. adrd/nn/img_model_wrapper.py +174 -0
  42. adrd/nn/net_resnet3d.py +338 -0
  43. adrd/nn/resnet3d.py +256 -0
  44. adrd/nn/resnet_img_model.py +81 -0
  45. adrd/nn/selfattention.py +62 -0
  46. adrd/nn/transformer.py +268 -0
  47. adrd/nn/unet.py +232 -0
  48. adrd/nn/unet_3d.py +63 -0
  49. adrd/nn/unet_img_model.py +211 -0
  50. adrd/nn/vitautoenc.py +163 -0
adrd/__init__.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __version__ = '0.0.1'
2
+
3
+ from . import nn
4
+ from . import model
5
+
6
+ # # load pretrained transformer
7
+ # pretrained_transformer = model.Transformer.from_ckpt('{}/ckpt/ckpt.pt'.format(__path__[0]))
8
+ # from . import shap_adrd
9
+ # from .model import DynamicCalibratedClassifier
10
+ # from .model import StaticCalibratedClassifier
11
+
12
+ # load fitted transformer and calibrated wrapper
13
+ # try:
14
+ # fitted_resnet3d = model.CNNResNet3DWithLinearClassifier.from_ckpt('{}/ckpt/ckpt_img_072523.pt'.format(__path__[0]))
15
+ # fitted_calibrated_classifier_nonimg = StaticCalibratedClassifier.from_ckpt(
16
+ # filepath_state_dict = '{}/ckpt/static_calibrated_classifier_073023.pkl'.format(__path__[0]),
17
+ # filepath_wrapped_model = '{}/ckpt/ckpt_080823.pt'.format(__path__[0]),
18
+ # )
19
+ # fitted_transformer_nonimg = fitted_calibrated_classifier_nonimg.model
20
+ # shap_explainer = shap_adrd.SamplingExplainer(fitted_transformer_nonimg)
21
+ # except:
22
+ # print('Fail to load checkpoints.')
adrd/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (288 Bytes). View file
 
adrd/_ds/__init__.py ADDED
File without changes
adrd/_ds/lddl.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Self, overload
2
+
3
+
4
+ class lddl:
5
+ ''' ... '''
6
+ def __init__(self) -> None:
7
+ ''' ... '''
8
+ self.dat_ld: list[dict[str, Any]] = None
9
+ self.dat_dl: dict[str, list[Any]] = None
10
+
11
+ @overload
12
+ def __getitem__(self, idx: int) -> dict[str, Any]: ...
13
+
14
+ @overload
15
+ def __getitem__(self, idx: str) -> list[Any]: ...
16
+
17
+ def __getitem__(self, idx: str | int) -> list[Any] | dict[str, Any]:
18
+ ''' ... '''
19
+ if isinstance(idx, str):
20
+ return self.dat_dl[idx]
21
+ elif isinstance(idx, int):
22
+ return self.dat_ld[idx]
23
+ else:
24
+ raise TypeError('Unexpected key type: {}'.format(type(idx)))
25
+
26
+ @classmethod
27
+ def from_ld(cls, dat: list[dict[str, Any]]) -> Self:
28
+ ''' Construct from list of dicts. '''
29
+ obj = cls()
30
+ obj.dat_ld = dat
31
+ obj.dat_dl = {k: [dat[i][k] for i in range(len(dat))] for k in dat[0]}
32
+ return obj
33
+
34
+ @classmethod
35
+ def from_dl(cls, dat: dict[str, list[Any]]) -> Self:
36
+ ''' Construct from dict of lists. '''
37
+ obj = cls()
38
+ obj.dat_ld = [dict(zip(dat, v)) for v in zip(*dat.values())]
39
+ obj.dat_dl = dat
40
+ return obj
41
+
42
+
43
+ if __name__ == '__main__':
44
+ ''' for testing purpose only '''
45
+ dl = {
46
+ 'a': [0, 1, 2],
47
+ 'b': [3, 4, 5],
48
+ }
49
+
50
+ ld = [
51
+ {'a': 0, 'b': 1, 'c': 2},
52
+ {'a': 3, 'b': 4, 'c': 5},
53
+ ]
54
+
55
+ # test constructing from ld
56
+ dat_ld = lddl.from_ld(ld)
57
+ print(dat_ld.dat_ld)
58
+ print(dat_ld.dat_dl)
59
+
60
+ # test constructing from dl
61
+ dat_dl = lddl.from_dl(dl)
62
+ print(dat_dl.dat_ld)
63
+ print(dat_dl.dat_dl)
64
+
65
+ # test __getitem__
66
+ print(dat_dl['a'])
67
+ print(dat_dl[0])
68
+
69
+ # mouse hover to check if type hints are correct
70
+ v = dat_dl['a']
71
+ v = dat_dl[0]
adrd/model/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .adrd_model import ADRDModel
2
+ from .imaging_model import ImagingModel
3
+ from .train_resnet import TrainResNet
4
+ # from .transformer import Transformer
5
+ from .calibration import DynamicCalibratedClassifier
6
+ from .calibration import StaticCalibratedClassifier
adrd/model/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (512 Bytes). View file
 
adrd/model/__pycache__/adrd_model.cpython-311.pyc ADDED
Binary file (56.5 kB). View file
 
adrd/model/__pycache__/calibration.cpython-311.pyc ADDED
Binary file (27.8 kB). View file
 
adrd/model/__pycache__/imaging_model.cpython-311.pyc ADDED
Binary file (46.5 kB). View file
 
adrd/model/__pycache__/train_resnet.cpython-311.pyc ADDED
Binary file (26.1 kB). View file
 
adrd/model/adrd_model.py ADDED
@@ -0,0 +1,976 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __all__ = ['ADRDModel']
2
+
3
+ import wandb
4
+ import torch
5
+ from torch.utils.data import DataLoader
6
+ import numpy as np
7
+ import tqdm
8
+ import multiprocessing
9
+ from sklearn.base import BaseEstimator
10
+ from sklearn.utils.validation import check_is_fitted
11
+ from sklearn.model_selection import train_test_split
12
+ from scipy.special import expit
13
+ from copy import deepcopy
14
+ from contextlib import suppress
15
+ from typing import Any, Self, Type
16
+ from functools import wraps
17
+ from tqdm import tqdm
18
+ Tensor = Type[torch.Tensor]
19
+ Module = Type[torch.nn.Module]
20
+
21
+ # for DistributedDataParallel
22
+ import torch.distributed as dist
23
+ import torch.multiprocessing as mp
24
+ from torch.nn.parallel import DistributedDataParallel as DDP
25
+
26
+ from .. import nn
27
+ from ..nn import Transformer
28
+ from ..utils import TransformerTrainingDataset, TransformerBalancedTrainingDataset, TransformerValidationDataset, TransformerTestingDataset, Transformer2ndOrderBalancedTrainingDataset
29
+ from ..utils.misc import ProgressBar
30
+ from ..utils.misc import get_metrics_multitask, print_metrics_multitask
31
+ from ..utils.misc import convert_args_kwargs_to_kwargs
32
+
33
+
34
+ def _manage_ctx_fit(func):
35
+ ''' ... '''
36
+ @wraps(func)
37
+ def wrapper(*args, **kwargs):
38
+ # format arguments
39
+ kwargs = convert_args_kwargs_to_kwargs(func, args, kwargs)
40
+
41
+ if kwargs['self']._device_ids is None:
42
+ return func(**kwargs)
43
+ else:
44
+ # change primary device
45
+ default_device = kwargs['self'].device
46
+ kwargs['self'].device = kwargs['self']._device_ids[0]
47
+ rtn = func(**kwargs)
48
+ kwargs['self'].to(default_device)
49
+ return rtn
50
+ return wrapper
51
+
52
+
53
+ class ADRDModel(BaseEstimator):
54
+ """Primary model class for ADRD framework.
55
+
56
+ The ADRDModel encapsulates the core pipeline of the ADRD framework,
57
+ permitting users to train and validate with the provided data. Designed for
58
+ user-friendly operation, the ADRDModel is derived from
59
+ ``sklearn.base.BaseEstimator``, ensuring compliance with the well-established
60
+ API design conventions of scikit-learn.
61
+ """
62
+ def __init__(self,
63
+ src_modalities: dict[str, dict[str, Any]],
64
+ tgt_modalities: dict[str, dict[str, Any]],
65
+ label_fractions: dict[str, float],
66
+ d_model: int = 32,
67
+ nhead: int = 1,
68
+ num_encoder_layers: int = 1,
69
+ num_decoder_layers: int = 1,
70
+ num_epochs: int = 32,
71
+ batch_size: int = 8,
72
+ batch_size_multiplier: int = 1,
73
+ lr: float = 1e-2,
74
+ weight_decay: float = 0.0,
75
+ beta: float = 0.9999,
76
+ gamma: float = 2.0,
77
+ criterion: str | None = None,
78
+ device: str = 'cpu',
79
+ cuda_devices: list = [1],
80
+ img_net: str | None = None,
81
+ imgnet_layers: int | None = 2,
82
+ img_size: int | None = 128,
83
+ fusion_stage: str = 'middle',
84
+ patch_size: int | None = 16,
85
+ imgnet_ckpt: str | None = None,
86
+ train_imgnet: bool = False,
87
+ ckpt_path: str = './adrd_tool/adrd/dev/ckpt/ckpt.pt',
88
+ load_from_ckpt: bool = False,
89
+ save_intermediate_ckpts: bool = False,
90
+ data_parallel: bool = False,
91
+ verbose: int = 0,
92
+ wandb_ = 0,
93
+ balanced_sampling: bool = False,
94
+ label_distribution: dict = {},
95
+ ranking_loss: bool = False,
96
+ _device_ids: list | None = None,
97
+
98
+ _dataloader_num_workers: int = 4,
99
+ _amp_enabled: bool = False,
100
+ ) -> None:
101
+ """Create a new ADRD model.
102
+
103
+ :param src_modalities: _description_
104
+ :type src_modalities: dict[str, dict[str, Any]]
105
+ :param tgt_modalities: _description_
106
+ :type tgt_modalities: dict[str, dict[str, Any]]
107
+ :param label_fractions: _description_
108
+ :type label_fractions: dict[str, float]
109
+ :param d_model: _description_, defaults to 32
110
+ :type d_model: int, optional
111
+ :param nhead: number of transformer heads, defaults to 1
112
+ :type nhead: int, optional
113
+ :param num_encoder_layers: number of encoder layers, defaults to 1
114
+ :type num_encoder_layers: int, optional
115
+ :param num_decoder_layers: number of decoder layers, defaults to 1
116
+ :type num_decoder_layers: int, optional
117
+ :param num_epochs: number of training epochs, defaults to 32
118
+ :type num_epochs: int, optional
119
+ :param batch_size: batch size, defaults to 8
120
+ :type batch_size: int, optional
121
+ :param batch_size_multiplier: _description_, defaults to 1
122
+ :type batch_size_multiplier: int, optional
123
+ :param lr: learning rate, defaults to 1e-2
124
+ :type lr: float, optional
125
+ :param weight_decay: _description_, defaults to 0.0
126
+ :type weight_decay: float, optional
127
+ :param beta: _description_, defaults to 0.9999
128
+ :type beta: float, optional
129
+ :param gamma: The focusing parameter for the focal loss. Higher values of gamma make easy-to-classify examples contribute less to the loss relative to hard-to-classify examples. Must be non-negative., defaults to 2.0
130
+ :type gamma: float, optional
131
+ :param criterion: The criterion to select the best model, defaults to None
132
+ :type criterion: str | None, optional
133
+ :param device: 'cuda' or 'cpu', defaults to 'cpu'
134
+ :type device: str, optional
135
+ :param cuda_devices: A list of gpu numbers to data parallel training. The device must be set to 'cuda' and data_parallel must be set to True, defaults to [1]
136
+ :type cuda_devices: list, optional
137
+ :param img_net: _description_, defaults to None
138
+ :type img_net: str | None, optional
139
+ :param imgnet_layers: _description_, defaults to 2
140
+ :type imgnet_layers: int | None, optional
141
+ :param img_size: _description_, defaults to 128
142
+ :type img_size: int | None, optional
143
+ :param fusion_stage: _description_, defaults to 'middle'
144
+ :type fusion_stage: str, optional
145
+ :param patch_size: _description_, defaults to 16
146
+ :type patch_size: int | None, optional
147
+ :param imgnet_ckpt: _description_, defaults to None
148
+ :type imgnet_ckpt: str | None, optional
149
+ :param train_imgnet: Set to True to finetune the img_net backbone, defaults to False
150
+ :type train_imgnet: bool, optional
151
+ :param ckpt_path: The model checkpoint point path, defaults to './adrd_tool/adrd/dev/ckpt/ckpt.pt'
152
+ :type ckpt_path: str, optional
153
+ :param load_from_ckpt: Set to True to load the model weights from checkpoint ckpt_path, defaults to False
154
+ :type load_from_ckpt: bool, optional
155
+ :param save_intermediate_ckpts: Set to True to save intermediate model checkpoints, defaults to False
156
+ :type save_intermediate_ckpts: bool, optional
157
+ :param data_parallel: Set to True to enable data parallel trsining, defaults to False
158
+ :type data_parallel: bool, optional
159
+ :param verbose: _description_, defaults to 0
160
+ :type verbose: int, optional
161
+ :param wandb_: Set to 1 to track the loss and accuracy curves on wandb, defaults to 0
162
+ :type wandb_: int, optional
163
+ :param balanced_sampling: _description_, defaults to False
164
+ :type balanced_sampling: bool, optional
165
+ :param label_distribution: _description_, defaults to {}
166
+ :type label_distribution: dict, optional
167
+ :param ranking_loss: _description_, defaults to False
168
+ :type ranking_loss: bool, optional
169
+ :param _device_ids: _description_, defaults to None
170
+ :type _device_ids: list | None, optional
171
+ :param _dataloader_num_workers: _description_, defaults to 4
172
+ :type _dataloader_num_workers: int, optional
173
+ :param _amp_enabled: _description_, defaults to False
174
+ :type _amp_enabled: bool, optional
175
+ """
176
+ # for multiprocessing
177
+ self._rank = 0
178
+ self._lock = None
179
+
180
+ # positional parameters
181
+ self.src_modalities = src_modalities
182
+ self.tgt_modalities = tgt_modalities
183
+
184
+ # training parameters
185
+ self.label_fractions = label_fractions
186
+ self.d_model = d_model
187
+ self.nhead = nhead
188
+ self.num_encoder_layers = num_encoder_layers
189
+ self.num_decoder_layers = num_decoder_layers
190
+ self.num_epochs = num_epochs
191
+ self.batch_size = batch_size
192
+ self.batch_size_multiplier = batch_size_multiplier
193
+ self.lr = lr
194
+ self.weight_decay = weight_decay
195
+ self.beta = beta
196
+ self.gamma = gamma
197
+ self.criterion = criterion
198
+ self.device = device
199
+ self.cuda_devices = cuda_devices
200
+ self.img_net = img_net
201
+ self.patch_size = patch_size
202
+ self.img_size = img_size
203
+ self.fusion_stage = fusion_stage
204
+ self.imgnet_ckpt = imgnet_ckpt
205
+ self.imgnet_layers = imgnet_layers
206
+ self.train_imgnet = train_imgnet
207
+ self.ckpt_path = ckpt_path
208
+ self.load_from_ckpt = load_from_ckpt
209
+ self.save_intermediate_ckpts = save_intermediate_ckpts
210
+ self.data_parallel = data_parallel
211
+ self.verbose = verbose
212
+ self.label_distribution = label_distribution
213
+ self.wandb_ = wandb_
214
+ self.balanced_sampling = balanced_sampling
215
+ self.ranking_loss = ranking_loss
216
+ self._device_ids = _device_ids
217
+ self._dataloader_num_workers = _dataloader_num_workers
218
+ self._amp_enabled = _amp_enabled
219
+ self.scaler = torch.cuda.amp.GradScaler()
220
+ # self._init_net()
221
+
222
+ @_manage_ctx_fit
223
+ def fit(self, x_trn, x_vld, y_trn, y_vld, img_train_trans=None, img_vld_trans=None, img_mode=0) -> Self:
224
+ # def fit(self, x, y) -> Self:
225
+ ''' ... '''
226
+
227
+ # start a new wandb run to track this script
228
+ if self.wandb_ == 1:
229
+ wandb.init(
230
+ # set the wandb project where this run will be logged
231
+ project="ADRD_main",
232
+
233
+ # track hyperparameters and run metadata
234
+ config={
235
+ "Loss": 'Focalloss',
236
+ "ranking_loss": self.ranking_loss,
237
+ "img architecture": self.img_net,
238
+ "EMB": "ALL_SEQ",
239
+ "epochs": self.num_epochs,
240
+ "d_model": self.d_model,
241
+ # 'positional encoding': 'Diff PE',
242
+ 'Balanced Sampling': self.balanced_sampling,
243
+ 'Shared CNN': 'Yes',
244
+ }
245
+ )
246
+ wandb.run.log_code(".")
247
+ else:
248
+ wandb.init(mode="disabled")
249
+ # for PyTorch computational efficiency
250
+ torch.set_num_threads(1)
251
+ # print(img_train_trans)
252
+ # initialize neural network
253
+ print(self.criterion)
254
+ print(f"Ranking loss: {self.ranking_loss}")
255
+ print(f"Batch size: {self.batch_size}")
256
+ print(f"Batch size multiplier: {self.batch_size_multiplier}")
257
+
258
+ if img_mode in [0,1,2]:
259
+ for k, info in self.src_modalities.items():
260
+ if info['type'] == 'imaging':
261
+ if 'densenet' in self.img_net.lower() and 'emb' not in self.img_net.lower():
262
+ info['shape'] = (1,) + self.img_size
263
+ info['img_shape'] = (1,) + self.img_size
264
+ elif 'emb' not in self.img_net.lower():
265
+ info['shape'] = (1,) + (self.img_size,) * 3
266
+ info['img_shape'] = (1,) + (self.img_size,) * 3
267
+ elif 'swinunetr' in self.img_net.lower():
268
+ info['shape'] = (1, 768, 4, 4, 4)
269
+ info['img_shape'] = (1, 768, 4, 4, 4)
270
+
271
+
272
+
273
+ self._init_net()
274
+ ldr_trn, ldr_vld = self._init_dataloader(x_trn, x_vld, y_trn, y_vld, img_train_trans=img_train_trans, img_vld_trans=img_vld_trans)
275
+
276
+ # initialize optimizer and scheduler
277
+ if not self.load_from_ckpt:
278
+ self.optimizer = self._init_optimizer()
279
+ self.scheduler = self._init_scheduler(self.optimizer)
280
+
281
+ # gradient scaler for AMP
282
+ if self._amp_enabled:
283
+ self.scaler = torch.cuda.amp.GradScaler()
284
+
285
+ # initialize the focal losses
286
+ self.loss_fn = {}
287
+
288
+ for k in self.tgt_modalities:
289
+ if self.label_fractions[k] >= 0.3:
290
+ alpha = -1
291
+ else:
292
+ alpha = pow((1 - self.label_fractions[k]), 2)
293
+ # alpha = -1
294
+ self.loss_fn[k] = nn.SigmoidFocalLoss(
295
+ alpha = alpha,
296
+ gamma = self.gamma,
297
+ reduction = 'none'
298
+ )
299
+
300
+ # to record the best validation performance criterion
301
+ if self.criterion is not None:
302
+ best_crit = None
303
+ best_crit_AUPR = None
304
+
305
+ # progress bar for epoch loops
306
+ if self.verbose == 1:
307
+ with self._lock if self._lock is not None else suppress():
308
+ pbr_epoch = tqdm.tqdm(
309
+ desc = 'Rank {:02d}'.format(self._rank),
310
+ total = self.num_epochs,
311
+ position = self._rank,
312
+ ascii = True,
313
+ leave = False,
314
+ bar_format='{l_bar}{r_bar}'
315
+ )
316
+
317
+ self.skip_embedding = {}
318
+ for k, info in self.src_modalities.items():
319
+ # if info['type'] == 'imaging':
320
+ # if not self.img_net:
321
+ # self.skip_embedding[k] = True
322
+ # else:
323
+ self.skip_embedding[k] = False
324
+
325
+ self.grad_list = []
326
+ # Define a hook function to print and store the gradient of a layer
327
+ def print_and_store_grad(grad):
328
+ self.grad_list.append(grad)
329
+ # print(grad)
330
+
331
+
332
+ # initialize the ranking loss
333
+ self.lambda_coeff = 0.005
334
+ self.margin = 0.25
335
+ self.margin_loss = torch.nn.MarginRankingLoss(reduction='sum', margin=self.margin)
336
+
337
+ # training loop
338
+ for epoch in range(self.start_epoch, self.num_epochs):
339
+ met_trn = self.train_one_epoch(ldr_trn, epoch)
340
+ met_vld = self.validate_one_epoch(ldr_vld, epoch)
341
+
342
+ print(self.ckpt_path.split('/')[-1])
343
+
344
+ # save the model if it has the best validation performance criterion by far
345
+ if self.criterion is None: continue
346
+
347
+ # is current criterion better than previous best?
348
+ curr_crit = np.mean([met_vld[i][self.criterion] for i in range(len(self.tgt_modalities))])
349
+ curr_crit_AUPR = np.mean([met_vld[i]["AUC (PR)"] for i in range(len(self.tgt_modalities))])
350
+ # AUROC
351
+ if best_crit is None or np.isnan(best_crit):
352
+ is_better = True
353
+ elif self.criterion == 'Loss' and best_crit >= curr_crit:
354
+ is_better = True
355
+ elif self.criterion != 'Loss' and best_crit <= curr_crit :
356
+ is_better = True
357
+ else:
358
+ is_better = False
359
+
360
+ # AUPR
361
+ if best_crit_AUPR is None or np.isnan(best_crit_AUPR):
362
+ is_better_AUPR = True
363
+ elif best_crit_AUPR <= curr_crit_AUPR :
364
+ is_better_AUPR = True
365
+ else:
366
+ is_better_AUPR = False
367
+ # update best criterion
368
+ if is_better_AUPR:
369
+ best_crit_AUPR = curr_crit_AUPR
370
+ if self.save_intermediate_ckpts:
371
+ print(f"Saving the model to {self.ckpt_path[:-3]}_AUPR.pt...")
372
+ self.save(self.ckpt_path[:-3]+"_AUPR.pt", epoch)
373
+ if is_better:
374
+ best_crit = curr_crit
375
+ best_state_dict = deepcopy(self.net_.state_dict())
376
+ if self.save_intermediate_ckpts:
377
+ print(f"Saving the model to {self.ckpt_path}...")
378
+ self.save(self.ckpt_path, epoch)
379
+
380
+ if self.verbose > 2:
381
+ print('Best {}: {}'.format(self.criterion, best_crit))
382
+ print('Best {}: {}'.format('AUC (PR)', best_crit_AUPR))
383
+
384
+ if self.verbose == 1:
385
+ with self._lock if self._lock is not None else suppress():
386
+ pbr_epoch.update(1)
387
+ pbr_epoch.refresh()
388
+
389
+ if self.verbose == 1:
390
+ with self._lock if self._lock is not None else suppress():
391
+ pbr_epoch.close()
392
+
393
+ return self
394
+
395
+ def train_one_epoch(self, ldr_trn, epoch):
396
+ # progress bar for batch loops
397
+ if self.verbose > 1:
398
+ pbr_batch = ProgressBar(len(ldr_trn.dataset), 'Epoch {:03d} (TRN)'.format(epoch))
399
+
400
+ # set model to train mode
401
+ torch.set_grad_enabled(True)
402
+ self.net_.train()
403
+
404
+ scores_trn, y_true_trn, y_mask_trn = [], [], []
405
+ losses_trn = [[] for _ in self.tgt_modalities]
406
+ iters = len(ldr_trn)
407
+ for n_iter, (x_batch, y_batch, mask, y_mask) in enumerate(ldr_trn):
408
+
409
+ # mount data to the proper device
410
+ x_batch = {k: x_batch[k].to(self.device) for k in x_batch}
411
+ y_batch = {k: y_batch[k].to(torch.float).to(self.device) for k in y_batch}
412
+ mask = {k: mask[k].to(self.device) for k in mask}
413
+ y_mask = {k: y_mask[k].to(self.device) for k in y_mask}
414
+
415
+ with torch.autocast(
416
+ device_type = 'cpu' if self.device == 'cpu' else 'cuda',
417
+ dtype = torch.bfloat16 if self.device == 'cpu' else torch.float16,
418
+ enabled = self._amp_enabled,
419
+ ):
420
+
421
+ outputs = self.net_(x_batch, mask, skip_embedding=self.skip_embedding)
422
+
423
+ # calculate multitask loss
424
+ loss = 0
425
+
426
+ # for initial 10 epochs, only the focal loss is used for stable training
427
+ if self.ranking_loss:
428
+ if epoch < 10:
429
+ loss = 0
430
+ else:
431
+ for i, k in enumerate(self.tgt_modalities):
432
+ for ii, kk in enumerate(self.tgt_modalities):
433
+ if ii>i:
434
+ pairs = (y_mask[k] == 1) & (y_mask[kk] == 1)
435
+ total_elements = (torch.abs(y_batch[k][pairs]-y_batch[kk][pairs])).sum()
436
+ if total_elements != 0:
437
+ loss += self.lambda_coeff * (self.margin_loss(torch.sigmoid(outputs[k])[pairs],torch.sigmoid(outputs[kk][pairs]),y_batch[k][pairs]-y_batch[kk][pairs]))/total_elements
438
+
439
+ for i, k in enumerate(self.tgt_modalities):
440
+ loss_task = self.loss_fn[k](outputs[k], y_batch[k])
441
+ msk_loss_task = loss_task * y_mask[k]
442
+ msk_loss_mean = msk_loss_task.sum() / y_mask[k].sum()
443
+ # msk_loss_mean = msk_loss_task.sum()
444
+ loss += msk_loss_mean
445
+ losses_trn[i] += msk_loss_task.detach().cpu().numpy().tolist()
446
+
447
+ # backward
448
+ loss = loss / self.batch_size_multiplier
449
+ if self._amp_enabled:
450
+ self.scaler.scale(loss).backward()
451
+ else:
452
+ loss.backward()
453
+
454
+ if len(self.grad_list) > 0:
455
+ print(len(self.grad_list), len(self.grad_list[-1]))
456
+ print(f"Gradient at {n_iter}: {self.grad_list[-1][0]}")
457
+
458
+ # print("img_MRI_T1_1 ", self.net_.modules_emb_src.img_MRI_T1_1.img_model.features[0].weight)
459
+ # print("img_MRI_T1_1 ", self.net_.modules_emb_src.img_MRI_T1_1.downsample[0].weight)
460
+
461
+ # update parameters
462
+ if n_iter != 0 and n_iter % self.batch_size_multiplier == 0:
463
+ if self._amp_enabled:
464
+ self.scaler.step(self.optimizer)
465
+ self.scaler.update()
466
+ self.optimizer.zero_grad()
467
+ else:
468
+ self.optimizer.step()
469
+ self.optimizer.zero_grad()
470
+
471
+ # set self.scheduler
472
+ self.scheduler.step(epoch + n_iter / iters)
473
+
474
+ ''' TODO: change array to dictionary later '''
475
+ outputs = torch.stack(list(outputs.values()), dim=1)
476
+ y_batch = torch.stack(list(y_batch.values()), dim=1)
477
+ y_mask = torch.stack(list(y_mask.values()), dim=1)
478
+
479
+ # save outputs to evaluate performance later
480
+ scores_trn.append(outputs.detach().to(torch.float).cpu())
481
+ y_true_trn.append(y_batch.cpu())
482
+ y_mask_trn.append(y_mask.cpu())
483
+
484
+ # update progress bar
485
+ if self.verbose > 1:
486
+ batch_size = len(next(iter(x_batch.values())))
487
+ pbr_batch.update(batch_size, {})
488
+ pbr_batch.refresh()
489
+
490
+ # clear cuda cache
491
+ if "cuda" in self.device:
492
+ torch.cuda.empty_cache()
493
+
494
+ # for better tqdm progress bar display
495
+ if self.verbose > 1:
496
+ pbr_batch.close()
497
+
498
+ # calculate and print training performance metrics
499
+ scores_trn = torch.cat(scores_trn)
500
+ y_true_trn = torch.cat(y_true_trn)
501
+ y_mask_trn = torch.cat(y_mask_trn)
502
+ y_pred_trn = (scores_trn > 0).to(torch.int)
503
+ y_prob_trn = torch.sigmoid(scores_trn)
504
+ met_trn = get_metrics_multitask(
505
+ y_true_trn.numpy(),
506
+ y_pred_trn.numpy(),
507
+ y_prob_trn.numpy(),
508
+ y_mask_trn.numpy()
509
+ )
510
+
511
+ # add loss to metrics
512
+ for i in range(len(self.tgt_modalities)):
513
+ met_trn[i]['Loss'] = np.mean(losses_trn[i])
514
+
515
+ # log metrics to wandb
516
+ wandb.log({f"Train loss {list(self.tgt_modalities)[i]}": met_trn[i]['Loss'] for i in range(len(self.tgt_modalities))}, step=epoch)
517
+ wandb.log({f"Train Balanced Accuracy {list(self.tgt_modalities)[i]}": met_trn[i]['Balanced Accuracy'] for i in range(len(self.tgt_modalities))}, step=epoch)
518
+
519
+ wandb.log({f"Train AUC (ROC) {list(self.tgt_modalities)[i]}": met_trn[i]['AUC (ROC)'] for i in range(len(self.tgt_modalities))}, step=epoch)
520
+ wandb.log({f"Train AUPR {list(self.tgt_modalities)[i]}": met_trn[i]['AUC (PR)'] for i in range(len(self.tgt_modalities))}, step=epoch)
521
+
522
+ if self.verbose > 2:
523
+ print_metrics_multitask(met_trn)
524
+
525
+ return met_trn
526
+
527
+ def validate_one_epoch(self, ldr_vld, epoch):
528
+ # # progress bar for validation
529
+ if self.verbose > 1:
530
+ pbr_batch = ProgressBar(len(ldr_vld.dataset), 'Epoch {:03d} (VLD)'.format(epoch))
531
+
532
+ # set model to validation mode
533
+ torch.set_grad_enabled(False)
534
+ self.net_.eval()
535
+
536
+ scores_vld, y_true_vld, y_mask_vld = [], [], []
537
+ losses_vld = [[] for _ in self.tgt_modalities]
538
+ for x_batch, y_batch, mask, y_mask in ldr_vld:
539
+ # if len(next(iter(x_batch.values()))) < self.batch_size:
540
+ # break
541
+ # mount data to the proper device
542
+ x_batch = {k: x_batch[k].to(self.device) for k in x_batch} # if 'img' not in k}
543
+ # x_img_batch = {k: x_img_batch[k].to(self.device) for k in x_img_batch}
544
+ y_batch = {k: y_batch[k].to(torch.float).to(self.device) for k in y_batch}
545
+ mask = {k: mask[k].to(self.device) for k in mask}
546
+ y_mask = {k: y_mask[k].to(self.device) for k in y_mask}
547
+
548
+ # forward
549
+ with torch.autocast(
550
+ device_type = 'cpu' if self.device == 'cpu' else 'cuda',
551
+ dtype = torch.bfloat16 if self.device == 'cpu' else torch.float16,
552
+ enabled = self._amp_enabled
553
+ ):
554
+
555
+ outputs = self.net_(x_batch, mask, skip_embedding=self.skip_embedding)
556
+
557
+ # calculate multitask loss
558
+ for i, k in enumerate(self.tgt_modalities):
559
+ loss_task = self.loss_fn[k](outputs[k], y_batch[k])
560
+ msk_loss_task = loss_task * y_mask[k]
561
+ losses_vld[i] += msk_loss_task.detach().cpu().numpy().tolist()
562
+
563
+ ''' TODO: change array to dictionary later '''
564
+ outputs = torch.stack(list(outputs.values()), dim=1)
565
+ y_batch = torch.stack(list(y_batch.values()), dim=1)
566
+ y_mask = torch.stack(list(y_mask.values()), dim=1)
567
+
568
+ # save outputs to evaluate performance later
569
+ scores_vld.append(outputs.detach().to(torch.float).cpu())
570
+ y_true_vld.append(y_batch.cpu())
571
+ y_mask_vld.append(y_mask.cpu())
572
+
573
+ # update progress bar
574
+ if self.verbose > 1:
575
+ batch_size = len(next(iter(x_batch.values())))
576
+ pbr_batch.update(batch_size, {})
577
+ pbr_batch.refresh()
578
+
579
+ # clear cuda cache
580
+ if "cuda" in self.device:
581
+ torch.cuda.empty_cache()
582
+
583
+ # for better tqdm progress bar display
584
+ if self.verbose > 1:
585
+ pbr_batch.close()
586
+
587
+ # calculate and print validation performance metrics
588
+ scores_vld = torch.cat(scores_vld)
589
+ y_true_vld = torch.cat(y_true_vld)
590
+ y_mask_vld = torch.cat(y_mask_vld)
591
+ y_pred_vld = (scores_vld > 0).to(torch.int)
592
+ y_prob_vld = torch.sigmoid(scores_vld)
593
+ met_vld = get_metrics_multitask(
594
+ y_true_vld.numpy(),
595
+ y_pred_vld.numpy(),
596
+ y_prob_vld.numpy(),
597
+ y_mask_vld.numpy()
598
+ )
599
+
600
+ # add loss to metrics
601
+ for i in range(len(self.tgt_modalities)):
602
+ met_vld[i]['Loss'] = np.mean(losses_vld[i])
603
+
604
+ wandb.log({f"Validation loss {list(self.tgt_modalities)[i]}": met_vld[i]['Loss'] for i in range(len(self.tgt_modalities))}, step=epoch)
605
+ wandb.log({f"Validation Balanced Accuracy {list(self.tgt_modalities)[i]}": met_vld[i]['Balanced Accuracy'] for i in range(len(self.tgt_modalities))}, step=epoch)
606
+
607
+ wandb.log({f"Validation AUC (ROC) {list(self.tgt_modalities)[i]}": met_vld[i]['AUC (ROC)'] for i in range(len(self.tgt_modalities))}, step=epoch)
608
+ wandb.log({f"Validation AUPR {list(self.tgt_modalities)[i]}": met_vld[i]['AUC (PR)'] for i in range(len(self.tgt_modalities))}, step=epoch)
609
+
610
+ if self.verbose > 2:
611
+ print_metrics_multitask(met_vld)
612
+
613
+ return met_vld
614
+
615
+
616
+ def predict_logits(self,
617
+ x: list[dict[str, Any]],
618
+ _batch_size: int | None = None,
619
+ skip_embedding: dict | None = None,
620
+ img_transform: Any | None = None,
621
+ ) -> list[dict[str, float]]:
622
+ '''
623
+ The input x can be a single sample or a list of samples.
624
+ '''
625
+ # input validation
626
+ check_is_fitted(self)
627
+ print(self.device)
628
+
629
+ # for PyTorch computational efficiency
630
+ torch.set_num_threads(1)
631
+
632
+ # set model to eval mode
633
+ torch.set_grad_enabled(False)
634
+ self.net_.eval()
635
+
636
+ # intialize dataset and dataloader object
637
+ dat = TransformerTestingDataset(x, self.src_modalities, img_transform=img_transform)
638
+ ldr = DataLoader(
639
+ dataset = dat,
640
+ batch_size = _batch_size if _batch_size is not None else len(x),
641
+ shuffle = False,
642
+ drop_last = False,
643
+ num_workers = 0,
644
+ collate_fn = TransformerTestingDataset.collate_fn,
645
+ )
646
+ # print("dataloader done")
647
+
648
+ # run model and collect results
649
+ logits: list[dict[str, float]] = []
650
+ for x_batch, mask in tqdm(ldr):
651
+ # mount data to the proper device
652
+ # print(x_batch['his_SEX'])
653
+ x_batch = {k: x_batch[k].to(self.device) for k in x_batch}
654
+ mask = {k: mask[k].to(self.device) for k in mask}
655
+
656
+ # forward
657
+ output: dict[str, Tensor] = self.net_(x_batch, mask, skip_embedding)
658
+
659
+ # convert output from dict-of-list to list of dict, then append
660
+ tmp = {k: output[k].tolist() for k in self.tgt_modalities}
661
+ tmp = [{k: tmp[k][i] for k in self.tgt_modalities} for i in range(len(next(iter(tmp.values()))))]
662
+ logits += tmp
663
+
664
+ return logits
665
+
666
+ def predict_proba(self,
667
+ x: list[dict[str, Any]],
668
+ skip_embedding: dict | None = None,
669
+ temperature: float = 1.0,
670
+ _batch_size: int | None = None,
671
+ img_transform: Any | None = None,
672
+ ) -> list[dict[str, float]]:
673
+ ''' ... '''
674
+ logits = self.predict_logits(x=x, _batch_size=_batch_size, img_transform=img_transform, skip_embedding=skip_embedding)
675
+ print("got logits")
676
+ return logits, [{k: expit(smp[k] / temperature) for k in self.tgt_modalities} for smp in logits]
677
+
678
+ def predict(self,
679
+ x: list[dict[str, Any]],
680
+ skip_embedding: dict | None = None,
681
+ fpr: dict[str, Any] | None = None,
682
+ tpr: dict[str, Any] | None = None,
683
+ thresholds: dict[str, Any] | None = None,
684
+ _batch_size: int | None = None,
685
+ img_transform: Any | None = None,
686
+ ) -> list[dict[str, int]]:
687
+ ''' ... '''
688
+ if fpr is None or tpr is None or thresholds is None:
689
+ logits, proba = self.predict_proba(x, _batch_size=_batch_size, img_transform=img_transform, skip_embedding=skip_embedding)
690
+ print("got proba")
691
+ return logits, proba, [{k: int(smp[k] > 0.5) for k in self.tgt_modalities} for smp in proba]
692
+ else:
693
+ logits, proba = self.predict_proba(x, _batch_size=_batch_size, img_transform=img_transform, skip_embedding=skip_embedding)
694
+ print("got proba")
695
+ youden_index = {}
696
+ thr = {}
697
+ for i, k in enumerate(self.tgt_modalities):
698
+ youden_index[k] = tpr[i] - fpr[i]
699
+ thr[k] = thresholds[i][np.argmax(youden_index[k])]
700
+ # print(thr[k])
701
+ # print(thr)
702
+ return logits, proba, [{k: int(smp[k] > thr[k]) for k in self.tgt_modalities} for smp in proba]
703
+
704
+ def save(self, filepath: str, epoch: int) -> None:
705
+ """Save the model to the given file stream.
706
+
707
+ :param filepath: _description_
708
+ :type filepath: str
709
+ :param epoch: _description_
710
+ :type epoch: int
711
+ """
712
+ check_is_fitted(self)
713
+ if self.data_parallel:
714
+ state_dict = self.net_.module.state_dict()
715
+ else:
716
+ state_dict = self.net_.state_dict()
717
+
718
+ # attach model hyper parameters
719
+ state_dict['src_modalities'] = self.src_modalities
720
+ state_dict['tgt_modalities'] = self.tgt_modalities
721
+ state_dict['d_model'] = self.d_model
722
+ state_dict['nhead'] = self.nhead
723
+ state_dict['num_encoder_layers'] = self.num_encoder_layers
724
+ state_dict['num_decoder_layers'] = self.num_decoder_layers
725
+ state_dict['optimizer'] = self.optimizer
726
+ state_dict['img_net'] = self.img_net
727
+ state_dict['imgnet_layers'] = self.imgnet_layers
728
+ state_dict['img_size'] = self.img_size
729
+ state_dict['patch_size'] = self.patch_size
730
+ state_dict['imgnet_ckpt'] = self.imgnet_ckpt
731
+ state_dict['train_imgnet'] = self.train_imgnet
732
+ state_dict['epoch'] = epoch
733
+
734
+ if self.scaler is not None:
735
+ state_dict['scaler'] = self.scaler.state_dict()
736
+ if self.label_distribution:
737
+ state_dict['label_distribution'] = self.label_distribution
738
+
739
+ torch.save(state_dict, filepath)
740
+
741
+ def load(self, filepath: str, map_location: str = 'cpu', img_dict=None) -> None:
742
+ """Load a model from the given file stream.
743
+
744
+ :param filepath: _description_
745
+ :type filepath: str
746
+ :param map_location: _description_, defaults to 'cpu'
747
+ :type map_location: str, optional
748
+ :param img_dict: _description_, defaults to None
749
+ :type img_dict: _type_, optional
750
+ """
751
+ # load state_dict
752
+ state_dict = torch.load(filepath, map_location=map_location)
753
+
754
+ # load data modalities
755
+ self.src_modalities: dict[str, dict[str, Any]] = state_dict.pop('src_modalities')
756
+ self.tgt_modalities: dict[str, dict[str, Any]] = state_dict.pop('tgt_modalities')
757
+ if 'label_distribution' in state_dict:
758
+ self.label_distribution: dict[str, dict[int, int]] = state_dict.pop('label_distribution')
759
+ if 'optimizer' in state_dict:
760
+ self.optimizer = state_dict.pop('optimizer')
761
+
762
+ # initialize model
763
+ self.d_model = state_dict.pop('d_model')
764
+ self.nhead = state_dict.pop('nhead')
765
+ self.num_encoder_layers = state_dict.pop('num_encoder_layers')
766
+ self.num_decoder_layers = state_dict.pop('num_decoder_layers')
767
+ if 'epoch' in state_dict.keys():
768
+ self.start_epoch = state_dict.pop('epoch')
769
+ if img_dict is None:
770
+ self.img_net = state_dict.pop('img_net')
771
+ self.imgnet_layers = state_dict.pop('imgnet_layers')
772
+ self.img_size = state_dict.pop('img_size')
773
+ self.patch_size = state_dict.pop('patch_size')
774
+ self.imgnet_ckpt = state_dict.pop('imgnet_ckpt')
775
+ self.train_imgnet = state_dict.pop('train_imgnet')
776
+ else:
777
+ self.img_net = img_dict['img_net']
778
+ self.imgnet_layers = img_dict['imgnet_layers']
779
+ self.img_size = img_dict['img_size']
780
+ self.patch_size = img_dict['patch_size']
781
+ self.imgnet_ckpt = img_dict['imgnet_ckpt']
782
+ self.train_imgnet = img_dict['train_imgnet']
783
+ state_dict.pop('img_net')
784
+ state_dict.pop('imgnet_layers')
785
+ state_dict.pop('img_size')
786
+ state_dict.pop('patch_size')
787
+ state_dict.pop('imgnet_ckpt')
788
+ state_dict.pop('train_imgnet')
789
+
790
+ for k, info in self.src_modalities.items():
791
+ if info['type'] == 'imaging':
792
+ if 'emb' not in self.img_net.lower():
793
+ info['shape'] = (1,) + (self.img_size,) * 3
794
+ info['img_shape'] = (1,) + (self.img_size,) * 3
795
+ elif 'swinunetr' in self.img_net.lower():
796
+ info['shape'] = (1, 768, 4, 4, 4)
797
+ info['img_shape'] = (1, 768, 4, 4, 4)
798
+ # print(info['shape'])
799
+
800
+ self.net_ = Transformer(self.src_modalities, self.tgt_modalities, self.d_model, self.nhead, self.num_encoder_layers, self.num_decoder_layers, self.device, self.cuda_devices, self.img_net, self.imgnet_layers, self.img_size, self.patch_size, self.imgnet_ckpt, self.train_imgnet, self.fusion_stage)
801
+
802
+
803
+ if 'scaler' in state_dict and state_dict['scaler']:
804
+ self.scaler.load_state_dict(state_dict.pop('scaler'))
805
+ self.net_.load_state_dict(state_dict)
806
+ check_is_fitted(self)
807
+ self.net_.to(self.device)
808
+
809
+ def to(self, device: str) -> Self:
810
+ """Mount the model to the given device.
811
+
812
+ :param device: _description_
813
+ :type device: str
814
+ :return: _description_
815
+ :rtype: Self
816
+ """
817
+ self.device = device
818
+ if hasattr(self, 'model'): self.net_ = self.net_.to(device)
819
+ if hasattr(self, 'img_model'): self.img_model = self.img_model.to(device)
820
+ return self
821
+
822
+ @classmethod
823
+ def from_ckpt(cls, filepath: str, device='cpu', img_dict=None) -> Self:
824
+ """Create a new ADRD model and load parameters from the checkpoint.
825
+
826
+ This is an alternative constructor.
827
+
828
+ :param filepath: _description_
829
+ :type filepath: str
830
+ :param device: _description_, defaults to 'cpu'
831
+ :type device: str, optional
832
+ :param img_dict: _description_, defaults to None
833
+ :type img_dict: _type_, optional
834
+ :return: _description_
835
+ :rtype: Self
836
+ """
837
+ obj = cls(None, None, None,device=device)
838
+ if device == 'cuda':
839
+ obj.device = "{}:{}".format(obj.device, str(obj.cuda_devices[0]))
840
+ print(obj.device)
841
+ obj.load(filepath, map_location=obj.device, img_dict=img_dict)
842
+ return obj
843
+
844
+ def _init_net(self):
845
+ """ ... """
846
+ # set the device for use
847
+ if self.device == 'cuda':
848
+ self.device = "{}:{}".format(self.device, str(self.cuda_devices[0]))
849
+ print("Device: " + self.device)
850
+
851
+ self.start_epoch = 0
852
+ if self.load_from_ckpt:
853
+ try:
854
+ print("Loading model from checkpoint...")
855
+ self.load(self.ckpt_path, map_location=self.device)
856
+ except:
857
+ print("Cannot load from checkpoint. Initializing new model...")
858
+ self.load_from_ckpt = False
859
+
860
+ if not self.load_from_ckpt:
861
+ self.net_ = nn.Transformer(
862
+ src_modalities = self.src_modalities,
863
+ tgt_modalities = self.tgt_modalities,
864
+ d_model = self.d_model,
865
+ nhead = self.nhead,
866
+ num_encoder_layers = self.num_encoder_layers,
867
+ num_decoder_layers = self.num_decoder_layers,
868
+ device = self.device,
869
+ cuda_devices = self.cuda_devices,
870
+ img_net = self.img_net,
871
+ layers = self.imgnet_layers,
872
+ img_size = self.img_size,
873
+ patch_size = self.patch_size,
874
+ imgnet_ckpt = self.imgnet_ckpt,
875
+ train_imgnet = self.train_imgnet,
876
+ fusion_stage = self.fusion_stage,
877
+ )
878
+
879
+ # intialize model parameters using xavier_uniform
880
+ for name, p in self.net_.named_parameters():
881
+ if p.dim() > 1:
882
+ torch.nn.init.xavier_uniform_(p)
883
+
884
+ self.net_.to(self.device)
885
+
886
+ # Initialize the number of GPUs
887
+ if self.data_parallel and torch.cuda.device_count() > 1:
888
+ print("Available", torch.cuda.device_count(), "GPUs!")
889
+ self.net_ = torch.nn.DataParallel(self.net_, device_ids=self.cuda_devices)
890
+
891
+ # return net
892
+
893
+ def _init_dataloader(self, x_trn, x_vld, y_trn, y_vld, img_train_trans=None, img_vld_trans=None):
894
+ # initialize dataset and dataloader
895
+ if self.balanced_sampling:
896
+ dat_trn = Transformer2ndOrderBalancedTrainingDataset(
897
+ x_trn, y_trn,
898
+ self.src_modalities,
899
+ self.tgt_modalities,
900
+ dropout_rate = .5,
901
+ dropout_strategy = 'permutation',
902
+ img_transform=img_train_trans,
903
+ )
904
+ else:
905
+ dat_trn = TransformerTrainingDataset(
906
+ x_trn, y_trn,
907
+ self.src_modalities,
908
+ self.tgt_modalities,
909
+ dropout_rate = .5,
910
+ dropout_strategy = 'permutation',
911
+ img_transform=img_train_trans,
912
+ )
913
+
914
+ dat_vld = TransformerValidationDataset(
915
+ x_vld, y_vld,
916
+ self.src_modalities,
917
+ self.tgt_modalities,
918
+ img_transform=img_vld_trans,
919
+ )
920
+
921
+ ldr_trn = DataLoader(
922
+ dataset = dat_trn,
923
+ batch_size = self.batch_size,
924
+ shuffle = True,
925
+ drop_last = False,
926
+ num_workers = self._dataloader_num_workers,
927
+ collate_fn = TransformerTrainingDataset.collate_fn,
928
+ # pin_memory = True
929
+ )
930
+
931
+ ldr_vld = DataLoader(
932
+ dataset = dat_vld,
933
+ batch_size = self.batch_size,
934
+ shuffle = False,
935
+ drop_last = False,
936
+ num_workers = self._dataloader_num_workers,
937
+ collate_fn = TransformerValidationDataset.collate_fn,
938
+ # pin_memory = True
939
+ )
940
+
941
+ return ldr_trn, ldr_vld
942
+
943
+ def _init_optimizer(self):
944
+ """ ... """
945
+ params = list(self.net_.parameters())
946
+ return torch.optim.AdamW(
947
+ params,
948
+ lr = self.lr,
949
+ betas = (0.9, 0.98),
950
+ weight_decay = self.weight_decay
951
+ )
952
+
953
+ def _init_scheduler(self, optimizer):
954
+ """ ... """
955
+
956
+ return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
957
+ optimizer=optimizer,
958
+ T_0=64,
959
+ T_mult=2,
960
+ eta_min = 0,
961
+ verbose=(self.verbose > 2)
962
+ )
963
+
964
+ def _init_loss_func(self,
965
+ num_per_cls: dict[str, tuple[int, int]],
966
+ ) -> dict[str, Module]:
967
+ """ ... """
968
+ return {k: nn.SigmoidFocalLossBeta(
969
+ beta = self.beta,
970
+ gamma = self.gamma,
971
+ num_per_cls = num_per_cls[k],
972
+ reduction = 'none',
973
+ ) for k in self.tgt_modalities}
974
+
975
+ def _proc_fit(self):
976
+ """ ... """
adrd/model/calibration.py ADDED
@@ -0,0 +1,450 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from sklearn.base import BaseEstimator
3
+ from sklearn.utils.validation import check_is_fitted
4
+ from sklearn.linear_model import LogisticRegression
5
+ from sklearn.isotonic import IsotonicRegression
6
+ from functools import lru_cache
7
+ from functools import cached_property
8
+ from typing import Self, Any
9
+ from pickle import dump
10
+ from pickle import load
11
+ from abc import ABC, abstractmethod
12
+
13
+ from . import ADRDModel
14
+ from ..utils import Formatter
15
+ from ..utils import MissingMasker
16
+
17
+
18
+ def calibration_curve(
19
+ y_true: list[int],
20
+ y_pred: list[float],
21
+ n_bins: int = 10,
22
+ ratio: float = 1.0,
23
+ ) -> tuple[list[float], list[float]]:
24
+ """
25
+ Compute true and predicted probabilities for a calibration curve. The method
26
+ assumes the inputs come from a binary classifier, and discretize the [0, 1]
27
+ interval into bins.
28
+
29
+ Note that this function is an alternative to
30
+ sklearn.calibration.calibration_curve() which can only estimate the absolute
31
+ proportion of positive cases in each bin.
32
+
33
+ Parameters
34
+ ----------
35
+ y_true : list[int]
36
+ True targets.
37
+ y_pred : list[float]
38
+ Probabilities of the positive class.
39
+ n_bins : int, default=10
40
+ Number of bins to discretize the [0, 1] interval. A bigger number
41
+ requires more data. Bins with no samples (i.e. without corresponding
42
+ values in y_prob) will not be returned, thus the returned arrays may
43
+ have less than n_bins values.
44
+ ratio : float, default=1.0
45
+ Used to adjust the class balance.
46
+
47
+ Returns
48
+ -------
49
+ prob_true : list[float]
50
+ The proportion of positive samples in each bin.
51
+ prob_pred : list[float]
52
+ The mean predicted probability in each bin.
53
+ """
54
+ # generate "n_bin" intervals
55
+ tmp = np.around(np.linspace(0, 1, n_bins + 1), decimals=6)
56
+ intvs = [(tmp[i - 1], tmp[i]) for i in range(1, len(tmp))]
57
+
58
+ # pair up (pred, true) and group them by intervals
59
+ tmp = list(zip(y_pred, y_true))
60
+ intv_pairs = {(l, r): [p for p in tmp if l <= p[0] < r] for l, r in intvs}
61
+
62
+ # calculate balanced proportion of POSITIVE cases for each intervel
63
+ # along with the balanced averaged predictions
64
+ intv_prob_true: dict[tuple, float] = dict()
65
+ intv_prob_pred: dict[tuple, float] = dict()
66
+ for intv, pairs in intv_pairs.items():
67
+ # number of cases that fall into the interval
68
+ n_pairs = len(pairs)
69
+
70
+ # it's likely that no predictions fall into the interval
71
+ if n_pairs == 0: continue
72
+
73
+ # count number of positives and negatives in the interval
74
+ n_pos = sum([p[1] for p in pairs])
75
+ n_neg = n_pairs - n_pos
76
+
77
+ # calculate adjusted proportion of positives
78
+ intv_prob_true[intv] = n_pos / (n_pos + n_neg * ratio)
79
+
80
+ # calculate adjusted avg. predictions
81
+ sum_pred_pos = sum([p[0] for p in pairs if p[1] == 1])
82
+ sum_pred_neg = sum([p[0] for p in pairs if p[1] == 0])
83
+ intv_prob_pred[intv] = (sum_pred_pos + sum_pred_neg * ratio)
84
+ intv_prob_pred[intv] /= (n_pos + n_neg * ratio)
85
+
86
+ prob_true = list(intv_prob_true.values())
87
+ prob_pred = list(intv_prob_pred.values())
88
+ return prob_true, prob_pred
89
+
90
+
91
+ class CalibrationCore(BaseEstimator):
92
+ """
93
+ A wrapper class of multiple regressors to predict the proportions of
94
+ positive samples from the predicted probabilities. The method for
95
+ calibration can be 'sigmoid' which corresponds to Platt's method (i.e. a
96
+ logistic regression model) or 'isotonic' which is a non-parametric approach.
97
+ It is not advised to use isotonic calibration with too few calibration
98
+ samples (<<1000) since it tends to overfit.
99
+
100
+ TODO
101
+ ----
102
+ - 'sigmoid' method is not trivial to implement.
103
+ """
104
+ def __init__(self,
105
+ method: str = 'isotonic',
106
+ ) -> None:
107
+ """
108
+ Initialization function of CalibrationCore class.
109
+
110
+ Parameters
111
+ ----------
112
+ method : {'sigmoid', 'isotonic'}, default='isotonic'
113
+ The method to use for calibration. can be 'sigmoid' which
114
+ corresponds to Platt's method (i.e. a logistic regression model) or
115
+ 'isotonic' which is a non-parametric approach. It is not advised to
116
+ use isotonic calibration with too few calibration samples (<<1000)
117
+ since it tends to overfit.
118
+
119
+ Raises
120
+ ------
121
+ ValueError
122
+ Sigmoid approach has not been implemented.
123
+ """
124
+ assert method in ('sigmoid', 'isotonic')
125
+ if method == 'sigmoid':
126
+ raise ValueError('Sigmoid approach has not been implemented.')
127
+ self.method = method
128
+
129
+ def fit(self,
130
+ prob_pred: list[float],
131
+ prob_true: list[float],
132
+ ) -> Self:
133
+ """
134
+ Fit the underlying regressor using prob_pred, prob_true as training
135
+ data.
136
+
137
+ Parameters
138
+ ----------
139
+ prob_pred : list[float]
140
+ Probabilities predicted directly by a model.
141
+ prob_true : list[float]
142
+ Target probabilities to calibrate to.
143
+
144
+ Returns
145
+ -------
146
+ Self
147
+ CalibrationCore object.
148
+ """
149
+ # using Platt's method for calibration
150
+ if self.method == 'sigmoid':
151
+ self.model_ = LogisticRegression()
152
+ self.model_.fit(prob_pred, prob_true)
153
+
154
+ # using isotonic calibration
155
+ elif self.method == 'isotonic':
156
+ self.model_ = IsotonicRegression(y_min=0, y_max=1, out_of_bounds='clip')
157
+ self.model_.fit(prob_pred, prob_true)
158
+
159
+ return self
160
+
161
+ def predict(self,
162
+ prob_pred: list[float],
163
+ ) -> list[float]:
164
+ """
165
+ Calibrate the input probabilities using the fitted regressor.
166
+
167
+ Parameters
168
+ ----------
169
+ prob_pred : list[float]
170
+ Probabilities predicted directly by a model.
171
+
172
+ Returns
173
+ -------
174
+ prob_cali : list[float]
175
+ Calibrated probabilities.
176
+ """
177
+ # as usual, the core needs to be fitted
178
+ check_is_fitted(self)
179
+
180
+ # note that logistic regression is classification model, we need to call
181
+ # 'predict_proba' instead of 'predict' to get the calibrated results
182
+ if self.method == 'sigmoid':
183
+ prob_cali = self.model_.predict_proba(prob_pred)
184
+ elif self.method == 'isotonic':
185
+ prob_cali = self.model_.predict(prob_pred)
186
+
187
+ return prob_cali
188
+
189
+
190
+ class CalibratedClassifier(ABC):
191
+ """
192
+ Abstract class of calibrated classifier.
193
+ """
194
+ def __init__(self,
195
+ model: ADRDModel,
196
+ background_src: list[dict[str, Any]],
197
+ background_tgt: list[dict[str, Any]],
198
+ background_is_embedding: dict[str, bool] | None = None,
199
+ method: str = 'isotonic',
200
+ ) -> None:
201
+ """
202
+ Constructor of Calibrator class.
203
+
204
+ Parameters
205
+ ----------
206
+ model : ADRDModel
207
+ Fitted model to calibrate.
208
+ background_src : list[dict[str, Any]]
209
+ Features of the background dataset.
210
+ background_tgt : list[dict[str, Any]]
211
+ Labels of the background dataset.
212
+ method : {'sigmoid', 'isotonic'}, default='isotonic'
213
+ Method used by the underlying regressor.
214
+ """
215
+ self.method = method
216
+ self.model = model
217
+ self.src_modalities = model.src_modalities
218
+ self.tgt_modalities = model.tgt_modalities
219
+ self.background_is_embedding = background_is_embedding
220
+
221
+ # format background data
222
+ fmt_src = Formatter(self.src_modalities)
223
+ fmt_tgt = Formatter(self.tgt_modalities)
224
+ self.background_src = [fmt_src(smp) for smp in background_src]
225
+ self.background_tgt = [fmt_tgt(smp) for smp in background_tgt]
226
+
227
+ @abstractmethod
228
+ def predict_proba(self,
229
+ src: list[dict[str, Any]],
230
+ is_embedding: dict[str, bool] | None = None,
231
+ ) -> list[dict[str, float]]:
232
+ """
233
+ This method returns calibrated probabilities of classification.
234
+
235
+ Parameters
236
+ ----------
237
+ src : list[dict[str, Any]]
238
+ Features of the input samples.
239
+
240
+ Returns
241
+ -------
242
+ list[dict[str, float]]
243
+ Calibrated probabilities.
244
+ """
245
+ pass
246
+
247
+ def predict(self,
248
+ src: list[dict[str, Any]],
249
+ is_embedding: dict[str, bool] | None = None,
250
+ ) -> list[dict[str, int]]:
251
+ """
252
+ Make predictions based on the results of predict_proba().
253
+
254
+ Parameters
255
+ ----------
256
+ x : list[dict[str, Any]]
257
+ Input features.
258
+
259
+ Returns
260
+ -------
261
+ list[dict[str, int]]
262
+ Calibrated predictions.
263
+ """
264
+ proba = self.predict_proba(src, is_embedding)
265
+ return [{k: int(smp[k] > 0.5) for k in self.tgt_modalities} for smp in proba]
266
+
267
+ def save(self,
268
+ filepath_state_dict: str,
269
+ ) -> None:
270
+ """
271
+ Save the state dict and the underlying model to the given paths.
272
+
273
+ Parameters
274
+ ----------
275
+ filepath_state_dict : str
276
+ File path to save the state_dict which includes the background
277
+ dataset and the regressor information.
278
+ filepath_wrapped_model : str | None, default=None
279
+ File path to save the wrapped model. If None, the model won't be
280
+ saved.
281
+ """
282
+ # save state dict
283
+ state_dict = {
284
+ 'background_src': self.background_src,
285
+ 'background_tgt': self.background_tgt,
286
+ 'background_is_embedding': self.background_is_embedding,
287
+ 'method': self.method,
288
+ }
289
+ with open(filepath_state_dict, 'wb') as f:
290
+ dump(state_dict, f)
291
+
292
+ @classmethod
293
+ def from_ckpt(cls,
294
+ filepath_state_dict: str,
295
+ filepath_wrapped_model: str,
296
+ ) -> Self:
297
+ """
298
+ Alternative constructor which loads from checkpoint.
299
+
300
+ Parameters
301
+ ----------
302
+ filepath_state_dict : str
303
+ File path to load the state_dict which includes the background
304
+ dataset and the regressor information.
305
+ filepath_wrapped_model : str
306
+ File path of the wrapped model.
307
+
308
+ Returns
309
+ -------
310
+ Self
311
+ CalibratedClassifier class object.
312
+ """
313
+ with open(filepath_state_dict, 'rb') as f:
314
+ kwargs = load(f)
315
+ kwargs['model'] = ADRDModel.from_ckpt(filepath_wrapped_model)
316
+ return cls(**kwargs)
317
+
318
+
319
+ class DynamicCalibratedClassifier(CalibratedClassifier):
320
+ """
321
+ The dynamic approach generates background predictions based on the
322
+ missingness pattern of each input. With an astronomical number of
323
+ missingness patterns, calibrating each sample requires a comprehensive
324
+ process that involves running the ADRDModel on the majority of the
325
+ background data and training a corresponding regressor. This results in a
326
+ computationally intensive calculation.
327
+ """
328
+ def predict_proba(self,
329
+ src: list[dict[str, Any]],
330
+ is_embedding: dict[str, bool] | None = None,
331
+ ) -> list[dict[str, float]]:
332
+
333
+ # initialize mask generator and format inputs
334
+ msk_gen = MissingMasker(self.src_modalities)
335
+ fmt_src = Formatter(self.src_modalities)
336
+ src = [fmt_src(smp) for smp in src]
337
+
338
+ # calculate calibrated probabilities
339
+ calibrated_prob: list[dict[str, float]] = []
340
+ for smp in src:
341
+ # model output and missingness pattern
342
+ prob = self.model.predict_proba([smp], is_embedding)[0]
343
+ mask = tuple(msk_gen(smp).values())
344
+
345
+ # get/fit core and calculate calibrated probabilities
346
+ core = self._fit_core(mask)
347
+ calibrated_prob.append({k: core[k].predict([prob[k]])[0] for k in self.tgt_modalities})
348
+
349
+ return calibrated_prob
350
+
351
+ # @lru_cache(maxsize = None)
352
+ def _fit_core(self,
353
+ missingness_pattern: tuple[bool],
354
+ ) -> dict[str, CalibrationCore]:
355
+ ''' ... '''
356
+ # remove features from all background samples accordingly
357
+ background_src, background_tgt = [], []
358
+ for src, tgt in zip(self.background_src, self.background_tgt):
359
+ src = {k: v for j, (k, v) in enumerate(src.items()) if missingness_pattern[j] == False}
360
+
361
+ # make sure there is at least one feature available
362
+ if len([v is not None for v in src.values()]) == 0: continue
363
+ background_src.append(src)
364
+ background_tgt.append(tgt)
365
+
366
+ # run model on background samples and collection predictions
367
+ background_prob = self.model.predict_proba(background_src, self.background_is_embedding, _batch_size=1024)
368
+
369
+ # list[dict] -> dict[list]
370
+ N = len(background_src)
371
+ background_prob = {k: [background_prob[i][k] for i in range(N)] for k in self.tgt_modalities}
372
+ background_true = {k: [background_tgt[i][k] for i in range(N)] for k in self.tgt_modalities}
373
+
374
+ # now, fit cores
375
+ core: dict[str, CalibrationCore] = dict()
376
+ for k in self.tgt_modalities:
377
+ prob_true, prob_pred = calibration_curve(
378
+ background_true[k], background_prob[k],
379
+ ratio = self.background_ratio[k],
380
+ )
381
+ core[k] = CalibrationCore(self.method).fit(prob_pred, prob_true)
382
+
383
+ return core
384
+
385
+ @cached_property
386
+ def background_ratio(self) -> dict[str, float]:
387
+ ''' The ratio of positives over negatives in the background dataset. '''
388
+ return {k: self.background_n_pos[k] / self.background_n_neg[k] for k in self.tgt_modalities}
389
+
390
+ @cached_property
391
+ def background_n_pos(self) -> dict[str, int]:
392
+ ''' Number of positives w.r.t each target in the background dataset. '''
393
+ return {k: sum([d[k] for d in self.background_tgt]) for k in self.tgt_modalities}
394
+
395
+ @cached_property
396
+ def background_n_neg(self) -> dict[str, int]:
397
+ ''' Number of negatives w.r.t each target in the background dataset. '''
398
+ return {k: len(self.background_tgt) - self.background_n_pos[k] for k in self.tgt_modalities}
399
+
400
+
401
+ class StaticCalibratedClassifier(CalibratedClassifier):
402
+ """
403
+ The static approach generates background predictions without considering the
404
+ missingness patterns.
405
+ """
406
+ def predict_proba(self,
407
+ src: list[dict[str, Any]],
408
+ is_embedding: dict[str, bool] | None = None,
409
+ ) -> list[dict[str, float]]:
410
+
411
+ # number of input samples
412
+ N = len(src)
413
+
414
+ # format inputs, and run ADRDModel, and convert to dict[list]
415
+ fmt_src = Formatter(self.src_modalities)
416
+ src = [fmt_src(smp) for smp in src]
417
+ prob = self.model.predict_proba(src, is_embedding)
418
+ prob = {k: [prob[i][k] for i in range(N)] for k in self.tgt_modalities}
419
+
420
+ # calibrate probabilities
421
+ core = self._fit_core()
422
+ calibrated_prob = {k: core[k].predict(prob[k]) for k in self.tgt_modalities}
423
+
424
+ # convert back to list[dict]
425
+ calibrated_prob: list[dict[str, float]] = [
426
+ {k: calibrated_prob[k][i] for k in self.tgt_modalities} for i in range(N)
427
+ ]
428
+ return calibrated_prob
429
+
430
+ @lru_cache(maxsize = None)
431
+ def _fit_core(self) -> dict[str, CalibrationCore]:
432
+ ''' ... '''
433
+ # run model on background samples and collection predictions
434
+ background_prob = self.model.predict_proba(self.background_src, self.background_is_embedding, _batch_size=1024)
435
+
436
+ # list[dict] -> dict[list]
437
+ N = len(self.background_src)
438
+ background_prob = {k: [background_prob[i][k] for i in range(N)] for k in self.tgt_modalities}
439
+ background_true = {k: [self.background_tgt[i][k] for i in range(N)] for k in self.tgt_modalities}
440
+
441
+ # now, fit cores
442
+ core: dict[str, CalibrationCore] = dict()
443
+ for k in self.tgt_modalities:
444
+ prob_true, prob_pred = calibration_curve(
445
+ background_true[k], background_prob[k],
446
+ ratio = 1.0,
447
+ )
448
+ core[k] = CalibrationCore(self.method).fit(prob_pred, prob_true)
449
+
450
+ return core
adrd/model/cnn_resnet3d_with_linear_classifier.py ADDED
@@ -0,0 +1,533 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __all__ = ['CNNResNet3DWithLinearClassifier']
2
+
3
+ import torch
4
+ from torch.utils.data import DataLoader
5
+ import numpy as np
6
+ import tqdm
7
+ from sklearn.base import BaseEstimator
8
+ from sklearn.utils.validation import check_is_fitted
9
+ from sklearn.model_selection import train_test_split
10
+ from scipy.special import expit
11
+ from copy import deepcopy
12
+ from contextlib import suppress
13
+ from typing import Any, Self, Type
14
+ from functools import wraps
15
+ Tensor = Type[torch.Tensor]
16
+ Module = Type[torch.nn.Module]
17
+
18
+ from ..utils.misc import ProgressBar
19
+ from ..utils.misc import get_metrics_multitask, print_metrics_multitask
20
+
21
+ from .. import nn
22
+ from ..utils import TransformerTrainingDataset
23
+ from ..utils import Transformer2ndOrderBalancedTrainingDataset
24
+ from ..utils import TransformerValidationDataset
25
+ from ..utils import TransformerTestingDataset
26
+ from ..utils.misc import ProgressBar
27
+ from ..utils.misc import get_metrics_multitask, print_metrics_multitask
28
+ from ..utils.misc import convert_args_kwargs_to_kwargs
29
+
30
+
31
+ def _manage_ctx_fit(func):
32
+ ''' ... '''
33
+ @wraps(func)
34
+ def wrapper(*args, **kwargs):
35
+ # format arguments
36
+ kwargs = convert_args_kwargs_to_kwargs(func, args, kwargs)
37
+
38
+ if kwargs['self']._device_ids is None:
39
+ return func(**kwargs)
40
+ else:
41
+ # change primary device
42
+ default_device = kwargs['self'].device
43
+ kwargs['self'].device = kwargs['self']._device_ids[0]
44
+ rtn = func(**kwargs)
45
+
46
+ # the actual module is wrapped
47
+ kwargs['self'].net_ = kwargs['self'].net_.module
48
+ kwargs['self'].to(default_device)
49
+ return rtn
50
+ return wrapper
51
+
52
+
53
+ class CNNResNet3DWithLinearClassifier(BaseEstimator):
54
+
55
+ def __init__(self,
56
+ src_modalities: dict[str, dict[str, Any]],
57
+ tgt_modalities: dict[str, dict[str, Any]],
58
+ num_epochs: int = 32,
59
+ batch_size: int = 8,
60
+ batch_size_multiplier: int = 1,
61
+ lr: float = 1e-2,
62
+ weight_decay: float = 0.0,
63
+ beta: float = 0.9999,
64
+ gamma: float = 2.0,
65
+ scale: float = 1.0,
66
+ criterion: str | None = None,
67
+ device: str = 'cpu',
68
+ verbose: int = 0,
69
+ _device_ids: list | None = None,
70
+ _dataloader_num_workers: int = 0,
71
+ _amp_enabled: bool = False,
72
+ _tmp_ckpt_filepath: str | None = None,
73
+ ) -> None:
74
+ ''' ... '''
75
+ # for multiprocessing
76
+ self._rank = 0
77
+ self._lock = None
78
+
79
+ # positional parameters
80
+ self.src_modalities = src_modalities
81
+ self.tgt_modalities = tgt_modalities
82
+
83
+ # training parameters
84
+ self.num_epochs = num_epochs
85
+ self.batch_size = batch_size
86
+ self.batch_size_multiplier = batch_size_multiplier
87
+ self.lr = lr
88
+ self.weight_decay = weight_decay
89
+ self.beta = beta
90
+ self.gamma = gamma
91
+ self.scale = scale
92
+ self.criterion = criterion
93
+ self.device = device
94
+ self.verbose = verbose
95
+ self._device_ids = _device_ids
96
+ self._dataloader_num_workers = _dataloader_num_workers
97
+ self._amp_enabled = _amp_enabled
98
+ self._tmp_ckpt_filepath = _tmp_ckpt_filepath
99
+
100
+
101
+ @_manage_ctx_fit
102
+ def fit(self, x, y) -> Self:
103
+ ''' ... '''
104
+ # for PyTorch computational efficiency
105
+ torch.set_num_threads(1)
106
+
107
+ # initialize neural network
108
+ self.net_ = self._init_net()
109
+
110
+ # initialize dataloaders
111
+ ldr_trn, ldr_vld = self._init_dataloader(x, y)
112
+
113
+ # initialize optimizer and scheduler
114
+ optimizer = self._init_optimizer()
115
+ scheduler = self._init_scheduler(optimizer)
116
+
117
+ # gradient scaler for AMP
118
+ if self._amp_enabled: scaler = torch.cuda.amp.GradScaler()
119
+
120
+ # initialize loss function (binary cross entropy)
121
+ loss_func = self._init_loss_func({
122
+ k: (
123
+ sum([_[k] == 0 for _ in ldr_trn.dataset.tgt]),
124
+ sum([_[k] == 1 for _ in ldr_trn.dataset.tgt]),
125
+ ) for k in self.tgt_modalities
126
+ })
127
+
128
+ # to record the best validation performance criterion
129
+ if self.criterion is not None: best_crit = None
130
+
131
+ # progress bar for epoch loops
132
+ if self.verbose == 1:
133
+ with self._lock if self._lock is not None else suppress():
134
+ pbr_epoch = tqdm.tqdm(
135
+ desc = 'Rank {:02d}'.format(self._rank),
136
+ total = self.num_epochs,
137
+ position = self._rank,
138
+ ascii = True,
139
+ leave = False,
140
+ bar_format='{l_bar}{r_bar}'
141
+ )
142
+
143
+ # training loop
144
+ for epoch in range(self.num_epochs):
145
+ # progress bar for batch loops
146
+ if self.verbose > 1:
147
+ pbr_batch = ProgressBar(len(ldr_trn.dataset), 'Epoch {:03d} (TRN)'.format(epoch))
148
+
149
+ # set model to train mode
150
+ torch.set_grad_enabled(True)
151
+ self.net_.train()
152
+
153
+ scores_trn: dict[str, list[float]] = {k: [] for k in self.tgt_modalities}
154
+ y_true_trn: dict[str, list[int]] = {k: [] for k in self.tgt_modalities}
155
+ losses_trn: dict[str, list[float]] = {k: [] for k in self.tgt_modalities}
156
+ for n_iter, (x_batch, y_batch, _, mask_y) in enumerate(ldr_trn):
157
+ # mount data to the proper device
158
+ x_batch = {k: x_batch[k].to(self.device) for k in self.src_modalities}
159
+ y_batch = {k: y_batch[k].to(torch.float).to(self.device) for k in self.tgt_modalities}
160
+ # mask_x = {k: mask_x[k].to(self.device) for k in self.src_modalities}
161
+ mask_y = {k: mask_y[k].to(self.device) for k in self.tgt_modalities}
162
+
163
+ # forward
164
+ with torch.autocast(
165
+ device_type = 'cpu' if self.device == 'cpu' else 'cuda',
166
+ dtype = torch.bfloat16 if self.device == 'cpu' else torch.float16,
167
+ enabled = self._amp_enabled,
168
+ ):
169
+ outputs = self.net_(x_batch)
170
+
171
+ # calculate multitask loss
172
+ loss = 0
173
+ for i, tgt_k in enumerate(self.tgt_modalities):
174
+ loss_k = loss_func[tgt_k](outputs[tgt_k], y_batch[tgt_k])
175
+ loss_k = torch.masked_select(loss_k, torch.logical_not(mask_y[tgt_k].squeeze()))
176
+ loss += loss_k.mean()
177
+ losses_trn[tgt_k] += loss_k.detach().cpu().numpy().tolist()
178
+
179
+ # backward
180
+ if self._amp_enabled:
181
+ scaler.scale(loss).backward()
182
+ else:
183
+ loss.backward()
184
+
185
+ # update parameters
186
+ if n_iter != 0 and n_iter % self.batch_size_multiplier == 0:
187
+ if self._amp_enabled:
188
+ scaler.step(optimizer)
189
+ scaler.update()
190
+ optimizer.zero_grad()
191
+ else:
192
+ optimizer.step()
193
+ optimizer.zero_grad()
194
+
195
+ # save outputs to evaluate performance later
196
+ for tgt_k in self.tgt_modalities:
197
+ tmp = torch.masked_select(outputs[tgt_k], torch.logical_not(mask_y[tgt_k].squeeze()))
198
+ scores_trn[tgt_k] += tmp.detach().cpu().numpy().tolist()
199
+ tmp = torch.masked_select(y_batch[tgt_k], torch.logical_not(mask_y[tgt_k].squeeze()))
200
+ y_true_trn[tgt_k] += tmp.cpu().numpy().tolist()
201
+
202
+ # update progress bar
203
+ if self.verbose > 1:
204
+ batch_size = len(next(iter(x_batch.values())))
205
+ pbr_batch.update(batch_size, {})
206
+ pbr_batch.refresh()
207
+
208
+ # for better tqdm progress bar display
209
+ if self.verbose > 1:
210
+ pbr_batch.close()
211
+
212
+ # set scheduler
213
+ scheduler.step()
214
+
215
+ # calculate and print training performance metrics
216
+ y_pred_trn: dict[str, list[int]] = {k: [] for k in self.tgt_modalities}
217
+ y_prob_trn: dict[str, list[float]] = {k: [] for k in self.tgt_modalities}
218
+ for tgt_k in self.tgt_modalities:
219
+ for i in range(len(scores_trn[tgt_k])):
220
+ y_pred_trn[tgt_k].append(1 if scores_trn[tgt_k][i] > 0 else 0)
221
+ y_prob_trn[tgt_k].append(expit(scores_trn[tgt_k][i]))
222
+ met_trn = get_metrics_multitask(y_true_trn, y_pred_trn, y_prob_trn)
223
+
224
+ # add loss to metrics
225
+ for tgt_k in self.tgt_modalities:
226
+ met_trn[tgt_k]['Loss'] = np.mean(losses_trn[tgt_k])
227
+
228
+ if self.verbose > 2:
229
+ print_metrics_multitask(met_trn)
230
+
231
+ # progress bar for validation
232
+ if self.verbose > 1:
233
+ pbr_batch = ProgressBar(len(ldr_vld.dataset), 'Epoch {:03d} (VLD)'.format(epoch))
234
+
235
+ # set model to validation mode
236
+ torch.set_grad_enabled(False)
237
+ self.net_.eval()
238
+
239
+ scores_vld: dict[str, list[float]] = {k: [] for k in self.tgt_modalities}
240
+ y_true_vld: dict[str, list[int]] = {k: [] for k in self.tgt_modalities}
241
+ losses_vld: dict[str, list[float]] = {k: [] for k in self.tgt_modalities}
242
+ for x_batch, y_batch, _, mask_y in ldr_vld:
243
+ # mount data to the proper device
244
+ x_batch = {k: x_batch[k].to(self.device) for k in self.src_modalities}
245
+ y_batch = {k: y_batch[k].to(torch.float).to(self.device) for k in self.tgt_modalities}
246
+ # mask_x = {k: mask_x[k].to(self.device) for k in self.src_modalities}
247
+ mask_y = {k: mask_y[k].to(self.device) for k in self.tgt_modalities}
248
+
249
+ # forward
250
+ with torch.autocast(
251
+ device_type = 'cpu' if self.device == 'cpu' else 'cuda',
252
+ dtype = torch.bfloat16 if self.device == 'cpu' else torch.float16,
253
+ enabled = self._amp_enabled
254
+ ):
255
+ outputs = self.net_(x_batch)
256
+
257
+ # calculate multitask loss
258
+ for i, tgt_k in enumerate(self.tgt_modalities):
259
+ loss_k = loss_func[tgt_k](outputs[tgt_k], y_batch[tgt_k])
260
+ loss_k = torch.masked_select(loss_k, torch.logical_not(mask_y[tgt_k].squeeze()))
261
+ losses_vld[tgt_k] += loss_k.detach().cpu().numpy().tolist()
262
+
263
+ # save outputs to evaluate performance later
264
+ for tgt_k in self.tgt_modalities:
265
+ tmp = torch.masked_select(outputs[tgt_k], torch.logical_not(mask_y[tgt_k].squeeze()))
266
+ scores_vld[tgt_k] += tmp.detach().cpu().numpy().tolist()
267
+ tmp = torch.masked_select(y_batch[tgt_k], torch.logical_not(mask_y[tgt_k].squeeze()))
268
+ y_true_vld[tgt_k] += tmp.cpu().numpy().tolist()
269
+
270
+ # update progress bar
271
+ if self.verbose > 1:
272
+ batch_size = len(next(iter(x_batch.values())))
273
+ pbr_batch.update(batch_size, {})
274
+ pbr_batch.refresh()
275
+
276
+ # for better tqdm progress bar display
277
+ if self.verbose > 1:
278
+ pbr_batch.close()
279
+
280
+ # calculate and print validation performance metrics
281
+ y_pred_vld: dict[str, list[int]] = {k: [] for k in self.tgt_modalities}
282
+ y_prob_vld: dict[str, list[float]] = {k: [] for k in self.tgt_modalities}
283
+ for tgt_k in self.tgt_modalities:
284
+ for i in range(len(scores_vld[tgt_k])):
285
+ y_pred_vld[tgt_k].append(1 if scores_vld[tgt_k][i] > 0 else 0)
286
+ y_prob_vld[tgt_k].append(expit(scores_vld[tgt_k][i]))
287
+ met_vld = get_metrics_multitask(y_true_vld, y_pred_vld, y_prob_vld)
288
+
289
+ # add loss to metrics
290
+ for tgt_k in self.tgt_modalities:
291
+ met_vld[tgt_k]['Loss'] = np.mean(losses_vld[tgt_k])
292
+
293
+ if self.verbose > 2:
294
+ print_metrics_multitask(met_vld)
295
+
296
+ # save the model if it has the best validation performance criterion by far
297
+ if self.criterion is None: continue
298
+
299
+ # is current criterion better than previous best?
300
+ curr_crit = np.mean([met_vld[k][self.criterion] for k in self.tgt_modalities])
301
+ if best_crit is None or np.isnan(best_crit):
302
+ is_better = True
303
+ elif self.criterion == 'Loss' and best_crit >= curr_crit:
304
+ is_better = True
305
+ elif self.criterion != 'Loss' and best_crit <= curr_crit:
306
+ is_better = True
307
+ else:
308
+ is_better = False
309
+
310
+ # update best criterion
311
+ if is_better:
312
+ best_crit = curr_crit
313
+ best_state_dict = deepcopy(self.net_.state_dict())
314
+
315
+ if self._tmp_ckpt_filepath is not None:
316
+ self.save(self._tmp_ckpt_filepath)
317
+
318
+ if self.verbose > 2:
319
+ print('Best {}: {}'.format(self.criterion, best_crit))
320
+
321
+ if self.verbose == 1:
322
+ with self._lock if self._lock is not None else suppress():
323
+ pbr_epoch.update(1)
324
+ pbr_epoch.refresh()
325
+
326
+ if self.verbose == 1:
327
+ with self._lock if self._lock is not None else suppress():
328
+ pbr_epoch.close()
329
+
330
+ # restore the model of the best validation performance across all epoches
331
+ if ldr_vld is not None and self.criterion is not None:
332
+ self.net_.load_state_dict(best_state_dict)
333
+
334
+ return self
335
+
336
+ def predict_logits(self,
337
+ x: list[dict[str, Any]],
338
+ _batch_size: int | None = None,
339
+ ) -> list[dict[str, float]]:
340
+ """
341
+ The input x can be a single sample or a list of samples.
342
+ """
343
+ # input validation
344
+ check_is_fitted(self)
345
+
346
+ # for PyTorch computational efficiency
347
+ torch.set_num_threads(1)
348
+
349
+ # set model to eval mode
350
+ torch.set_grad_enabled(False)
351
+ self.net_.eval()
352
+
353
+ # intialize dataset and dataloader object
354
+ dat = TransformerTestingDataset(x, self.src_modalities)
355
+ ldr = DataLoader(
356
+ dataset = dat,
357
+ batch_size = _batch_size if _batch_size is not None else len(x),
358
+ shuffle = False,
359
+ drop_last = False,
360
+ num_workers = 0,
361
+ collate_fn = TransformerTestingDataset.collate_fn,
362
+ )
363
+
364
+ # run model and collect results
365
+ logits: list[dict[str, float]] = []
366
+ for x_batch, _ in ldr:
367
+ # mount data to the proper device
368
+ x_batch = {k: x_batch[k].to(self.device) for k in self.src_modalities}
369
+
370
+ # forward
371
+ output: dict[str, Tensor] = self.net_(x_batch)
372
+
373
+ # convert output from dict-of-list to list of dict, then append
374
+ tmp = {k: output[k].tolist() for k in self.tgt_modalities}
375
+ tmp = [{k: tmp[k][i] for k in self.tgt_modalities} for i in range(len(next(iter(tmp.values()))))]
376
+ logits += tmp
377
+
378
+ return logits
379
+
380
+ def predict_proba(self,
381
+ x: list[dict[str, Any]],
382
+ temperature: float = 1.0,
383
+ _batch_size: int | None = None,
384
+ ) -> list[dict[str, float]]:
385
+ ''' ... '''
386
+ logits = self.predict_logits(x, _batch_size)
387
+ return [{k: expit(smp[k] / temperature) for k in self.tgt_modalities} for smp in logits]
388
+
389
+ def predict(self,
390
+ x: list[dict[str, Any]],
391
+ _batch_size: int | None = None,
392
+ ) -> list[dict[str, int]]:
393
+ ''' ... '''
394
+ logits = self.predict_logits(x, _batch_size)
395
+ return [{k: int(smp[k] > 0.0) for k in self.tgt_modalities} for smp in logits]
396
+
397
+ def save(self, filepath: str) -> None:
398
+ ''' ... '''
399
+ check_is_fitted(self)
400
+ state_dict = self.net_.state_dict()
401
+
402
+ # attach model hyper parameters
403
+ state_dict['src_modalities'] = self.src_modalities
404
+ state_dict['tgt_modalities'] = self.tgt_modalities
405
+ print('Saving model checkpoint to {} ... '.format(filepath), end='')
406
+ torch.save(state_dict, filepath)
407
+ print('Done.')
408
+
409
+ def load(self, filepath: str) -> None:
410
+ ''' ... '''
411
+ # load state_dict
412
+ state_dict = torch.load(filepath, map_location='cpu')
413
+
414
+ # load essential parameters
415
+ self.src_modalities: dict[str, dict[str, Any]] = state_dict.pop('src_modalities')
416
+ self.tgt_modalities: dict[str, dict[str, Any]] = state_dict.pop('tgt_modalities')
417
+
418
+ # initialize model
419
+ self.net_ = nn.CNNResNet3DWithLinearClassifier(
420
+ self.src_modalities,
421
+ self.tgt_modalities,
422
+ )
423
+
424
+ # load model parameters
425
+ self.net_.load_state_dict(state_dict)
426
+ self.to(self.device)
427
+
428
+ def to(self, device: str) -> Self:
429
+ ''' Mount model to the given device. '''
430
+ self.device = device
431
+ if hasattr(self, 'net_'): self.net_ = self.net_.to(device)
432
+ return self
433
+
434
+ @classmethod
435
+ def from_ckpt(cls, filepath: str) -> Self:
436
+ ''' ... '''
437
+ obj = cls(None, None)
438
+ obj.load(filepath)
439
+ return obj
440
+
441
+ def _init_net(self):
442
+ """ ... """
443
+ net = nn.CNNResNet3DWithLinearClassifier(
444
+ self.src_modalities,
445
+ self.tgt_modalities,
446
+ ).to(self.device)
447
+
448
+ # train on multiple GPUs using torch.nn.DataParallel
449
+ if self._device_ids is not None:
450
+ net = torch.nn.DataParallel(net, device_ids=self._device_ids)
451
+
452
+ # intialize model parameters using xavier_uniform
453
+ for p in net.parameters():
454
+ if p.dim() > 1:
455
+ torch.nn.init.xavier_uniform_(p)
456
+
457
+ return net
458
+
459
+ def _init_dataloader(self, x, y):
460
+ """ ... """
461
+ # split dataset
462
+ x_trn, x_vld, y_trn, y_vld = train_test_split(
463
+ x, y, test_size = 0.2, random_state = 0,
464
+ )
465
+
466
+ # initialize dataset and dataloader
467
+ # dat_trn = TransformerTrainingDataset(
468
+ dat_trn = Transformer2ndOrderBalancedTrainingDataset(
469
+ x_trn, y_trn,
470
+ self.src_modalities,
471
+ self.tgt_modalities,
472
+ dropout_rate = .5,
473
+ # dropout_strategy = 'compensated',
474
+ dropout_strategy = 'permutation',
475
+ )
476
+
477
+ dat_vld = TransformerValidationDataset(
478
+ x_vld, y_vld,
479
+ self.src_modalities,
480
+ self.tgt_modalities,
481
+ )
482
+
483
+ ldr_trn = DataLoader(
484
+ dataset = dat_trn,
485
+ batch_size = self.batch_size,
486
+ shuffle = True,
487
+ drop_last = False,
488
+ num_workers = self._dataloader_num_workers,
489
+ collate_fn = TransformerTrainingDataset.collate_fn,
490
+ # pin_memory = True
491
+ )
492
+
493
+ ldr_vld = DataLoader(
494
+ dataset = dat_vld,
495
+ batch_size = self.batch_size,
496
+ shuffle = False,
497
+ drop_last = False,
498
+ num_workers = self._dataloader_num_workers,
499
+ collate_fn = TransformerValidationDataset.collate_fn,
500
+ # pin_memory = True
501
+ )
502
+
503
+ return ldr_trn, ldr_vld
504
+
505
+ def _init_optimizer(self):
506
+ """ ... """
507
+ return torch.optim.AdamW(
508
+ self.net_.parameters(),
509
+ lr = self.lr,
510
+ betas = (0.9, 0.98),
511
+ weight_decay = self.weight_decay
512
+ )
513
+
514
+ def _init_scheduler(self, optimizer):
515
+ """ ... """
516
+ return torch.optim.lr_scheduler.OneCycleLR(
517
+ optimizer = optimizer,
518
+ max_lr = self.lr,
519
+ total_steps = self.num_epochs,
520
+ verbose = (self.verbose > 2)
521
+ )
522
+
523
+ def _init_loss_func(self,
524
+ num_per_cls: dict[str, tuple[int, int]],
525
+ ) -> dict[str, Module]:
526
+ """ ... """
527
+ return {k: nn.SigmoidFocalLoss(
528
+ beta = self.beta,
529
+ gamma = self.gamma,
530
+ scale = self.scale,
531
+ num_per_cls = num_per_cls[k],
532
+ reduction = 'none',
533
+ ) for k in self.tgt_modalities}
adrd/model/imaging_model.py ADDED
@@ -0,0 +1,843 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __all__ = ['Transformer']
2
+
3
+ import wandb
4
+ import torch
5
+ import numpy as np
6
+ import functools
7
+ import inspect
8
+ import monai
9
+ import random
10
+
11
+ from tqdm import tqdm
12
+ from functools import wraps
13
+ from sklearn.base import BaseEstimator
14
+ from sklearn.utils.validation import check_is_fitted
15
+ from sklearn.model_selection import train_test_split
16
+ from scipy.special import expit
17
+ from copy import deepcopy
18
+ from contextlib import suppress
19
+ from typing import Any, Self, Type
20
+ Tensor = Type[torch.Tensor]
21
+ Module = Type[torch.nn.Module]
22
+ from torch.utils.data import DataLoader
23
+ from monai.utils.type_conversion import convert_to_tensor
24
+ from monai.transforms import (
25
+ LoadImaged,
26
+ Compose,
27
+ CropForegroundd,
28
+ CopyItemsd,
29
+ SpatialPadd,
30
+ EnsureChannelFirstd,
31
+ Spacingd,
32
+ OneOf,
33
+ ScaleIntensityRanged,
34
+ HistogramNormalized,
35
+ RandSpatialCropSamplesd,
36
+ RandSpatialCropd,
37
+ CenterSpatialCropd,
38
+ RandCoarseDropoutd,
39
+ RandCoarseShuffled,
40
+ Resized,
41
+ )
42
+
43
+ # for DistributedDataParallel
44
+ import torch.distributed as dist
45
+ import torch.multiprocessing as mp
46
+ from torch.nn.parallel import DistributedDataParallel as DDP
47
+
48
+ from .. import nn
49
+ from ..utils.misc import ProgressBar
50
+ from ..utils.misc import get_metrics_multitask, print_metrics_multitask
51
+ from ..utils.misc import convert_args_kwargs_to_kwargs
52
+
53
+ import warnings
54
+ warnings.filterwarnings("ignore")
55
+
56
+
57
+ def _manage_ctx_fit(func):
58
+ ''' ... '''
59
+ @wraps(func)
60
+ def wrapper(*args, **kwargs):
61
+ # format arguments
62
+ kwargs = convert_args_kwargs_to_kwargs(func, args, kwargs)
63
+
64
+ if kwargs['self']._device_ids is None:
65
+ return func(**kwargs)
66
+ else:
67
+ # change primary device
68
+ default_device = kwargs['self'].device
69
+ kwargs['self'].device = kwargs['self']._device_ids[0]
70
+ rtn = func(**kwargs)
71
+ kwargs['self'].to(default_device)
72
+ return rtn
73
+ return wrapper
74
+
75
+ def collate_handle_corrupted(samples_list, dataset, labels, dtype=torch.half):
76
+ # print(len(samples_list))
77
+ orig_len = len(samples_list)
78
+ # for the loss to be consistent, we drop samples with NaN values in any of their corresponding crops
79
+ for i, s in enumerate(samples_list):
80
+ ic(s is None)
81
+ if s is None:
82
+ continue
83
+ samples_list = list(filter(lambda x: x is not None, samples_list))
84
+
85
+ if len(samples_list) == 0:
86
+ ic('recursive call')
87
+ return collate_handle_corrupted([dataset[random.randint(0, len(dataset)-1)] for _ in range(orig_len)], dataset, labels)
88
+
89
+ # collated_images = torch.stack([convert_to_tensor(s["image"]) for s in samples_list])
90
+ try:
91
+ if "image" in samples_list[0]:
92
+ samples_list = [s for s in samples_list if not torch.isnan(s["image"]).any()]
93
+ # print('samples list: ', len(samples_list))
94
+ collated_images = torch.stack([convert_to_tensor(s["image"]) for s in samples_list])
95
+ # print("here1")
96
+ collated_labels = {k: torch.Tensor([s["label"][k] if s["label"][k] is not None else 0 for s in samples_list]) for k in labels}
97
+ # print("here2")
98
+ collated_mask = {k: torch.Tensor([1 if s["label"][k] is not None else 0 for s in samples_list]) for k in labels}
99
+ # print("here3")
100
+ return {"image": collated_images,
101
+ "label": collated_labels,
102
+ "mask": collated_mask}
103
+ except:
104
+ return collate_handle_corrupted([dataset[random.randint(0, len(dataset)-1)] for _ in range(orig_len)], dataset, labels)
105
+
106
+
107
+
108
+ def get_backend(img_backend):
109
+ if img_backend == 'C3D':
110
+ return nn.C3D
111
+ elif img_backend == 'DenseNet':
112
+ return nn.DenseNet
113
+
114
+
115
+ class ImagingModel(BaseEstimator):
116
+ ''' ... '''
117
+ def __init__(self,
118
+ tgt_modalities: list[str],
119
+ label_fractions: dict[str, float],
120
+ num_epochs: int = 32,
121
+ batch_size: int = 8,
122
+ batch_size_multiplier: int = 1,
123
+ lr: float = 1e-2,
124
+ weight_decay: float = 0.0,
125
+ beta: float = 0.9999,
126
+ gamma: float = 2.0,
127
+ bn_size: int = 4,
128
+ growth_rate: int = 12,
129
+ block_config: tuple = (3, 3, 3),
130
+ compression: float = 0.5,
131
+ num_init_features: int = 16,
132
+ drop_rate: float = 0.2,
133
+ criterion: str | None = None,
134
+ device: str = 'cpu',
135
+ cuda_devices: list = [1],
136
+ ckpt_path: str = '/home/skowshik/ADRD_repo/adrd_tool/dev/ckpt/ckpt.pt',
137
+ load_from_ckpt: bool = True,
138
+ save_intermediate_ckpts: bool = False,
139
+ data_parallel: bool = False,
140
+ verbose: int = 0,
141
+ img_backend: str | None = None,
142
+ label_distribution: dict = {},
143
+ wandb_ = 1,
144
+ _device_ids: list | None = None,
145
+ _dataloader_num_workers: int = 4,
146
+ _amp_enabled: bool = False,
147
+ ) -> None:
148
+ ''' ... '''
149
+ # for multiprocessing
150
+ self._rank = 0
151
+ self._lock = None
152
+
153
+ # positional parameters
154
+ self.tgt_modalities = tgt_modalities
155
+
156
+ # training parameters
157
+ self.label_fractions = label_fractions
158
+ self.num_epochs = num_epochs
159
+ self.batch_size = batch_size
160
+ self.batch_size_multiplier = batch_size_multiplier
161
+ self.lr = lr
162
+ self.weight_decay = weight_decay
163
+ self.beta = beta
164
+ self.gamma = gamma
165
+ self.bn_size = bn_size
166
+ self.growth_rate = growth_rate
167
+ self.block_config = block_config
168
+ self.compression = compression
169
+ self.num_init_features = num_init_features
170
+ self.drop_rate = drop_rate
171
+ self.criterion = criterion
172
+ self.device = device
173
+ self.cuda_devices = cuda_devices
174
+ self.ckpt_path = ckpt_path
175
+ self.load_from_ckpt = load_from_ckpt
176
+ self.save_intermediate_ckpts = save_intermediate_ckpts
177
+ self.data_parallel = data_parallel
178
+ self.verbose = verbose
179
+ self.img_backend = img_backend
180
+ self.label_distribution = label_distribution
181
+ self.wandb_ = wandb_
182
+ self._device_ids = _device_ids
183
+ self._dataloader_num_workers = _dataloader_num_workers
184
+ self._amp_enabled = _amp_enabled
185
+ self.scaler = torch.cuda.amp.GradScaler()
186
+
187
+ @_manage_ctx_fit
188
+ def fit(self, trn_list, vld_list, img_train_trans=None, img_vld_trans=None) -> Self:
189
+ # def fit(self, x, y) -> Self:
190
+ ''' ... '''
191
+
192
+ # start a new wandb run to track this script
193
+ if self.wandb_ == 1:
194
+ wandb.init(
195
+ # set the wandb project where this run will be logged
196
+ project="ADRD_main",
197
+
198
+ # track hyperparameters and run metadata
199
+ config={
200
+ "Model": "DenseNet",
201
+ "Loss": 'Focalloss',
202
+ "EMB": "ALL_EMB",
203
+ "epochs": 256,
204
+ }
205
+ )
206
+ wandb.run.log_code("/home/skowshik/ADRD_repo/pipeline_v1_main/adrd_tool")
207
+ else:
208
+ wandb.init(mode="disabled")
209
+ # for PyTorch computational efficiency
210
+ torch.set_num_threads(1)
211
+ print(self.criterion)
212
+
213
+ # initialize neural network
214
+ self._init_net()
215
+
216
+ # for k, info in self.src_modalities.items():
217
+ # if info['type'] == 'imaging' and self.img_net != 'EMB':
218
+ # info['shape'] = (1,) + (self.img_size,) * 3
219
+ # info['img_shape'] = (1,) + (self.img_size,) * 3
220
+ # print(info['shape'])
221
+
222
+ # initialize dataloaders
223
+ # ldr_trn, ldr_vld = self._init_dataloader(x, y)
224
+ # ldr_trn, ldr_vld = self._init_dataloader(x_trn, x_vld, y_trn, y_vld)
225
+ ldr_trn, ldr_vld = self._init_dataloader(trn_list, vld_list, img_train_trans=img_train_trans, img_vld_trans=img_vld_trans)
226
+
227
+ # initialize optimizer and scheduler
228
+ if not self.load_from_ckpt:
229
+ self.optimizer = self._init_optimizer()
230
+ self.scheduler = self._init_scheduler(self.optimizer)
231
+
232
+ # gradient scaler for AMP
233
+ if self._amp_enabled:
234
+ self.scaler = torch.cuda.amp.GradScaler()
235
+
236
+ # initialize focal loss function
237
+ self.loss_fn = {}
238
+
239
+ for k in self.tgt_modalities:
240
+ if self.label_fractions[k] >= 0.3:
241
+ alpha = -1
242
+ else:
243
+ alpha = pow((1 - self.label_fractions[k]), 2)
244
+ # alpha = -1
245
+ self.loss_fn[k] = nn.SigmoidFocalLoss(
246
+ alpha = alpha,
247
+ gamma = self.gamma,
248
+ reduction = 'none'
249
+ )
250
+
251
+ # to record the best validation performance criterion
252
+ if self.criterion is not None:
253
+ best_crit = None
254
+ best_crit_AUPR = None
255
+
256
+ # progress bar for epoch loops
257
+ if self.verbose == 1:
258
+ with self._lock if self._lock is not None else suppress():
259
+ pbr_epoch = tqdm(
260
+ desc = 'Rank {:02d}'.format(self._rank),
261
+ total = self.num_epochs,
262
+ position = self._rank,
263
+ ascii = True,
264
+ leave = False,
265
+ bar_format='{l_bar}{r_bar}'
266
+ )
267
+
268
+ # Define a hook function to print and store the gradient of a layer
269
+ def print_and_store_grad(grad, grad_list):
270
+ grad_list.append(grad)
271
+ # print(grad)
272
+
273
+ # grad_list = []
274
+ # self.net_.modules_emb_src['img_MRI_T1'].downsample[0].weight.register_hook(lambda grad: print_and_store_grad(grad, grad_list))
275
+
276
+ # lambda_coeff = 0.0001
277
+ # margin_loss = torch.nn.MarginRankingLoss(reduction='sum', margin=0.05)
278
+
279
+ # training loop
280
+ for epoch in range(self.start_epoch, self.num_epochs):
281
+ met_trn = self.train_one_epoch(ldr_trn, epoch)
282
+ met_vld = self.validate_one_epoch(ldr_vld, epoch)
283
+
284
+ print(self.ckpt_path.split('/')[-1])
285
+
286
+ # save the model if it has the best validation performance criterion by far
287
+ if self.criterion is None: continue
288
+
289
+
290
+ # is current criterion better than previous best?
291
+ curr_crit = np.mean([met_vld[i][self.criterion] for i in range(len(self.tgt_modalities))])
292
+ curr_crit_AUPR = np.mean([met_vld[i]["AUC (PR)"] for i in range(len(self.tgt_modalities))])
293
+ # AUROC
294
+ if best_crit is None or np.isnan(best_crit):
295
+ is_better = True
296
+ elif self.criterion == 'Loss' and best_crit >= curr_crit:
297
+ is_better = True
298
+ elif self.criterion != 'Loss' and best_crit <= curr_crit :
299
+ is_better = True
300
+ else:
301
+ is_better = False
302
+
303
+ # AUPR
304
+ if best_crit_AUPR is None or np.isnan(best_crit_AUPR):
305
+ is_better_AUPR = True
306
+ elif best_crit_AUPR <= curr_crit_AUPR :
307
+ is_better_AUPR = True
308
+ else:
309
+ is_better_AUPR = False
310
+
311
+ # update best criterion
312
+ if is_better_AUPR:
313
+ best_crit_AUPR = curr_crit_AUPR
314
+ if self.save_intermediate_ckpts:
315
+ print(f"Saving the model to {self.ckpt_path[:-3]}_AUPR.pt...")
316
+ self.save(self.ckpt_path[:-3]+"_AUPR.pt", epoch)
317
+
318
+ if is_better:
319
+ best_crit = curr_crit
320
+ best_state_dict = deepcopy(self.net_.state_dict())
321
+ if self.save_intermediate_ckpts:
322
+ print(f"Saving the model to {self.ckpt_path}...")
323
+ self.save(self.ckpt_path, epoch)
324
+
325
+ if self.verbose > 2:
326
+ print('Best {}: {}'.format(self.criterion, best_crit))
327
+ print('Best {}: {}'.format('AUC (PR)', best_crit_AUPR))
328
+
329
+ if self.verbose == 1:
330
+ with self._lock if self._lock is not None else suppress():
331
+ pbr_epoch.update(1)
332
+ pbr_epoch.refresh()
333
+
334
+ return self
335
+
336
+ def train_one_epoch(self, ldr_trn, epoch):
337
+
338
+ # progress bar for batch loops
339
+ if self.verbose > 1:
340
+ pbr_batch = ProgressBar(len(ldr_trn.dataset), 'Epoch {:03d} (TRN)'.format(epoch))
341
+
342
+ torch.set_grad_enabled(True)
343
+ self.net_.train()
344
+
345
+ scores_trn, y_true_trn, y_mask_trn = [], [], []
346
+ losses_trn = [[] for _ in self.tgt_modalities]
347
+ iters = len(ldr_trn)
348
+ print(iters)
349
+ for n_iter, batch_data in enumerate(ldr_trn):
350
+ # if len(batch_data["image"]) < self.batch_size:
351
+ # continue
352
+
353
+ x_batch = batch_data["image"].to(self.device, non_blocking=True)
354
+ y_batch = {k: v.to(self.device, non_blocking=True) for k,v in batch_data["label"].items()}
355
+ y_mask = {k: v.to(self.device, non_blocking=True) for k,v in batch_data["mask"].items()}
356
+
357
+ with torch.autocast(
358
+ device_type = 'cpu' if self.device == 'cpu' else 'cuda',
359
+ dtype = torch.bfloat16 if self.device == 'cpu' else torch.float16,
360
+ enabled = self._amp_enabled,
361
+ ):
362
+
363
+ outputs = self.net_(x_batch, shap=False)
364
+ # print(outputs.shape)
365
+ # calculate multitask loss
366
+ loss = 0
367
+ for i, k in enumerate(self.tgt_modalities):
368
+ loss_task = self.loss_fn[k](outputs[k], y_batch[k])
369
+ msk_loss_task = loss_task * y_mask[k]
370
+ msk_loss_mean = msk_loss_task.sum() / y_mask[k].sum()
371
+ loss += msk_loss_mean
372
+ losses_trn[i] += msk_loss_task.detach().cpu().numpy().tolist()
373
+
374
+ # backward
375
+ if self._amp_enabled:
376
+ self.scaler.scale(loss).backward()
377
+ else:
378
+ loss.backward()
379
+
380
+ # print(len(grad_list), len(grad_list[-1]))
381
+ # print(f"Gradient at {n_iter}: {grad_list[-1][0]}")
382
+
383
+ # update parameters
384
+ if n_iter != 0 and n_iter % self.batch_size_multiplier == 0:
385
+ if self._amp_enabled:
386
+ self.scaler.step(self.optimizer)
387
+ self.scaler.update()
388
+ self.optimizer.zero_grad()
389
+ else:
390
+ self.optimizer.step()
391
+ self.optimizer.zero_grad()
392
+ # set self.scheduler
393
+ self.scheduler.step(epoch + n_iter / iters)
394
+ # print(f"Weight: {self.net_.module.features[0].weight[0]}")
395
+
396
+ ''' TODO: change array to dictionary later '''
397
+ outputs = torch.stack(list(outputs.values()), dim=1)
398
+ y_batch = torch.stack(list(y_batch.values()), dim=1)
399
+ y_mask = torch.stack(list(y_mask.values()), dim=1)
400
+
401
+ # save outputs to evaluate performance later
402
+ scores_trn.append(outputs.detach().to(torch.float).cpu())
403
+ y_true_trn.append(y_batch.cpu())
404
+ y_mask_trn.append(y_mask.cpu())
405
+
406
+ # log metrics to wandb
407
+
408
+ # update progress bar
409
+ if self.verbose > 1:
410
+ batch_size = len(x_batch)
411
+ pbr_batch.update(batch_size, {})
412
+ pbr_batch.refresh()
413
+
414
+ # clear cuda cache
415
+ if "cuda" in self.device:
416
+ torch.cuda.empty_cache()
417
+
418
+ # for better tqdm progress bar display
419
+ if self.verbose > 1:
420
+ pbr_batch.close()
421
+
422
+ # # set self.scheduler
423
+ # self.scheduler.step()
424
+
425
+ # calculate and print training performance metrics
426
+ scores_trn = torch.cat(scores_trn)
427
+ y_true_trn = torch.cat(y_true_trn)
428
+ y_mask_trn = torch.cat(y_mask_trn)
429
+ y_pred_trn = (scores_trn > 0).to(torch.int)
430
+ y_prob_trn = torch.sigmoid(scores_trn)
431
+ met_trn = get_metrics_multitask(
432
+ y_true_trn.numpy(),
433
+ y_pred_trn.numpy(),
434
+ y_prob_trn.numpy(),
435
+ y_mask_trn.numpy()
436
+ )
437
+
438
+ # add loss to metrics
439
+ for i in range(len(self.tgt_modalities)):
440
+ met_trn[i]['Loss'] = np.mean(losses_trn[i])
441
+
442
+ wandb.log({f"Train loss {list(self.tgt_modalities)[i]}": met_trn[i]['Loss'] for i in range(len(self.tgt_modalities))}, step=epoch)
443
+ wandb.log({f"Train Balanced Accuracy {list(self.tgt_modalities)[i]}": met_trn[i]['Balanced Accuracy'] for i in range(len(self.tgt_modalities))}, step=epoch)
444
+
445
+ wandb.log({f"Train AUC (ROC) {list(self.tgt_modalities)[i]}": met_trn[i]['AUC (ROC)'] for i in range(len(self.tgt_modalities))}, step=epoch)
446
+ wandb.log({f"Train AUPR {list(self.tgt_modalities)[i]}": met_trn[i]['AUC (PR)'] for i in range(len(self.tgt_modalities))}, step=epoch)
447
+
448
+ if self.verbose > 2:
449
+ print_metrics_multitask(met_trn)
450
+
451
+ return met_trn
452
+
453
+ # @torch.no_grad()
454
+ def validate_one_epoch(self, ldr_vld, epoch):
455
+ # progress bar for validation
456
+ if self.verbose > 1:
457
+ pbr_batch = ProgressBar(len(ldr_vld.dataset), 'Epoch {:03d} (VLD)'.format(epoch))
458
+
459
+ # set model to validation mode
460
+ torch.set_grad_enabled(False)
461
+ self.net_.eval()
462
+
463
+ scores_vld, y_true_vld, y_mask_vld = [], [], []
464
+ losses_vld = [[] for _ in self.tgt_modalities]
465
+ for batch_data in ldr_vld:
466
+ # if len(batch_data["image"]) < self.batch_size:
467
+ # continue
468
+ x_batch = batch_data["image"].to(self.device, non_blocking=True)
469
+ y_batch = {k: v.to(self.device, non_blocking=True) for k,v in batch_data["label"].items()}
470
+ y_mask = {k: v.to(self.device, non_blocking=True) for k,v in batch_data["mask"].items()}
471
+
472
+ # forward
473
+ with torch.autocast(
474
+ device_type = 'cpu' if self.device == 'cpu' else 'cuda',
475
+ dtype = torch.bfloat16 if self.device == 'cpu' else torch.float16,
476
+ enabled = self._amp_enabled
477
+ ):
478
+
479
+ outputs = self.net_(x_batch, shap=False)
480
+
481
+ # calculate multitask loss
482
+ for i, k in enumerate(self.tgt_modalities):
483
+ loss_task = self.loss_fn[k](outputs[k], y_batch[k])
484
+ msk_loss_task = loss_task * y_mask[k]
485
+ losses_vld[i] += msk_loss_task.detach().cpu().numpy().tolist()
486
+
487
+ ''' TODO: change array to dictionary later '''
488
+ outputs = torch.stack(list(outputs.values()), dim=1)
489
+ y_batch = torch.stack(list(y_batch.values()), dim=1)
490
+ y_mask = torch.stack(list(y_mask.values()), dim=1)
491
+
492
+ # save outputs to evaluate performance later
493
+ scores_vld.append(outputs.detach().to(torch.float).cpu())
494
+ y_true_vld.append(y_batch.cpu())
495
+ y_mask_vld.append(y_mask.cpu())
496
+
497
+ # update progress bar
498
+ if self.verbose > 1:
499
+ batch_size = len(x_batch)
500
+ pbr_batch.update(batch_size, {})
501
+ pbr_batch.refresh()
502
+
503
+ # clear cuda cache
504
+ if "cuda" in self.device:
505
+ torch.cuda.empty_cache()
506
+
507
+ # for better tqdm progress bar display
508
+ if self.verbose > 1:
509
+ pbr_batch.close()
510
+
511
+ # calculate and print validation performance metrics
512
+ scores_vld = torch.cat(scores_vld)
513
+ y_true_vld = torch.cat(y_true_vld)
514
+ y_mask_vld = torch.cat(y_mask_vld)
515
+ y_pred_vld = (scores_vld > 0).to(torch.int)
516
+ y_prob_vld = torch.sigmoid(scores_vld)
517
+ met_vld = get_metrics_multitask(
518
+ y_true_vld.numpy(),
519
+ y_pred_vld.numpy(),
520
+ y_prob_vld.numpy(),
521
+ y_mask_vld.numpy()
522
+ )
523
+
524
+ # add loss to metrics
525
+ for i in range(len(self.tgt_modalities)):
526
+ met_vld[i]['Loss'] = np.mean(losses_vld[i])
527
+
528
+ wandb.log({f"Validation loss {list(self.tgt_modalities)[i]}": met_vld[i]['Loss'] for i in range(len(self.tgt_modalities))}, step=epoch)
529
+ wandb.log({f"Validation Balanced Accuracy {list(self.tgt_modalities)[i]}": met_vld[i]['Balanced Accuracy'] for i in range(len(self.tgt_modalities))}, step=epoch)
530
+
531
+ wandb.log({f"Validation AUC (ROC) {list(self.tgt_modalities)[i]}": met_vld[i]['AUC (ROC)'] for i in range(len(self.tgt_modalities))}, step=epoch)
532
+ wandb.log({f"Validation AUPR {list(self.tgt_modalities)[i]}": met_vld[i]['AUC (PR)'] for i in range(len(self.tgt_modalities))}, step=epoch)
533
+
534
+ if self.verbose > 2:
535
+ print_metrics_multitask(met_vld)
536
+
537
+ return met_vld
538
+
539
+
540
+ def save(self, filepath: str, epoch: int = 0) -> None:
541
+ ''' ... '''
542
+ check_is_fitted(self)
543
+ if self.data_parallel:
544
+ state_dict = self.net_.module.state_dict()
545
+ else:
546
+ state_dict = self.net_.state_dict()
547
+
548
+ # attach model hyper parameters
549
+ state_dict['tgt_modalities'] = self.tgt_modalities
550
+ state_dict['optimizer'] = self.optimizer
551
+ state_dict['bn_size'] = self.bn_size
552
+ state_dict['growth_rate'] = self.growth_rate
553
+ state_dict['block_config'] = self.block_config
554
+ state_dict['compression'] = self.compression
555
+ state_dict['num_init_features'] = self.num_init_features
556
+ state_dict['drop_rate'] = self.drop_rate
557
+ state_dict['epoch'] = epoch
558
+
559
+ if self.scaler is not None:
560
+ state_dict['scaler'] = self.scaler.state_dict()
561
+ if self.label_distribution:
562
+ state_dict['label_distribution'] = self.label_distribution
563
+
564
+ torch.save(state_dict, filepath)
565
+
566
+ def load(self, filepath: str, map_location: str = 'cpu', how='latest') -> None:
567
+ ''' ... '''
568
+ # load state_dict
569
+ if how == 'latest':
570
+ if torch.load(filepath)['epoch'] > torch.load(f'{filepath[:-3]}_AUPR.pt')['epoch']:
571
+ print("Loading model saved using AUROC")
572
+ state_dict = torch.load(filepath, map_location=map_location)
573
+ else:
574
+ print("Loading model saved using AUPR")
575
+ state_dict = torch.load(f'{filepath[:-3]}_AUPR.pt', map_location=map_location)
576
+ else:
577
+ state_dict = torch.load(filepath, map_location=map_location)
578
+
579
+ # load data modalities
580
+ self.tgt_modalities: dict[str, dict[str, Any]] = state_dict.pop('tgt_modalities')
581
+ if 'label_distribution' in state_dict:
582
+ self.label_distribution: dict[str, dict[int, int]] = state_dict.pop('label_distribution')
583
+ if 'optimizer' in state_dict:
584
+ self.optimizer = state_dict.pop('optimizer')
585
+ if 'bn_size' in state_dict:
586
+ self.bn_size = state_dict.pop('bn_size')
587
+ if 'growth_rate' in state_dict:
588
+ self.growth_rate = state_dict.pop('growth_rate')
589
+ if 'block_config' in state_dict:
590
+ self.block_config = state_dict.pop('block_config')
591
+ if 'compression' in state_dict:
592
+ self.compression = state_dict.pop('compression')
593
+ if 'num_init_features' in state_dict:
594
+ self.num_init_features = state_dict.pop('num_init_features')
595
+ if 'drop_rate' in state_dict:
596
+ self.drop_rate = state_dict.pop('drop_rate')
597
+ if 'epoch' in state_dict:
598
+ self.start_epoch = state_dict.pop('epoch')
599
+ print(f'Epoch: {self.start_epoch}')
600
+
601
+ # initialize model
602
+
603
+ self.net_ = get_backend(self.img_backend)(
604
+ tgt_modalities = self.tgt_modalities,
605
+ bn_size = self.bn_size,
606
+ growth_rate=self.growth_rate,
607
+ block_config=self.block_config,
608
+ compression=self.compression,
609
+ num_init_features=self.num_init_features,
610
+ drop_rate=self.drop_rate,
611
+ load_from_ckpt=self.load_from_ckpt
612
+ )
613
+ print(self.net_)
614
+
615
+ if 'scaler' in state_dict and state_dict['scaler']:
616
+ self.scaler.load_state_dict(state_dict.pop('scaler'))
617
+ self.net_.load_state_dict(state_dict)
618
+ check_is_fitted(self)
619
+ self.net_.to(self.device)
620
+
621
+ def to(self, device: str) -> Self:
622
+ ''' Mount model to the given device. '''
623
+ self.device = device
624
+ if hasattr(self, 'model'): self.net_ = self.net_.to(device)
625
+ return self
626
+
627
+ @classmethod
628
+ def from_ckpt(cls, filepath: str, device='cpu', img_backend=None, load_from_ckpt=True, how='latest') -> Self:
629
+ ''' ... '''
630
+ obj = cls(None, None, None,device=device)
631
+ if device == 'cuda':
632
+ obj.device = "{}:{}".format(obj.device, str(obj.cuda_devices[0]))
633
+ print(obj.device)
634
+ obj.img_backend=img_backend
635
+ obj.load_from_ckpt = load_from_ckpt
636
+ obj.load(filepath, map_location=obj.device, how=how)
637
+ return obj
638
+
639
+ def _init_net(self):
640
+ """ ... """
641
+ self.start_epoch = 0
642
+ # set the device for use
643
+ if self.device == 'cuda':
644
+ self.device = "{}:{}".format(self.device, str(self.cuda_devices[0]))
645
+ # self.load(self.ckpt_path, map_location=self.device)
646
+ # print("Loading model from checkpoint...")
647
+ # self.load(self.ckpt_path, map_location=self.device)
648
+
649
+ if self.load_from_ckpt:
650
+ try:
651
+ print("Loading model from checkpoint...")
652
+ self.load(self.ckpt_path, map_location=self.device)
653
+ except:
654
+ print("Cannot load from checkpoint. Initializing new model...")
655
+ self.load_from_ckpt = False
656
+
657
+ if not self.load_from_ckpt:
658
+ self.net_ = get_backend(self.img_backend)(
659
+ tgt_modalities = self.tgt_modalities,
660
+ bn_size = self.bn_size,
661
+ growth_rate=self.growth_rate,
662
+ block_config=self.block_config,
663
+ compression=self.compression,
664
+ num_init_features=self.num_init_features,
665
+ drop_rate=self.drop_rate,
666
+ load_from_ckpt=self.load_from_ckpt
667
+ )
668
+
669
+ # # intialize model parameters using xavier_uniform
670
+ # for p in self.net_.parameters():
671
+ # if p.dim() > 1:
672
+ # torch.nn.init.xavier_uniform_(p)
673
+
674
+ self.net_.to(self.device)
675
+
676
+ # Initialize the number of GPUs
677
+ if self.data_parallel and torch.cuda.device_count() > 1:
678
+ print("Available", torch.cuda.device_count(), "GPUs!")
679
+ self.net_ = torch.nn.DataParallel(self.net_, device_ids=self.cuda_devices)
680
+
681
+ # return net
682
+
683
+ def _init_dataloader(self, trn_list, vld_list, img_train_trans=None, img_vld_trans=None):
684
+ # def _init_dataloader(self, x, y):
685
+ """ ... """
686
+ # # split dataset
687
+ # x_trn, x_vld, y_trn, y_vld = train_test_split(
688
+ # x, y, test_size = 0.2, random_state = 0,
689
+ # )
690
+
691
+ # # initialize dataset and dataloader
692
+ # dat_trn = CNNTrainingValidationDataset(
693
+ # x_trn, y_trn,
694
+ # self.tgt_modalities,
695
+ # img_transform=img_train_trans,
696
+ # )
697
+
698
+ # dat_vld = CNNTrainingValidationDataset(
699
+ # x_vld, y_vld,
700
+ # self.tgt_modalities,
701
+ # img_transform=img_vld_trans,
702
+ # )
703
+
704
+ dat_trn = monai.data.Dataset(data=trn_list, transform=img_train_trans)
705
+ dat_vld = monai.data.Dataset(data=vld_list, transform=img_vld_trans)
706
+ collate_fn_trn = functools.partial(collate_handle_corrupted, dataset=dat_trn, dtype=torch.FloatTensor, labels=self.tgt_modalities)
707
+ collate_fn_vld = functools.partial(collate_handle_corrupted, dataset=dat_vld, dtype=torch.FloatTensor, labels=self.tgt_modalities)
708
+
709
+ ldr_trn = DataLoader(
710
+ dataset = dat_trn,
711
+ batch_size = self.batch_size,
712
+ shuffle = True,
713
+ drop_last = False,
714
+ num_workers = self._dataloader_num_workers,
715
+ collate_fn = collate_fn_trn,
716
+ # pin_memory = True
717
+ )
718
+
719
+ ldr_vld = DataLoader(
720
+ dataset = dat_vld,
721
+ batch_size = self.batch_size,
722
+ shuffle = False,
723
+ drop_last = False,
724
+ num_workers = self._dataloader_num_workers,
725
+ collate_fn = collate_fn_vld,
726
+ # pin_memory = True
727
+ )
728
+
729
+ return ldr_trn, ldr_vld
730
+
731
+ def _init_optimizer(self):
732
+ """ ... """
733
+ params = list(self.net_.parameters())
734
+ # for p in params:
735
+ # print(p.requires_grad)
736
+ return torch.optim.AdamW(
737
+ params,
738
+ lr = self.lr,
739
+ betas = (0.9, 0.98),
740
+ weight_decay = self.weight_decay
741
+ )
742
+
743
+ def _init_scheduler(self, optimizer):
744
+ """ ... """
745
+ # return torch.optim.lr_scheduler.OneCycleLR(
746
+ # optimizer = optimizer,
747
+ # max_lr = self.lr,
748
+ # total_steps = self.num_epochs,
749
+ # verbose = (self.verbose > 2)
750
+ # )
751
+
752
+ # return torch.optim.lr_scheduler.CosineAnnealingLR(
753
+ # optimizer=optimizer,
754
+ # T_max=64,
755
+ # verbose=(self.verbose > 2)
756
+ # )
757
+
758
+ return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
759
+ optimizer=optimizer,
760
+ T_0=64,
761
+ T_mult=2,
762
+ eta_min = 0,
763
+ verbose=(self.verbose > 2)
764
+ )
765
+
766
+ def _init_loss_func(self,
767
+ num_per_cls: dict[str, tuple[int, int]],
768
+ ) -> dict[str, Module]:
769
+ """ ... """
770
+ return {k: nn.SigmoidFocalLossBeta(
771
+ beta = self.beta,
772
+ gamma = self.gamma,
773
+ num_per_cls = num_per_cls[k],
774
+ reduction = 'none',
775
+ ) for k in self.tgt_modalities}
776
+
777
+ def _proc_fit(self):
778
+ """ ... """
779
+
780
+ def _init_test_dataloader(self, batch_size, tst_list, img_tst_trans=None):
781
+ # input validation
782
+ check_is_fitted(self)
783
+ print(self.device)
784
+
785
+ # for PyTorch computational efficiency
786
+ torch.set_num_threads(1)
787
+
788
+ # set model to eval mode
789
+ torch.set_grad_enabled(False)
790
+ self.net_.eval()
791
+
792
+ dat_tst = monai.data.Dataset(data=tst_list, transform=img_tst_trans)
793
+ collate_fn_tst = functools.partial(collate_handle_corrupted, dataset=dat_tst, dtype=torch.FloatTensor, labels=self.tgt_modalities)
794
+ # print(collate_fn_tst)
795
+
796
+ ldr_tst = DataLoader(
797
+ dataset = dat_tst,
798
+ batch_size = batch_size,
799
+ shuffle = False,
800
+ drop_last = False,
801
+ num_workers = self._dataloader_num_workers,
802
+ collate_fn = collate_fn_tst,
803
+ # pin_memory = True
804
+ )
805
+ return ldr_tst
806
+
807
+
808
+ def predict_logits(self,
809
+ ldr_tst: Any | None = None,
810
+ ) -> list[dict[str, float]]:
811
+
812
+ # run model and collect results
813
+ logits: list[dict[str, float]] = []
814
+ for batch_data in tqdm(ldr_tst):
815
+ # print(batch_data["image"])
816
+ if len(batch_data) == 0:
817
+ continue
818
+ x_batch = batch_data["image"].to(self.device, non_blocking=True)
819
+ outputs = self.net_(x_batch, shap=False)
820
+
821
+ # convert output from dict-of-list to list of dict, then append
822
+ tmp = {k: outputs[k].tolist() for k in self.tgt_modalities}
823
+ tmp = [{k: tmp[k][i] for k in self.tgt_modalities} for i in range(len(next(iter(tmp.values()))))]
824
+ logits += tmp
825
+
826
+ return logits
827
+
828
+ def predict_proba(self,
829
+ ldr_tst: Any | None = None,
830
+ temperature: float = 1.0,
831
+ ) -> list[dict[str, float]]:
832
+ ''' ... '''
833
+ logits = self.predict_logits(ldr_tst)
834
+ print("got logits")
835
+ return logits, [{k: expit(smp[k] / temperature) for k in self.tgt_modalities} for smp in logits]
836
+
837
+ def predict(self,
838
+ ldr_tst: Any | None = None,
839
+ ) -> list[dict[str, int]]:
840
+ ''' ... '''
841
+ logits, proba = self.predict_proba(ldr_tst)
842
+ print("got proba")
843
+ return logits, proba, [{k: int(smp[k] > 0.5) for k in self.tgt_modalities} for smp in proba]
adrd/model/train_resnet.py ADDED
@@ -0,0 +1,484 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import tqdm
4
+ from sklearn.base import BaseEstimator
5
+ from sklearn.utils.validation import check_is_fitted
6
+ from sklearn.model_selection import train_test_split
7
+ from scipy.special import expit
8
+ from copy import deepcopy
9
+ from contextlib import suppress
10
+ from typing import Any, Self
11
+ from icecream import ic
12
+
13
+ from .. import nn
14
+ from ..utils import TransformerTrainingDataset
15
+ from ..utils import TransformerValidationDataset
16
+ from ..utils import MissingMasker
17
+ from ..utils import ConstantImputer
18
+ from ..utils import Formatter
19
+ from ..utils.misc import ProgressBar
20
+ from ..utils.misc import get_metrics_multitask, print_metrics_multitask
21
+
22
+
23
+ class TrainResNet(BaseEstimator):
24
+ ''' ... '''
25
+ def __init__(self,
26
+ src_modalities: dict[str, dict[str, Any]],
27
+ tgt_modalities: dict[str, dict[str, Any]],
28
+ label_fractions: dict[str, float],
29
+ num_epochs: int = 32,
30
+ batch_size: int = 8,
31
+ lr: float = 1e-2,
32
+ weight_decay: float = 0.0,
33
+ gamma: float = 0.0,
34
+ criterion: str | None = None,
35
+ device: str = 'cpu',
36
+ cuda_devices: list = [1,2],
37
+ mri_feature: str = 'img_MRI_T1',
38
+ ckpt_path: str = '/home/skowshik/ADRD_repo/adrd_tool/adrd/dev/ckpt/ckpt.pt',
39
+ load_from_ckpt: bool = True,
40
+ save_intermediate_ckpts: bool = False,
41
+ data_parallel: bool = False,
42
+ verbose: int = 0,
43
+ ):
44
+ ''' ... '''
45
+ # for multiprocessing
46
+ self._rank = 0
47
+ self._lock = None
48
+
49
+ # positional parameters
50
+ self.src_modalities = src_modalities
51
+ self.tgt_modalities = tgt_modalities
52
+
53
+ # training parameters
54
+ self.label_fractions = label_fractions
55
+ self.num_epochs = num_epochs
56
+ self.batch_size = batch_size
57
+ self.lr = lr
58
+ self.weight_decay = weight_decay
59
+ self.gamma = gamma
60
+ self.criterion = criterion
61
+ self.device = device
62
+ self.cuda_devices = cuda_devices
63
+ self.mri_feature = mri_feature
64
+ self.ckpt_path = ckpt_path
65
+ self.load_from_ckpt = load_from_ckpt
66
+ self.save_intermediate_ckpts = save_intermediate_ckpts
67
+ self.data_parallel = data_parallel
68
+ self.verbose = verbose
69
+
70
+ def fit(self, x, y):
71
+ ''' ... '''
72
+ # for PyTorch computational efficiency
73
+ torch.set_num_threads(1)
74
+
75
+ # set the device for use
76
+ if self.device == 'cuda':
77
+ self.device = "{}:{}".format(self.device, str(self.cuda_devices[0]))
78
+
79
+ # initialize model
80
+ if self.load_from_ckpt:
81
+ try:
82
+ print("Loading model from checkpoint...")
83
+ self.load(self.ckpt_path, map_location=self.device)
84
+ except:
85
+ print("Cannot load from checkpoint. Initializing new model...")
86
+ self.load_from_ckpt = False
87
+
88
+ # initialize model
89
+ if not self.load_from_ckpt:
90
+ self.net_ = nn.ResNetModel(
91
+ self.tgt_modalities,
92
+ mri_feature = self.mri_feature
93
+ )
94
+ # intialize model parameters using xavier_uniform
95
+ for p in self.net_.parameters():
96
+ if p.dim() > 1:
97
+ torch.nn.init.xavier_uniform_(p)
98
+
99
+ self.net_.to(self.device)
100
+
101
+ # Initialize the number of GPUs
102
+ if self.data_parallel and torch.cuda.device_count() > 1:
103
+ print("Available", torch.cuda.device_count(), "GPUs!")
104
+ self.net_ = torch.nn.DataParallel(self.net_, device_ids=self.cuda_devices)
105
+
106
+
107
+ # split dataset
108
+ x_trn, x_vld, y_trn, y_vld = train_test_split(
109
+ x, y, test_size = 0.2, random_state = 0,
110
+ )
111
+
112
+ # initialize dataset and dataloader
113
+ dat_trn = TransformerTrainingDataset(
114
+ x_trn, y_trn,
115
+ self.src_modalities,
116
+ self.tgt_modalities,
117
+ dropout_rate = .5,
118
+ dropout_strategy = 'compensated',
119
+ mri_feature = self.mri_feature,
120
+ )
121
+
122
+ dat_vld = TransformerValidationDataset(
123
+ x_vld, y_vld,
124
+ self.src_modalities,
125
+ self.tgt_modalities,
126
+ mri_feature = self.mri_feature,
127
+ )
128
+
129
+ # ic(dat_trn[0])
130
+
131
+ ldr_trn = torch.utils.data.DataLoader(
132
+ dat_trn,
133
+ batch_size = self.batch_size,
134
+ shuffle = True,
135
+ drop_last = False,
136
+ num_workers = 0,
137
+ collate_fn = TransformerTrainingDataset.collate_fn,
138
+ # pin_memory = True
139
+ )
140
+
141
+ ldr_vld = torch.utils.data.DataLoader(
142
+ dat_vld,
143
+ batch_size = self.batch_size,
144
+ shuffle = False,
145
+ drop_last = False,
146
+ num_workers = 0,
147
+ collate_fn = TransformerTrainingDataset.collate_fn,
148
+ # pin_memory = True
149
+ )
150
+
151
+ # initialize optimizer
152
+ optimizer = torch.optim.AdamW(
153
+ self.net_.parameters(),
154
+ lr = self.lr,
155
+ betas = (0.9, 0.98),
156
+ weight_decay = self.weight_decay
157
+ )
158
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=64, verbose=(self.verbose > 2))
159
+
160
+ # initialize loss function (binary cross entropy)
161
+ loss_fn = {}
162
+
163
+ for k in self.tgt_modalities:
164
+ alpha = pow((1 - self.label_fractions[k]), self.gamma)
165
+ # if alpha < 0.5:
166
+ # alpha = -1
167
+ loss_fn[k] = nn.SigmoidFocalLoss(
168
+ alpha = alpha,
169
+ gamma = self.gamma,
170
+ reduction = 'none'
171
+ )
172
+
173
+ # to record the best validation performance criterion
174
+ if self.criterion is not None:
175
+ best_crit = None
176
+
177
+ # progress bar for epoch loops
178
+ if self.verbose == 1:
179
+ with self._lock if self._lock is not None else suppress():
180
+ pbr_epoch = tqdm.tqdm(
181
+ desc = 'Rank {:02d}'.format(self._rank),
182
+ total = self.num_epochs,
183
+ position = self._rank,
184
+ ascii = True,
185
+ leave = False,
186
+ bar_format='{l_bar}{r_bar}'
187
+ )
188
+
189
+ # Define a hook function to print and store the gradient of a layer
190
+ def print_and_store_grad(grad, grad_list):
191
+ grad_list.append(grad)
192
+ # print(grad)
193
+
194
+ # grad_list = []
195
+ # self.net_.module.img_net_.featurizer.down_tr64.ops[0].conv1.weight.register_hook(lambda grad: print_and_store_grad(grad, grad_list))
196
+ # self.net_.module.modules_emb_src['gender'].weight.register_hook(lambda grad: print_and_store_grad(grad, grad_list))
197
+
198
+
199
+ # training loop
200
+ for epoch in range(self.num_epochs):
201
+ # progress bar for batch loops
202
+ if self.verbose > 1:
203
+ pbr_batch = ProgressBar(len(dat_trn), 'Epoch {:03d} (TRN)'.format(epoch))
204
+
205
+ # set model to train mode
206
+ torch.set_grad_enabled(True)
207
+ self.net_.train()
208
+
209
+ scores_trn, y_true_trn = [], []
210
+ losses_trn = [[] for _ in self.tgt_modalities]
211
+ for x_batch, y_batch, mask in ldr_trn:
212
+
213
+ # mount data to the proper device
214
+ x_batch = {k: x_batch[k].to(self.device) for k in x_batch}
215
+ y_batch = {k: y_batch[k].to(torch.float).to(self.device) for k in y_batch}
216
+
217
+ # forward
218
+ outputs = self.net_(x_batch)
219
+
220
+ # calculate multitask loss
221
+ loss = 0
222
+ for i, k in enumerate(self.tgt_modalities):
223
+ loss_task = loss_fn[k](outputs[k], y_batch[k])
224
+ loss += loss_task.mean()
225
+ losses_trn[i] += loss_task.detach().cpu().numpy().tolist()
226
+
227
+ # backward
228
+ optimizer.zero_grad(set_to_none=True)
229
+ loss.backward()
230
+ optimizer.step()
231
+
232
+ ''' TODO: change array to dictionary later '''
233
+ outputs = torch.stack(list(outputs.values()), dim=1)
234
+ y_batch = torch.stack(list(y_batch.values()), dim=1)
235
+
236
+ # save outputs to evaluate performance later
237
+ scores_trn.append(outputs.detach().to(torch.float).cpu())
238
+ y_true_trn.append(y_batch.cpu())
239
+
240
+ # update progress bar
241
+ if self.verbose > 1:
242
+ batch_size = len(next(iter(x_batch.values())))
243
+ pbr_batch.update(batch_size, {})
244
+ pbr_batch.refresh()
245
+
246
+ # clear cuda cache
247
+ if "cuda" in self.device:
248
+ torch.cuda.empty_cache()
249
+
250
+ # for better tqdm progress bar display
251
+ if self.verbose > 1:
252
+ pbr_batch.close()
253
+
254
+ # set scheduler
255
+ scheduler.step()
256
+
257
+ # calculate and print training performance metrics
258
+ scores_trn = torch.cat(scores_trn)
259
+ y_true_trn = torch.cat(y_true_trn)
260
+ y_pred_trn = (scores_trn > 0).to(torch.int)
261
+ y_prob_trn = torch.sigmoid(scores_trn)
262
+ met_trn = get_metrics_multitask(
263
+ y_true_trn.numpy(),
264
+ y_pred_trn.numpy(),
265
+ y_prob_trn.numpy()
266
+ )
267
+
268
+ # add loss to metrics
269
+ for i in range(len(self.tgt_modalities)):
270
+ met_trn[i]['Loss'] = np.mean(losses_trn[i])
271
+
272
+ if self.verbose > 2:
273
+ print_metrics_multitask(met_trn)
274
+
275
+ # progress bar for validation
276
+ if self.verbose > 1:
277
+ pbr_batch = ProgressBar(len(dat_vld), 'Epoch {:03d} (VLD)'.format(epoch))
278
+
279
+ # set model to validation mode
280
+ torch.set_grad_enabled(False)
281
+ self.net_.eval()
282
+
283
+ scores_vld, y_true_vld = [], []
284
+ losses_vld = [[] for _ in self.tgt_modalities]
285
+ for x_batch, y_batch, mask in ldr_vld:
286
+ # mount data to the proper device
287
+ x_batch = {k: x_batch[k].to(self.device) for k in x_batch}
288
+ y_batch = {k: y_batch[k].to(torch.float).to(self.device) for k in y_batch}
289
+
290
+ # forward
291
+ outputs = self.net_(x_batch)
292
+
293
+ # calculate multitask loss
294
+ for i, k in enumerate(self.tgt_modalities):
295
+ loss_task = loss_fn[k](outputs[k], y_batch[k])
296
+ losses_vld[i] += loss_task.detach().cpu().numpy().tolist()
297
+
298
+ ''' TODO: change array to dictionary later '''
299
+ outputs = torch.stack(list(outputs.values()), dim=1)
300
+ y_batch = torch.stack(list(y_batch.values()), dim=1)
301
+
302
+ # save outputs to evaluate performance later
303
+ scores_vld.append(outputs.detach().to(torch.float).cpu())
304
+ y_true_vld.append(y_batch.cpu())
305
+
306
+ # update progress bar
307
+ if self.verbose > 1:
308
+ batch_size = len(next(iter(x_batch.values())))
309
+ pbr_batch.update(batch_size, {})
310
+ pbr_batch.refresh()
311
+
312
+ # clear cuda cache
313
+ if "cuda" in self.device:
314
+ torch.cuda.empty_cache()
315
+
316
+ # for better tqdm progress bar display
317
+ if self.verbose > 1:
318
+ pbr_batch.close()
319
+
320
+ # calculate and print validation performance metrics
321
+ scores_vld = torch.cat(scores_vld)
322
+ y_true_vld = torch.cat(y_true_vld)
323
+ y_pred_vld = (scores_vld > 0).to(torch.int)
324
+ y_prob_vld = torch.sigmoid(scores_vld)
325
+ met_vld = get_metrics_multitask(
326
+ y_true_vld.numpy(),
327
+ y_pred_vld.numpy(),
328
+ y_prob_vld.numpy()
329
+ )
330
+
331
+ # add loss to metrics
332
+ for i in range(len(self.tgt_modalities)):
333
+ met_vld[i]['Loss'] = np.mean(losses_vld[i])
334
+
335
+ if self.verbose > 2:
336
+ print_metrics_multitask(met_vld)
337
+
338
+ # save the model if it has the best validation performance criterion by far
339
+ if self.criterion is None: continue
340
+
341
+ # is current criterion better than previous best?
342
+ curr_crit = np.mean([met_vld[i][self.criterion] for i in range(len(self.tgt_modalities))])
343
+ if best_crit is None or np.isnan(best_crit):
344
+ is_better = True
345
+ elif self.criterion == 'Loss' and best_crit >= curr_crit:
346
+ is_better = True
347
+ elif self.criterion != 'Loss' and best_crit <= curr_crit:
348
+ is_better = True
349
+ else:
350
+ is_better = False
351
+
352
+ # update best criterion
353
+ if is_better:
354
+ best_crit = curr_crit
355
+ best_state_dict = deepcopy(self.net_.state_dict())
356
+ if self.save_intermediate_ckpts:
357
+ print("Saving the model...")
358
+ self.save(self.ckpt_path)
359
+
360
+ if self.verbose > 2:
361
+ print('Best {}: {}'.format(self.criterion, best_crit))
362
+
363
+ if self.verbose == 1:
364
+ with self._lock if self._lock is not None else suppress():
365
+ pbr_epoch.update(1)
366
+ pbr_epoch.refresh()
367
+
368
+ if self.verbose == 1:
369
+ with self._lock if self._lock is not None else suppress():
370
+ pbr_epoch.close()
371
+
372
+ # restore the model of the best validation performance across all epoches
373
+ if ldr_vld is not None and self.criterion is not None:
374
+ self.net_.load_state_dict(best_state_dict)
375
+
376
+ return self
377
+
378
+ def predict_logits(self,
379
+ x: list[dict[str, Any]],
380
+ ) -> list[dict[str, float]]:
381
+ '''
382
+ The input x can be a single sample or a list of samples.
383
+ '''
384
+ # input validation
385
+ check_is_fitted(self)
386
+
387
+ # for PyTorch computational efficiency
388
+ torch.set_num_threads(1)
389
+
390
+ # set model to eval mode
391
+ torch.set_grad_enabled(False)
392
+ self.net_.eval()
393
+
394
+ # number of samples to evaluate
395
+ n_samples = len(x)
396
+
397
+ # format x
398
+ fmt = Formatter(self.src_modalities)
399
+ x = [fmt(smp) for smp in x]
400
+
401
+ # generate missing mask (BEFORE IMPUTATION)
402
+ msk = MissingMasker(self.src_modalities)
403
+ mask = [msk(smp) for smp in x]
404
+
405
+ # reformat x and then impute by 0s
406
+ imp = ConstantImputer(self.src_modalities)
407
+ x = [imp(smp) for smp in x]
408
+
409
+ # convert list-of-dict to dict-of-list
410
+ x = {k: [smp[k] for smp in x] for k in self.src_modalities}
411
+ mask = {k: [smp[k] for smp in mask] for k in self.src_modalities}
412
+
413
+ # to tensor
414
+ x = {k: torch.as_tensor(np.array(v)).to(self.device) for k, v in x.items()}
415
+ mask = {k: torch.as_tensor(np.array(v)).to(self.device) for k, v in mask.items()}
416
+
417
+ # calculate logits
418
+ logits = self.net_(x)
419
+
420
+ # convert dict-of-list to list-of-dict
421
+ logits = {k: logits[k].tolist() for k in self.tgt_modalities}
422
+ logits = [{k: logits[k][i] for k in self.tgt_modalities} for i in range(n_samples)]
423
+
424
+ return logits
425
+
426
+ def predict_proba(self,
427
+ x: list[dict[str, Any]],
428
+ temperature: float = 1.0
429
+ ) -> list[dict[str, float]]:
430
+ ''' ... '''
431
+ # calculate logits
432
+ logits = self.predict_logits(x)
433
+
434
+ # convert logits to probabilities and
435
+ proba = [{k: expit(smp[k] / temperature) for k in self.tgt_modalities} for smp in logits]
436
+ return proba
437
+
438
+ def predict(self,
439
+ x: list[dict[str, Any]],
440
+ ) -> list[dict[str, int]]:
441
+ ''' ... '''
442
+ proba = self.predict_proba(x)
443
+ return [{k: int(smp[k] > 0.5) for k in self.tgt_modalities} for smp in proba]
444
+
445
+ def save(self, filepath: str) -> None:
446
+ ''' ... '''
447
+ check_is_fitted(self)
448
+ if self.data_parallel:
449
+ state_dict = self.net_.module.state_dict()
450
+ else:
451
+ state_dict = self.net_.state_dict()
452
+
453
+ # attach model hyper parameters
454
+ state_dict['src_modalities'] = self.src_modalities
455
+ state_dict['tgt_modalities'] = self.tgt_modalities
456
+ state_dict['mri_feature'] = self.mri_feature
457
+
458
+ torch.save(state_dict, filepath)
459
+
460
+ def load(self, filepath: str, map_location: str='cpu') -> None:
461
+ ''' ... '''
462
+ # load state_dict
463
+ state_dict = torch.load(filepath, map_location=map_location)
464
+
465
+ # load data modalities
466
+ self.src_modalities = state_dict.pop('src_modalities')
467
+ self.tgt_modalities = state_dict.pop('tgt_modalities')
468
+
469
+ # initialize model
470
+ self.net_ = nn.ResNetModel(
471
+ self.tgt_modalities,
472
+ mri_feature = state_dict.pop('mri_feature')
473
+ )
474
+
475
+ # load model parameters
476
+ self.net_.load_state_dict(state_dict)
477
+ self.net_.to(self.device)
478
+
479
+ @classmethod
480
+ def from_ckpt(cls, filepath: str, device='cpu') -> Self:
481
+ ''' ... '''
482
+ obj = cls(None, None, None,device=device)
483
+ obj.load(filepath)
484
+ return obj
adrd/model/transformer.py ADDED
@@ -0,0 +1,600 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __all__ = ['Transformer']
2
+
3
+ import torch
4
+ from torch.utils.data import DataLoader
5
+ import numpy as np
6
+ import tqdm
7
+ from sklearn.base import BaseEstimator
8
+ from sklearn.utils.validation import check_is_fitted
9
+ from sklearn.model_selection import train_test_split
10
+ from scipy.special import expit
11
+ from copy import deepcopy
12
+ from contextlib import suppress
13
+ from typing import Any, Self, Type
14
+ from functools import wraps
15
+ Tensor = Type[torch.Tensor]
16
+ Module = Type[torch.nn.Module]
17
+
18
+ from .. import nn
19
+ from ..utils import TransformerTrainingDataset
20
+ from ..utils import TransformerBalancedTrainingDataset
21
+ from ..utils import Transformer2ndOrderBalancedTrainingDataset
22
+ from ..utils import TransformerValidationDataset
23
+ from ..utils import TransformerTestingDataset
24
+ from ..utils.misc import ProgressBar
25
+ from ..utils.misc import get_metrics_multitask, print_metrics_multitask
26
+ from ..utils.misc import convert_args_kwargs_to_kwargs
27
+
28
+
29
+ def _manage_ctx_fit(func):
30
+ ''' ... '''
31
+ @wraps(func)
32
+ def wrapper(*args, **kwargs):
33
+ # format arguments
34
+ kwargs = convert_args_kwargs_to_kwargs(func, args, kwargs)
35
+
36
+ if kwargs['self']._device_ids is None:
37
+ return func(**kwargs)
38
+ else:
39
+ # change primary device
40
+ default_device = kwargs['self'].device
41
+ kwargs['self'].device = kwargs['self']._device_ids[0]
42
+ rtn = func(**kwargs)
43
+ kwargs['self'].to(default_device)
44
+ return rtn
45
+ return wrapper
46
+
47
+
48
+ class Transformer(BaseEstimator):
49
+ ''' ... '''
50
+ def __init__(self,
51
+ src_modalities: dict[str, dict[str, Any]],
52
+ tgt_modalities: dict[str, dict[str, Any]],
53
+ d_model: int = 32,
54
+ nhead: int = 1,
55
+ num_layers: int = 1,
56
+ num_epochs: int = 32,
57
+ batch_size: int = 8,
58
+ batch_size_multiplier: int = 1,
59
+ lr: float = 1e-2,
60
+ weight_decay: float = 0.0,
61
+ beta: float = 0.9999,
62
+ gamma: float = 2.0,
63
+ scale: float = 1.0,
64
+ lambd: float = 0.0,
65
+ criterion: str | None = None,
66
+ device: str = 'cpu',
67
+ verbose: int = 0,
68
+ _device_ids: list | None = None,
69
+ _dataloader_num_workers: int = 0,
70
+ _amp_enabled: bool = False,
71
+ ) -> None:
72
+ ''' ... '''
73
+ # for multiprocessing
74
+ self._rank = 0
75
+ self._lock = None
76
+
77
+ # positional parameters
78
+ self.src_modalities = src_modalities
79
+ self.tgt_modalities = tgt_modalities
80
+
81
+ # training parameters
82
+ self.d_model = d_model
83
+ self.nhead = nhead
84
+ self.num_layers = num_layers
85
+ self.num_epochs = num_epochs
86
+ self.batch_size = batch_size
87
+ self.batch_size_multiplier = batch_size_multiplier
88
+ self.lr = lr
89
+ self.weight_decay = weight_decay
90
+ self.beta = beta
91
+ self.gamma = gamma
92
+ self.scale = scale
93
+ self.lambd = lambd
94
+ self.criterion = criterion
95
+ self.device = device
96
+ self.verbose = verbose
97
+ self._device_ids = _device_ids
98
+ self._dataloader_num_workers = _dataloader_num_workers
99
+ self._amp_enabled = _amp_enabled
100
+
101
+ @_manage_ctx_fit
102
+ def fit(self,
103
+ x, y,
104
+ is_embedding: dict[str, bool] | None = None,
105
+ ) -> Self:
106
+ ''' ... '''
107
+ # for PyTorch computational efficiency
108
+ torch.set_num_threads(1)
109
+
110
+ # initialize neural network
111
+ self.net_ = self._init_net()
112
+
113
+ # initialize dataloaders
114
+ ldr_trn, ldr_vld = self._init_dataloader(x, y, is_embedding)
115
+
116
+ # initialize optimizer and scheduler
117
+ optimizer = self._init_optimizer()
118
+ scheduler = self._init_scheduler(optimizer)
119
+
120
+ # gradient scaler for AMP
121
+ if self._amp_enabled: scaler = torch.cuda.amp.GradScaler()
122
+
123
+ # initialize loss function (binary cross entropy)
124
+ loss_func = self._init_loss_func({
125
+ k: (
126
+ sum([_[k] == 0 for _ in ldr_trn.dataset.tgt]),
127
+ sum([_[k] == 1 for _ in ldr_trn.dataset.tgt]),
128
+ ) for k in self.tgt_modalities
129
+ })
130
+
131
+ # to record the best validation performance criterion
132
+ if self.criterion is not None: best_crit = None
133
+
134
+ # progress bar for epoch loops
135
+ if self.verbose == 1:
136
+ with self._lock if self._lock is not None else suppress():
137
+ pbr_epoch = tqdm.tqdm(
138
+ desc = 'Rank {:02d}'.format(self._rank),
139
+ total = self.num_epochs,
140
+ position = self._rank,
141
+ ascii = True,
142
+ leave = False,
143
+ bar_format='{l_bar}{r_bar}'
144
+ )
145
+
146
+ # training loop
147
+ for epoch in range(self.num_epochs):
148
+ # progress bar for batch loops
149
+ if self.verbose > 1:
150
+ pbr_batch = ProgressBar(len(ldr_trn.dataset), 'Epoch {:03d} (TRN)'.format(epoch))
151
+
152
+ # set model to train mode
153
+ torch.set_grad_enabled(True)
154
+ self.net_.train()
155
+
156
+ scores_trn: dict[str, list[float]] = {k: [] for k in self.tgt_modalities}
157
+ y_true_trn: dict[str, list[int]] = {k: [] for k in self.tgt_modalities}
158
+ losses_trn: dict[str, list[float]] = {k: [] for k in self.tgt_modalities}
159
+ for n_iter, (x_batch, y_batch, mask_x, mask_y) in enumerate(ldr_trn):
160
+ # mount data to the proper device
161
+ x_batch = {k: x_batch[k].to(self.device) for k in self.src_modalities}
162
+ y_batch = {k: y_batch[k].to(torch.float).to(self.device) for k in self.tgt_modalities}
163
+ mask_x = {k: mask_x[k].to(self.device) for k in self.src_modalities}
164
+ mask_y = {k: mask_y[k].to(self.device) for k in self.tgt_modalities}
165
+
166
+ # forward
167
+ with torch.autocast(
168
+ device_type = 'cpu' if self.device == 'cpu' else 'cuda',
169
+ dtype = torch.bfloat16 if self.device == 'cpu' else torch.float16,
170
+ enabled = self._amp_enabled,
171
+ ):
172
+ outputs = self.net_(x_batch, mask_x, is_embedding)
173
+
174
+ # calculate multitask loss
175
+ loss = 0
176
+ for i, tgt_k in enumerate(self.tgt_modalities):
177
+ loss_k = loss_func[tgt_k](outputs[tgt_k], y_batch[tgt_k])
178
+ loss_k = torch.masked_select(loss_k, torch.logical_not(mask_y[tgt_k].squeeze()))
179
+ loss += loss_k.mean()
180
+ losses_trn[tgt_k] += loss_k.detach().cpu().numpy().tolist()
181
+
182
+ # if self.lambd != 0:
183
+
184
+ # backward
185
+ if self._amp_enabled:
186
+ scaler.scale(loss).backward()
187
+ else:
188
+ loss.backward()
189
+
190
+ # update parameters
191
+ if n_iter != 0 and n_iter % self.batch_size_multiplier == 0:
192
+ if self._amp_enabled:
193
+ scaler.step(optimizer)
194
+ scaler.update()
195
+ optimizer.zero_grad()
196
+ else:
197
+ optimizer.step()
198
+ optimizer.zero_grad()
199
+
200
+ # save outputs to evaluate performance later
201
+ for tgt_k in self.tgt_modalities:
202
+ tmp = torch.masked_select(outputs[tgt_k], torch.logical_not(mask_y[tgt_k].squeeze()))
203
+ scores_trn[tgt_k] += tmp.detach().cpu().numpy().tolist()
204
+ tmp = torch.masked_select(y_batch[tgt_k], torch.logical_not(mask_y[tgt_k].squeeze()))
205
+ y_true_trn[tgt_k] += tmp.cpu().numpy().tolist()
206
+
207
+ # update progress bar
208
+ if self.verbose > 1:
209
+ batch_size = len(next(iter(x_batch.values())))
210
+ pbr_batch.update(batch_size, {})
211
+ pbr_batch.refresh()
212
+
213
+ # for better tqdm progress bar display
214
+ if self.verbose > 1:
215
+ pbr_batch.close()
216
+
217
+ # set scheduler
218
+ scheduler.step()
219
+
220
+ # calculate and print training performance metrics
221
+ y_pred_trn: dict[str, list[int]] = {k: [] for k in self.tgt_modalities}
222
+ y_prob_trn: dict[str, list[float]] = {k: [] for k in self.tgt_modalities}
223
+ for tgt_k in self.tgt_modalities:
224
+ for i in range(len(scores_trn[tgt_k])):
225
+ y_pred_trn[tgt_k].append(1 if scores_trn[tgt_k][i] > 0 else 0)
226
+ y_prob_trn[tgt_k].append(expit(scores_trn[tgt_k][i]))
227
+ met_trn = get_metrics_multitask(y_true_trn, y_pred_trn, y_prob_trn)
228
+
229
+ # add loss to metrics
230
+ for tgt_k in self.tgt_modalities:
231
+ met_trn[tgt_k]['Loss'] = np.mean(losses_trn[tgt_k])
232
+
233
+ if self.verbose > 2:
234
+ print_metrics_multitask(met_trn)
235
+
236
+ # progress bar for validation
237
+ if self.verbose > 1:
238
+ pbr_batch = ProgressBar(len(ldr_vld.dataset), 'Epoch {:03d} (VLD)'.format(epoch))
239
+
240
+ # set model to validation mode
241
+ torch.set_grad_enabled(False)
242
+ self.net_.eval()
243
+
244
+ scores_vld: dict[str, list[float]] = {k: [] for k in self.tgt_modalities}
245
+ y_true_vld: dict[str, list[int]] = {k: [] for k in self.tgt_modalities}
246
+ losses_vld: dict[str, list[float]] = {k: [] for k in self.tgt_modalities}
247
+ for x_batch, y_batch, mask_x, mask_y in ldr_vld:
248
+ # mount data to the proper device
249
+ x_batch = {k: x_batch[k].to(self.device) for k in self.src_modalities}
250
+ y_batch = {k: y_batch[k].to(torch.float).to(self.device) for k in self.tgt_modalities}
251
+ mask_x = {k: mask_x[k].to(self.device) for k in self.src_modalities}
252
+ mask_y = {k: mask_y[k].to(self.device) for k in self.tgt_modalities}
253
+
254
+ # forward
255
+ with torch.autocast(
256
+ device_type = 'cpu' if self.device == 'cpu' else 'cuda',
257
+ dtype = torch.bfloat16 if self.device == 'cpu' else torch.float16,
258
+ enabled = self._amp_enabled
259
+ ):
260
+ outputs = self.net_(x_batch, mask_x, is_embedding)
261
+
262
+ # calculate multitask loss
263
+ for i, tgt_k in enumerate(self.tgt_modalities):
264
+ loss_k = loss_func[tgt_k](outputs[tgt_k], y_batch[tgt_k])
265
+ loss_k = torch.masked_select(loss_k, torch.logical_not(mask_y[tgt_k].squeeze()))
266
+ losses_vld[tgt_k] += loss_k.detach().cpu().numpy().tolist()
267
+
268
+ # save outputs to evaluate performance later
269
+ for tgt_k in self.tgt_modalities:
270
+ tmp = torch.masked_select(outputs[tgt_k], torch.logical_not(mask_y[tgt_k].squeeze()))
271
+ scores_vld[tgt_k] += tmp.detach().cpu().numpy().tolist()
272
+ tmp = torch.masked_select(y_batch[tgt_k], torch.logical_not(mask_y[tgt_k].squeeze()))
273
+ y_true_vld[tgt_k] += tmp.cpu().numpy().tolist()
274
+
275
+ # update progress bar
276
+ if self.verbose > 1:
277
+ batch_size = len(next(iter(x_batch.values())))
278
+ pbr_batch.update(batch_size, {})
279
+ pbr_batch.refresh()
280
+
281
+ # for better tqdm progress bar display
282
+ if self.verbose > 1:
283
+ pbr_batch.close()
284
+
285
+ # calculate and print validation performance metrics
286
+ y_pred_vld: dict[str, list[int]] = {k: [] for k in self.tgt_modalities}
287
+ y_prob_vld: dict[str, list[float]] = {k: [] for k in self.tgt_modalities}
288
+ for tgt_k in self.tgt_modalities:
289
+ for i in range(len(scores_vld[tgt_k])):
290
+ y_pred_vld[tgt_k].append(1 if scores_vld[tgt_k][i] > 0 else 0)
291
+ y_prob_vld[tgt_k].append(expit(scores_vld[tgt_k][i]))
292
+ met_vld = get_metrics_multitask(y_true_vld, y_pred_vld, y_prob_vld)
293
+
294
+ # add loss to metrics
295
+ for tgt_k in self.tgt_modalities:
296
+ met_vld[tgt_k]['Loss'] = np.mean(losses_vld[tgt_k])
297
+
298
+ if self.verbose > 2:
299
+ print_metrics_multitask(met_vld)
300
+
301
+ # save the model if it has the best validation performance criterion by far
302
+ if self.criterion is None: continue
303
+
304
+ # is current criterion better than previous best?
305
+ curr_crit = np.mean([met_vld[k][self.criterion] for k in self.tgt_modalities])
306
+ if best_crit is None or np.isnan(best_crit):
307
+ is_better = True
308
+ elif self.criterion == 'Loss' and best_crit >= curr_crit:
309
+ is_better = True
310
+ elif self.criterion != 'Loss' and best_crit <= curr_crit:
311
+ is_better = True
312
+ else:
313
+ is_better = False
314
+
315
+ # update best criterion
316
+ if is_better:
317
+ best_crit = curr_crit
318
+ best_state_dict = deepcopy(self.net_.state_dict())
319
+
320
+ if self.verbose > 2:
321
+ print('Best {}: {}'.format(self.criterion, best_crit))
322
+
323
+ if self.verbose == 1:
324
+ with self._lock if self._lock is not None else suppress():
325
+ pbr_epoch.update(1)
326
+ pbr_epoch.refresh()
327
+
328
+ if self.verbose == 1:
329
+ with self._lock if self._lock is not None else suppress():
330
+ pbr_epoch.close()
331
+
332
+ # restore the model of the best validation performance across all epoches
333
+ if ldr_vld is not None and self.criterion is not None:
334
+ self.net_.load_state_dict(best_state_dict)
335
+
336
+ return self
337
+
338
+ def predict_logits(self,
339
+ x: list[dict[str, Any]],
340
+ is_embedding: dict[str, bool] | None = None,
341
+ _batch_size: int | None = None,
342
+ ) -> list[dict[str, float]]:
343
+ '''
344
+ The input x can be a single sample or a list of samples.
345
+ '''
346
+ # input validation
347
+ check_is_fitted(self)
348
+
349
+ # for PyTorch computational efficiency
350
+ torch.set_num_threads(1)
351
+
352
+ # set model to eval mode
353
+ torch.set_grad_enabled(False)
354
+ self.net_.eval()
355
+
356
+ # intialize dataset and dataloader object
357
+ dat = TransformerTestingDataset(x, self.src_modalities, is_embedding)
358
+ ldr = DataLoader(
359
+ dataset = dat,
360
+ batch_size = _batch_size if _batch_size is not None else len(x),
361
+ shuffle = False,
362
+ drop_last = False,
363
+ num_workers = 0,
364
+ collate_fn = TransformerTestingDataset.collate_fn,
365
+ )
366
+
367
+ # run model and collect results
368
+ logits: list[dict[str, float]] = []
369
+ for x_batch, mask_x in ldr:
370
+ # mount data to the proper device
371
+ x_batch = {k: x_batch[k].to(self.device) for k in self.src_modalities}
372
+ mask_x = {k: mask_x[k].to(self.device) for k in self.src_modalities}
373
+
374
+ # forward
375
+ output: dict[str, Tensor] = self.net_(x_batch, mask_x, is_embedding)
376
+
377
+ # convert output from dict-of-list to list of dict, then append
378
+ tmp = {k: output[k].tolist() for k in self.tgt_modalities}
379
+ tmp = [{k: tmp[k][i] for k in self.tgt_modalities} for i in range(len(next(iter(tmp.values()))))]
380
+ logits += tmp
381
+
382
+ return logits
383
+
384
+ def predict_proba(self,
385
+ x: list[dict[str, Any]],
386
+ is_embedding: dict[str, bool] | None = None,
387
+ temperature: float = 1.0,
388
+ _batch_size: int | None = None,
389
+ ) -> list[dict[str, float]]:
390
+ ''' ... '''
391
+ logits = self.predict_logits(x, is_embedding, _batch_size)
392
+ return [{k: expit(smp[k] / temperature) for k in self.tgt_modalities} for smp in logits]
393
+
394
+ def predict(self,
395
+ x: list[dict[str, Any]],
396
+ is_embedding: dict[str, bool] | None = None,
397
+ _batch_size: int | None = None,
398
+ ) -> list[dict[str, int]]:
399
+ ''' ... '''
400
+ logits = self.predict_logits(x, is_embedding, _batch_size)
401
+ return [{k: int(smp[k] > 0.0) for k in self.tgt_modalities} for smp in logits]
402
+
403
+ def save(self, filepath: str) -> None:
404
+ ''' ... '''
405
+ check_is_fitted(self)
406
+ state_dict = self.net_.state_dict()
407
+
408
+ # attach model hyper parameters
409
+ state_dict['src_modalities'] = self.src_modalities
410
+ state_dict['tgt_modalities'] = self.tgt_modalities
411
+ state_dict['d_model'] = self.d_model
412
+ state_dict['nhead'] = self.nhead
413
+ state_dict['num_layers'] = self.num_layers
414
+ torch.save(state_dict, filepath)
415
+
416
+ def load(self, filepath: str) -> None:
417
+ ''' ... '''
418
+ # load state_dict
419
+ state_dict = torch.load(filepath, map_location='cpu')
420
+
421
+ # load essential parameters
422
+ self.src_modalities: dict[str, dict[str, Any]] = state_dict.pop('src_modalities')
423
+ self.tgt_modalities: dict[str, dict[str, Any]] = state_dict.pop('tgt_modalities')
424
+ self.d_model = state_dict.pop('d_model')
425
+ self.nhead = state_dict.pop('nhead')
426
+ self.num_layers = state_dict.pop('num_layers')
427
+
428
+ # initialize model
429
+ self.net_ = nn.Transformer(
430
+ self.src_modalities,
431
+ self.tgt_modalities,
432
+ self.d_model,
433
+ self.nhead,
434
+ self.num_layers,
435
+ )
436
+
437
+ # load model parameters
438
+ self.net_.load_state_dict(state_dict)
439
+ self.to(self.device)
440
+
441
+ def to(self, device: str) -> Self:
442
+ ''' Mount model to the given device. '''
443
+ self.device = device
444
+ if hasattr(self, 'net_'): self.net_ = self.net_.to(device)
445
+ return self
446
+
447
+ @classmethod
448
+ def from_ckpt(cls, filepath: str) -> Self:
449
+ ''' ... '''
450
+ obj = cls(None, None)
451
+ obj.load(filepath)
452
+ return obj
453
+
454
+ def _init_net(self):
455
+ """ ... """
456
+ net = nn.Transformer(
457
+ self.src_modalities,
458
+ self.tgt_modalities,
459
+ self.d_model,
460
+ self.nhead,
461
+ self.num_layers,
462
+ ).to(self.device)
463
+
464
+ # train on multiple GPUs using torch.nn.DataParallel
465
+ if self._device_ids is not None:
466
+ net = torch.nn.DataParallel(net, device_ids=self._device_ids)
467
+
468
+ # intialize model parameters using xavier_uniform
469
+ for p in net.parameters():
470
+ if p.dim() > 1:
471
+ torch.nn.init.xavier_uniform_(p)
472
+
473
+ return net
474
+
475
+ def _init_dataloader(self, x, y, is_embedding):
476
+ """ ... """
477
+ # split dataset
478
+ x_trn, x_vld, y_trn, y_vld = train_test_split(
479
+ x, y, test_size = 0.2, random_state = 0,
480
+ )
481
+
482
+ # initialize dataset and dataloader
483
+ # dat_trn = TransformerTrainingDataset(
484
+ # dat_trn = TransformerBalancedTrainingDataset(
485
+ dat_trn = Transformer2ndOrderBalancedTrainingDataset(
486
+ x_trn, y_trn,
487
+ self.src_modalities,
488
+ self.tgt_modalities,
489
+ dropout_rate = .5,
490
+ # dropout_strategy = 'compensated',
491
+ dropout_strategy = 'permutation',
492
+ )
493
+
494
+ dat_vld = TransformerValidationDataset(
495
+ x_vld, y_vld,
496
+ self.src_modalities,
497
+ self.tgt_modalities,
498
+ is_embedding,
499
+ )
500
+
501
+ ldr_trn = DataLoader(
502
+ dataset = dat_trn,
503
+ batch_size = self.batch_size,
504
+ shuffle = True,
505
+ drop_last = False,
506
+ num_workers = self._dataloader_num_workers,
507
+ collate_fn = TransformerTrainingDataset.collate_fn,
508
+ # pin_memory = True
509
+ )
510
+
511
+ ldr_vld = DataLoader(
512
+ dataset = dat_vld,
513
+ batch_size = self.batch_size,
514
+ shuffle = False,
515
+ drop_last = False,
516
+ num_workers = self._dataloader_num_workers,
517
+ collate_fn = TransformerValidationDataset.collate_fn,
518
+ # pin_memory = True
519
+ )
520
+
521
+ return ldr_trn, ldr_vld
522
+
523
+ def _init_optimizer(self):
524
+ """ ... """
525
+ return torch.optim.AdamW(
526
+ self.net_.parameters(),
527
+ lr = self.lr,
528
+ betas = (0.9, 0.98),
529
+ weight_decay = self.weight_decay
530
+ )
531
+
532
+ def _init_scheduler(self, optimizer):
533
+ """ ... """
534
+ return torch.optim.lr_scheduler.OneCycleLR(
535
+ optimizer = optimizer,
536
+ max_lr = self.lr,
537
+ total_steps = self.num_epochs,
538
+ verbose = (self.verbose > 2)
539
+ )
540
+
541
+ def _init_loss_func(self,
542
+ num_per_cls: dict[str, tuple[int, int]],
543
+ ) -> dict[str, Module]:
544
+ """ ... """
545
+ return {k: nn.SigmoidFocalLoss(
546
+ beta = self.beta,
547
+ gamma = self.gamma,
548
+ scale = self.scale,
549
+ num_per_cls = num_per_cls[k],
550
+ reduction = 'none',
551
+ ) for k in self.tgt_modalities}
552
+
553
+ def _extract_embedding(self,
554
+ x: list[dict[str, Any]],
555
+ is_embedding: dict[str, bool] | None = None,
556
+ _batch_size: int | None = None,
557
+ ) -> list[dict[str, Any]]:
558
+ """ ... """
559
+ # input validation
560
+ check_is_fitted(self)
561
+
562
+ # for PyTorch computational efficiency
563
+ torch.set_num_threads(1)
564
+
565
+ # set model to eval mode
566
+ torch.set_grad_enabled(False)
567
+ self.net_.eval()
568
+
569
+ # intialize dataset and dataloader object
570
+ dat = TransformerTestingDataset(x, self.src_modalities, is_embedding)
571
+ ldr = DataLoader(
572
+ dataset = dat,
573
+ batch_size = _batch_size if _batch_size is not None else len(x),
574
+ shuffle = False,
575
+ drop_last = False,
576
+ num_workers = 0,
577
+ collate_fn = TransformerTestingDataset.collate_fn,
578
+ )
579
+
580
+ # run model and extract embeddings
581
+ embeddings: list[dict[str, Any]] = []
582
+ for x_batch, _ in ldr:
583
+ # mount data to the proper device
584
+ x_batch = {k: x_batch[k].to(self.device) for k in self.src_modalities}
585
+
586
+ # forward
587
+ out: dict[str, Tensor] = self.net_.forward_emb(x_batch, is_embedding)
588
+
589
+ # convert output from dict-of-list to list of dict, then append
590
+ tmp = {k: out[k].detach().cpu().numpy() for k in self.src_modalities}
591
+ tmp = [{k: tmp[k][i] for k in self.src_modalities} for i in range(len(next(iter(tmp.values()))))]
592
+ embeddings += tmp
593
+
594
+ # remove imputed embeddings
595
+ for i in range(len(x)):
596
+ avail = [k for k, v in x[i].items() if v is not None]
597
+ embeddings[i] = {k: embeddings[i][k] for k in avail}
598
+
599
+ return embeddings
600
+
adrd/nn/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .transformer import Transformer
2
+ from .vitautoenc import ViTAutoEnc
3
+ from .unet import UNet3D
4
+ from .unet_3d import UNet3DBase
5
+ from .focal_loss import SigmoidFocalLoss
6
+ from .unet_img_model import ImageModel
7
+ from .img_model_wrapper import ImagingModelWrapper
8
+ from .resnet_img_model import ResNetModel
9
+ from .c3d import C3D
10
+ from .dense_net import DenseNet
11
+ from .cnn_resnet3d import CNNResNet3D
12
+ from .cnn_resnet3d_with_linear_classifier import CNNResNet3DWithLinearClassifier
adrd/nn/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (961 Bytes). View file
 
adrd/nn/__pycache__/blocks.cpython-311.pyc ADDED
Binary file (2.88 kB). View file
 
adrd/nn/__pycache__/c3d.cpython-311.pyc ADDED
Binary file (5.17 kB). View file
 
adrd/nn/__pycache__/cnn_resnet3d.cpython-311.pyc ADDED
Binary file (4.37 kB). View file
 
adrd/nn/__pycache__/cnn_resnet3d_with_linear_classifier.cpython-311.pyc ADDED
Binary file (4.06 kB). View file
 
adrd/nn/__pycache__/dense_net.cpython-311.pyc ADDED
Binary file (13.8 kB). View file
 
adrd/nn/__pycache__/focal_loss.cpython-311.pyc ADDED
Binary file (6.22 kB). View file
 
adrd/nn/__pycache__/img_model_wrapper.cpython-311.pyc ADDED
Binary file (8.7 kB). View file
 
adrd/nn/__pycache__/net_resnet3d.cpython-311.pyc ADDED
Binary file (17.2 kB). View file
 
adrd/nn/__pycache__/resnet3d.cpython-311.pyc ADDED
Binary file (13.3 kB). View file
 
adrd/nn/__pycache__/resnet_img_model.cpython-311.pyc ADDED
Binary file (2.85 kB). View file
 
adrd/nn/__pycache__/selfattention.cpython-311.pyc ADDED
Binary file (3.58 kB). View file
 
adrd/nn/__pycache__/transformer.cpython-311.pyc ADDED
Binary file (14.1 kB). View file
 
adrd/nn/__pycache__/unet.cpython-311.pyc ADDED
Binary file (15.8 kB). View file
 
adrd/nn/__pycache__/unet_3d.cpython-311.pyc ADDED
Binary file (3.03 kB). View file
 
adrd/nn/__pycache__/unet_img_model.cpython-311.pyc ADDED
Binary file (14.1 kB). View file
 
adrd/nn/__pycache__/vitautoenc.cpython-311.pyc ADDED
Binary file (8.61 kB). View file
 
adrd/nn/blocks.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from monai.networks.blocks.mlp import MLPBlock
13
+ from typing import Sequence, Union
14
+ import torch
15
+ import torch.nn as nn
16
+
17
+ from ..nn.selfattention import SABlock
18
+
19
+ class TransformerBlock(nn.Module):
20
+ """
21
+ A transformer block, based on: "Dosovitskiy et al.,
22
+ An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
23
+ """
24
+
25
+ def __init__(
26
+ self, hidden_size: int, mlp_dim: int, num_heads: int, dropout_rate: float = 0.0, qkv_bias: bool = False
27
+ ) -> None:
28
+ """
29
+ Args:
30
+ hidden_size: dimension of hidden layer.
31
+ mlp_dim: dimension of feedforward layer.
32
+ num_heads: number of attention heads.
33
+ dropout_rate: faction of the input units to drop.
34
+ qkv_bias: apply bias term for the qkv linear layer
35
+
36
+ """
37
+
38
+ super().__init__()
39
+
40
+ if not (0 <= dropout_rate <= 1):
41
+ raise ValueError("dropout_rate should be between 0 and 1.")
42
+
43
+ if hidden_size % num_heads != 0:
44
+ raise ValueError("hidden_size should be divisible by num_heads.")
45
+
46
+ self.mlp = MLPBlock(hidden_size, mlp_dim, dropout_rate)
47
+ self.norm1 = nn.LayerNorm(hidden_size)
48
+ self.attn = SABlock(hidden_size, num_heads, dropout_rate, qkv_bias)
49
+ self.norm2 = nn.LayerNorm(hidden_size)
50
+
51
+ def forward(self, x, return_attention=False):
52
+ y, attn = self.attn(self.norm1(x))
53
+ if return_attention:
54
+ return attn
55
+ x = x + y
56
+ x = x + self.mlp(self.norm2(x))
57
+ return x
adrd/nn/c3d.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # From https://github.com/xmuyzz/3D-CNN-PyTorch/blob/master/models/C3DNet.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import sys
6
+ # from icecream import ic
7
+ import math
8
+
9
+ class C3D(torch.nn.Module):
10
+
11
+ def __init__(self, tgt_modalities, in_channels=1, load_from_ckpt=None):
12
+
13
+ super(C3D, self).__init__()
14
+ self.conv_group1 = nn.Sequential(
15
+ nn.Conv3d(in_channels, 64, kernel_size=3, padding=1),
16
+ nn.BatchNorm3d(64),
17
+ nn.ReLU(),
18
+ nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(1, 2, 2)))
19
+ self.conv_group2 = nn.Sequential(
20
+ nn.Conv3d(64, 128, kernel_size=3, padding=1),
21
+ nn.BatchNorm3d(128),
22
+ nn.ReLU(),
23
+ nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2)))
24
+ self.conv_group3 = nn.Sequential(
25
+ nn.Conv3d(128, 256, kernel_size=3, padding=1),
26
+ nn.BatchNorm3d(256),
27
+ nn.ReLU(),
28
+ nn.Conv3d(256, 256, kernel_size=3, padding=1),
29
+ nn.BatchNorm3d(256),
30
+ nn.ReLU(),
31
+ nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2))
32
+ )
33
+ self.conv_group4 = nn.Sequential(
34
+ nn.Conv3d(256, 512, kernel_size=3, padding=1),
35
+ nn.BatchNorm3d(512),
36
+ nn.ReLU(),
37
+ nn.Conv3d(512, 512, kernel_size=3, padding=1),
38
+ nn.BatchNorm3d(512),
39
+ nn.ReLU(),
40
+ nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=(0, 1, 1))
41
+ )
42
+
43
+ # last_duration = int(math.floor(128 / 16))
44
+ # last_size = int(math.ceil(128 / 32))
45
+ self.fc1 = nn.Sequential(
46
+ nn.Linear((512 * 15 * 9 * 9) , 512),
47
+ nn.ReLU(),
48
+ nn.Dropout(0.5))
49
+ self.fc2 = nn.Sequential(
50
+ nn.Linear(512, 256),
51
+ nn.ReLU(),
52
+ nn.Dropout(0.5))
53
+ # self.fc = nn.Sequential(
54
+ # nn.Linear(4096, num_classes))
55
+
56
+ self.fc = torch.nn.ModuleDict()
57
+ for k in tgt_modalities:
58
+ self.fc[k] = torch.nn.Linear(256, 1)
59
+
60
+ def forward(self, x):
61
+ # for k in x.keys():
62
+ # x[k] = x[k].to(torch.float32)
63
+
64
+ # x = torch.stack([o for o in x.values()], dim=0)[0]
65
+ # print(x.shape)
66
+
67
+ out = self.conv_group1(x)
68
+ out = self.conv_group2(out)
69
+ out = self.conv_group3(out)
70
+ out = self.conv_group4(out)
71
+ out = out.view(out.size(0), -1)
72
+ # print(out.shape)
73
+ out = self.fc1(out)
74
+ out = self.fc2(out)
75
+ # out = self.fc(out)
76
+
77
+ tgt_iter = self.fc.keys()
78
+ out_tgt = {k: self.fc[k](out).squeeze(1) for k in tgt_iter}
79
+ return out_tgt
80
+
81
+
82
+ if __name__ == "__main__":
83
+ model = C3D(tgt_modalities=['NC', 'MCI', 'DE'])
84
+ print(model)
85
+ x = torch.rand((1, 1, 128, 128, 128))
86
+ # layers = list(model.features.named_children())
87
+ # features = nn.Sequential(*list(model.features.children()))(x)
88
+ # print(features.shape)
89
+ print(sum(p.numel() for p in model.parameters()))
90
+ # layer_found = False
91
+ # features = None
92
+ # desired_layer_name = 'transition3'
93
+
94
+ # for name, layer in layers:
95
+ # if name == desired_layer_name:
96
+ # x = layer(x)
97
+ # print(x)
98
+ # model(x)
99
+ # print(features)
adrd/nn/cnn_resnet3d.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import Any, Type
4
+ Tensor = Type[torch.Tensor]
5
+
6
+ from .resnet3d import r3d_18
7
+
8
+
9
+ class CNNResNet3D(nn.Module):
10
+
11
+ def __init__(self,
12
+ src_modalities: dict[str, dict[str, Any]],
13
+ tgt_modalities: dict[str, dict[str, Any]]
14
+ ) -> None:
15
+ """ ... """
16
+ super().__init__()
17
+
18
+ # resnet
19
+ # embedding modules for source
20
+ self.modules_emb_src = nn.ModuleDict()
21
+ for k, info in src_modalities.items():
22
+ if info['type'] == 'imaging' and len(info['img_shape']) == 4:
23
+ self.modules_emb_src[k] = nn.Sequential(
24
+ r3d_18(),
25
+ nn.Dropout(0.5)
26
+ )
27
+ else:
28
+ # unrecognized
29
+ raise ValueError('{} is an unrecognized data modality'.format(k))
30
+
31
+ # classifiers (binary only)
32
+ self.modules_cls = nn.ModuleDict()
33
+ for k, info in tgt_modalities.items():
34
+ if info['type'] == 'categorical' and info['num_categories'] == 2:
35
+ # categorical
36
+ self.modules_cls[k] = nn.Linear(256, 1)
37
+ else:
38
+ # unrecognized
39
+ raise ValueError
40
+
41
+ def forward(self,
42
+ x: dict[str, Tensor],
43
+ ) -> dict[str, Tensor]:
44
+ """ ... """
45
+ out_emb = self.forward_emb(x)
46
+ out_emb = out_emb[list(out_emb.keys())[0]]
47
+ out_cls = self.forward_cls(out_emb)
48
+ return out_cls
49
+
50
+ def forward_emb(self,
51
+ x: dict[str, Tensor],
52
+ ) -> dict[str, Tensor]:
53
+ """ ... """
54
+ out_emb = dict()
55
+ for k in self.modules_emb_src.keys():
56
+ out_emb[k] = self.modules_emb_src[k](x[k])
57
+ return out_emb
58
+
59
+ def forward_cls(self,
60
+ out_emb: dict[str, Tensor]
61
+ ) -> dict[str, Tensor]:
62
+ """ ... """
63
+ out_cls = dict()
64
+ for k in self.modules_cls.keys():
65
+ out_cls[k] = self.modules_cls[k](out_emb).squeeze(1)
66
+ return out_cls
67
+
68
+
69
+ # for testing purpose only
70
+ if __name__ == '__main__':
71
+ src_modalities = {
72
+ 'img_MRI_T1': {'type': 'imaging', 'img_shape': [1, 182, 218, 182]}
73
+ }
74
+ tgt_modalities = {
75
+ 'AD': {'type': 'categorical', 'num_categories': 2},
76
+ 'PD': {'type': 'categorical', 'num_categories': 2}
77
+ }
78
+ net = CNNResNet3D(src_modalities, tgt_modalities)
79
+ net.eval()
80
+ x = {'img_MRI_T1': torch.zeros(2, 1, 182, 218, 182)}
81
+ print(net(x))
adrd/nn/cnn_resnet3d_with_linear_classifier.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import Any, Type
4
+ Tensor = Type[torch.Tensor]
5
+
6
+ from .resnet3d import r3d_18
7
+
8
+ class CNNResNet3DWithLinearClassifier(nn.Module):
9
+
10
+ def __init__(self,
11
+ src_modalities: dict[str, dict[str, Any]],
12
+ tgt_modalities: dict[str, dict[str, Any]]
13
+ ) -> None:
14
+ """ ... """
15
+ super().__init__()
16
+ self.core = _CNNResNet3DWithLinearClassifier(len(tgt_modalities))
17
+ self.src_modalities = src_modalities
18
+ self.tgt_modalities = tgt_modalities
19
+
20
+ def forward(self,
21
+ x: dict[str, Tensor],
22
+ ) -> dict[str, Tensor]:
23
+ """ x is expected to be a singleton dictionary """
24
+ src_k = list(x.keys())[0]
25
+ x = x[src_k]
26
+ out = self.core(x)
27
+ out = {tgt_k: out[:, i] for i, tgt_k in enumerate(self.tgt_modalities)}
28
+ return out
29
+
30
+
31
+ class _CNNResNet3DWithLinearClassifier(nn.Module):
32
+
33
+ def __init__(self,
34
+ len_tgt_modalities: int,
35
+ ) -> None:
36
+ """ ... """
37
+ super().__init__()
38
+ self.cnn = r3d_18()
39
+ self.cls = nn.Sequential(
40
+ nn.Dropout(0.5),
41
+ nn.Linear(256, len_tgt_modalities),
42
+ )
43
+
44
+ def forward(self, x: Tensor) -> Tensor:
45
+ """ ... """
46
+ out_emb = self.forward_emb(x)
47
+ out_cls = self.forward_cls(out_emb)
48
+ return out_cls
49
+
50
+ def forward_emb(self, x: Tensor) -> Tensor:
51
+ """ ... """
52
+ return self.cnn(x)
53
+
54
+ def forward_cls(self, out_emb: Tensor) -> Tensor:
55
+ """ ... """
56
+ return self.cls(out_emb)
adrd/nn/dense_net.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This implementation is based on the DenseNet-BC implementation in torchvision
2
+ # https://github.com/pytorch/vision/blob/master/torchvision/models/densenet.py
3
+ # https://github.com/gpleiss/efficient_densenet_pytorch/blob/master/models/densenet.py
4
+
5
+
6
+ import math
7
+ import torch
8
+ import numpy as np
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import torch.utils.checkpoint as cp
12
+ from collections import OrderedDict
13
+
14
+
15
+ def _bn_function_factory(norm, relu, conv):
16
+ def bn_function(*inputs):
17
+ concated_features = torch.cat(inputs, 1)
18
+ bottleneck_output = conv(relu(norm(concated_features)))
19
+ return bottleneck_output
20
+
21
+ return bn_function
22
+
23
+
24
+ class _DenseLayer(nn.Module):
25
+ def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, efficient=False):
26
+ super(_DenseLayer, self).__init__()
27
+ self.add_module('norm1', nn.BatchNorm3d(num_input_features)),
28
+ self.add_module('relu1', nn.ReLU(inplace=True)),
29
+ self.add_module('conv1', nn.Conv3d(num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False)),
30
+ self.add_module('norm2', nn.BatchNorm3d(bn_size * growth_rate)),
31
+ self.add_module('relu2', nn.ReLU(inplace=True)),
32
+ self.add_module('conv2', nn.Conv3d(bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False)),
33
+ self.drop_rate = drop_rate
34
+ self.efficient = efficient
35
+
36
+ def forward(self, *prev_features):
37
+ bn_function = _bn_function_factory(self.norm1, self.relu1, self.conv1)
38
+ if self.efficient and any(prev_feature.requires_grad for prev_feature in prev_features):
39
+ bottleneck_output = cp.checkpoint(bn_function, *prev_features)
40
+ else:
41
+ bottleneck_output = bn_function(*prev_features)
42
+ new_features = self.conv2(self.relu2(self.norm2(bottleneck_output)))
43
+ if self.drop_rate > 0:
44
+ new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)
45
+ return new_features
46
+
47
+
48
+ class _Transition(nn.Sequential):
49
+ def __init__(self, num_input_features, num_output_features):
50
+ super(_Transition, self).__init__()
51
+ self.add_module('norm', nn.BatchNorm3d(num_input_features))
52
+ self.add_module('relu', nn.ReLU(inplace=True))
53
+ self.add_module('conv', nn.Conv3d(num_input_features, num_output_features, kernel_size=1, stride=1, bias=False))
54
+ self.add_module('pool', nn.AvgPool3d(kernel_size=2, stride=2))
55
+
56
+
57
+ class _DenseBlock(nn.Module):
58
+ def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate, efficient=False):
59
+ super(_DenseBlock, self).__init__()
60
+ for i in range(num_layers):
61
+ layer = _DenseLayer(
62
+ num_input_features + i * growth_rate,
63
+ growth_rate=growth_rate,
64
+ bn_size=bn_size,
65
+ drop_rate=drop_rate,
66
+ efficient=efficient,
67
+ )
68
+ self.add_module('denselayer%d' % (i + 1), layer)
69
+
70
+ def forward(self, init_features):
71
+ features = [init_features]
72
+ for name, layer in self.named_children():
73
+ new_features = layer(*features)
74
+ features.append(new_features)
75
+ return torch.cat(features, 1)
76
+
77
+
78
+ class DenseNet(nn.Module):
79
+ r"""Densenet-BC model class, based on
80
+ `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
81
+ Args:
82
+ growth_rate (int) - how many filters to add each layer (`k` in paper)
83
+ block_config (list of 3 or 4 ints) - how many layers in each pooling block
84
+ num_init_features (int) - the number of filters to learn in the first convolution layer
85
+ bn_size (int) - multiplicative factor for number of bottle neck layers
86
+ (i.e. bn_size * k features in the bottleneck layer)
87
+ drop_rate (float) - dropout rate after each dense layer
88
+ tgt_modalities (list) - list of target modalities
89
+ efficient (bool) - set to True to use checkpointing. Much more memory efficient, but slower.
90
+ """
91
+ # def __init__(self, tgt_modalities, growth_rate=12, block_config=(3, 3, 3), compression=0.5,
92
+ # num_init_features=16, bn_size=4, drop_rate=0, efficient=False, load_from_ckpt=False): # config 1
93
+
94
+ def __init__(self, tgt_modalities, growth_rate=12, block_config=(3, 3, 3), compression=0.5,
95
+ num_init_features=16, bn_size=4, drop_rate=0, efficient=False, load_from_ckpt=False): # config 2
96
+
97
+ super(DenseNet, self).__init__()
98
+
99
+ # First convolution
100
+ self.features = nn.Sequential(OrderedDict([('conv0', nn.Conv3d(1, num_init_features, kernel_size=7, stride=2, padding=0, bias=False)),]))
101
+ self.features.add_module('norm0', nn.BatchNorm3d(num_init_features))
102
+ self.features.add_module('relu0', nn.ReLU(inplace=True))
103
+ self.features.add_module('pool0', nn.MaxPool3d(kernel_size=3, stride=2, padding=0, ceil_mode=False))
104
+ self.tgt_modalities = tgt_modalities
105
+
106
+ # Each denseblock
107
+ num_features = num_init_features
108
+ for i, num_layers in enumerate(block_config):
109
+ block = _DenseBlock(
110
+ num_layers=num_layers,
111
+ num_input_features=num_features,
112
+ bn_size=bn_size,
113
+ growth_rate=growth_rate,
114
+ drop_rate=drop_rate,
115
+ efficient=efficient,
116
+ )
117
+ self.features.add_module('denseblock%d' % (i + 1), block)
118
+ num_features = num_features + num_layers * growth_rate
119
+ if i != len(block_config):
120
+ trans = _Transition(num_input_features=num_features,
121
+ num_output_features=int(num_features * compression))
122
+ self.features.add_module('transition%d' % (i + 1), trans)
123
+ num_features = int(num_features * compression)
124
+
125
+ # Final batch norm
126
+ self.features.add_module('norm_final', nn.BatchNorm3d(num_features))
127
+
128
+ # Classification heads
129
+ self.tgt = torch.nn.ModuleDict()
130
+ for k in tgt_modalities:
131
+ # self.tgt[k] = torch.nn.Linear(621, 1) # config 2
132
+ self.tgt[k] = torch.nn.Sequential(
133
+ torch.nn.Linear(self.test_size(), 256),
134
+ torch.nn.ReLU(),
135
+ torch.nn.Linear(256, 1)
136
+ )
137
+
138
+ print(f'load_from_ckpt: {load_from_ckpt}')
139
+ # Initialization
140
+ if not load_from_ckpt:
141
+ for name, param in self.named_parameters():
142
+ if 'conv' in name and 'weight' in name:
143
+ n = param.size(0) * param.size(2) * param.size(3) * param.size(4)
144
+ param.data.normal_().mul_(math.sqrt(2. / n))
145
+ elif 'norm' in name and 'weight' in name:
146
+ param.data.fill_(1)
147
+ elif 'norm' in name and 'bias' in name:
148
+ param.data.fill_(0)
149
+ elif ('classifier' in name or 'tgt' in name) and 'bias' in name:
150
+ param.data.fill_(0)
151
+
152
+ # self.size = self.test_size()
153
+
154
+ def forward(self, x, shap=True):
155
+ # print(x.shape)
156
+ features = self.features(x)
157
+ # print(features.shape)
158
+ out = F.relu(features, inplace=True)
159
+ # out = F.adaptive_avg_pool3d(out, (1, 1, 1))
160
+ out = torch.flatten(out, 1)
161
+
162
+ # print(out.shape)
163
+
164
+ # out_tgt = self.tgt(out).squeeze(1)
165
+ # print(out_tgt)
166
+ # return F.softmax(out_tgt)
167
+
168
+ tgt_iter = self.tgt.keys()
169
+ out_tgt = {k: self.tgt[k](out).squeeze(1) for k in tgt_iter}
170
+ if shap:
171
+ out_tgt = torch.stack(list(out_tgt.values()))
172
+ return out_tgt.T
173
+ else:
174
+ return out_tgt
175
+
176
+ def test_size(self):
177
+ case = torch.ones((1, 1, 182, 218, 182))
178
+ output = self.features(case).view(-1).size(0)
179
+ return output
180
+
181
+
182
+ if __name__ == "__main__":
183
+ model = DenseNet(
184
+ tgt_modalities=['NC', 'MCI', 'DE'],
185
+ growth_rate=12,
186
+ block_config=(2, 3, 2),
187
+ compression=0.5,
188
+ num_init_features=16,
189
+ drop_rate=0.2)
190
+ print(model)
191
+ torch.manual_seed(42)
192
+ x = torch.rand((1, 1, 182, 218, 182))
193
+ # layers = list(model.features.named_children())
194
+ features = nn.Sequential(*list(model.features.children()))(x)
195
+ print(features.shape)
196
+ print(sum(p.numel() for p in model.parameters()))
197
+ # out = mdl.net_(x, shap=False)
198
+ # print(out)
199
+
200
+ out = model(x, shap=False)
201
+ print(out)
202
+ # layer_found = False
203
+ # features = None
204
+ # desired_layer_name = 'transition3'
205
+
206
+ # for name, layer in layers:
207
+ # if name == desired_layer_name:
208
+ # x = layer(x)
209
+ # print(x)
210
+ # model(x)
211
+ # print(features)
adrd/nn/focal_loss.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import sys
5
+
6
+ class SigmoidFocalLoss(nn.Module):
7
+ ''' ... '''
8
+ def __init__(
9
+ self,
10
+ alpha: float = -1,
11
+ gamma: float = 2.0,
12
+ reduction: str = 'mean',
13
+ ):
14
+ ''' ... '''
15
+ super().__init__()
16
+ self.alpha = alpha
17
+ self.gamma = gamma
18
+ self.reduction = reduction
19
+
20
+ def forward(self, input, target):
21
+ ''' ... '''
22
+ p = torch.sigmoid(input)
23
+ ce_loss = F.binary_cross_entropy_with_logits(input, target, reduction='none')
24
+ p_t = p * target + (1 - p) * (1 - target)
25
+ loss = ce_loss * ((1 - p_t) ** self.gamma)
26
+
27
+ if self.alpha >= 0:
28
+ alpha_t = self.alpha * target + (1 - self.alpha) * (1 - target)
29
+ loss = alpha_t * loss
30
+
31
+ if self.reduction == 'mean':
32
+ loss = loss.mean()
33
+ elif self.reduction == 'sum':
34
+ loss = loss.sum()
35
+
36
+ return loss
37
+
38
+
39
+ class SigmoidFocalLossBeta(nn.Module):
40
+ ''' ... '''
41
+ def __init__(
42
+ self,
43
+ beta: float = 0.9999,
44
+ gamma: float = 2.0,
45
+ num_per_cls = (1, 1),
46
+ reduction: str = 'mean',
47
+ ):
48
+ ''' ... '''
49
+ super().__init__()
50
+ eps = sys.float_info.epsilon
51
+ self.gamma = gamma
52
+ self.reduction = reduction
53
+
54
+ # weights to balance loss
55
+ self.weight_neg = ((1 - beta) / (1 - beta ** num_per_cls[0] + eps))
56
+ self.weight_pos = ((1 - beta) / (1 - beta ** num_per_cls[1] + eps))
57
+ # weight_neg = (1 - beta) / (1 - beta ** num_per_cls[0])
58
+ # weight_pos = (1 - beta) / (1 - beta ** num_per_cls[1])
59
+ # self.weight_neg = weight_neg / (weight_neg + weight_pos)
60
+ # self.weight_pos = weight_pos / (weight_neg + weight_pos)
61
+
62
+ def forward(self, input, target):
63
+ ''' ... '''
64
+ p = torch.sigmoid(input)
65
+ p_t = p * target + (1 - p) * (1 - target)
66
+ ce_loss = F.binary_cross_entropy_with_logits(input, target, reduction='none')
67
+ loss = ce_loss * ((1 - p_t) ** self.gamma)
68
+
69
+ alpha_t = self.weight_pos * target + self.weight_neg * (1 - target)
70
+ loss = alpha_t * loss
71
+
72
+ if self.reduction == 'mean':
73
+ loss = loss.mean()
74
+ elif self.reduction == 'sum':
75
+ loss = loss.sum()
76
+
77
+ return loss
78
+
79
+ class AsymmetricLoss(nn.Module):
80
+ def __init__(self, gamma_neg=4, gamma_pos=1, alpha=0.5, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=True):
81
+ super(AsymmetricLoss, self).__init__()
82
+ self.alpha = alpha
83
+ self.gamma_neg = gamma_neg
84
+ self.gamma_pos = gamma_pos
85
+ self.clip = clip
86
+ self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
87
+ self.eps = eps
88
+
89
+
90
+ def forward(self, x, y):
91
+ """"
92
+ Parameters
93
+ ----------
94
+ x: input logits
95
+ y: targets (multi-label binarized vector)
96
+ """
97
+ # Calculating Probabilities
98
+ x_sigmoid = torch.sigmoid(x)
99
+ xs_pos = x_sigmoid
100
+ xs_neg = 1 - x_sigmoid
101
+ # Asymmetric Clipping
102
+ if self.clip is not None and self.clip > 0:
103
+ xs_neg = (xs_neg + self.clip).clamp(max=1)
104
+ # Basic CE calculation
105
+ los_pos = y * torch.log(xs_pos.clamp(min=self.eps))
106
+ los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps))
107
+ loss = self.alpha*los_pos + (1-self.alpha)*los_neg
108
+ # Asymmetric Focusing
109
+ if self.gamma_neg > 0 or self.gamma_pos > 0:
110
+ if self.disable_torch_grad_focal_loss:
111
+ torch.set_grad_enabled(False)
112
+ pt0 = xs_pos * y
113
+ pt1 = xs_neg * (1 - y) # pt = p if t > 0 else 1-p
114
+ pt = pt0 + pt1
115
+ one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y)
116
+ one_sided_w = torch.pow(1 - pt, one_sided_gamma)
117
+ if self.disable_torch_grad_focal_loss:
118
+ torch.set_grad_enabled(True)
119
+ loss *= one_sided_w
120
+ return -loss#.sum()
adrd/nn/img_model_wrapper.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .. import nn
3
+ from .. import model
4
+ import numpy as np
5
+ from icecream import ic
6
+ from monai.networks.nets.swin_unetr import SwinUNETR
7
+ from typing import Any
8
+
9
+ class ImagingModelWrapper(torch.nn.Module):
10
+ def __init__(
11
+ self,
12
+ arch: str = 'ViTAutoEnc',
13
+ tgt_modalities: dict | None = {},
14
+ img_size: int | None = 128,
15
+ patch_size: int | None = 16,
16
+ ckpt_path: str | None = None,
17
+ train_backbone: bool = False,
18
+ out_dim: int = 128,
19
+ layers: int | None = 1,
20
+ device: str = 'cpu',
21
+ fusion_stage: str = 'middle',
22
+ ):
23
+ super(ImagingModelWrapper, self).__init__()
24
+
25
+ self.arch = arch
26
+ self.tgt_modalities = tgt_modalities
27
+ self.img_size = img_size
28
+ self.patch_size = patch_size
29
+ self.train_backbone = train_backbone
30
+ self.ckpt_path = ckpt_path
31
+ self.device = device
32
+ self.out_dim = out_dim
33
+ self.layers = layers
34
+ self.fusion_stage = fusion_stage
35
+
36
+
37
+ if "swinunetr" in self.arch.lower():
38
+ if "emb" not in self.arch.lower():
39
+ ckpt_path = '/projectnb/ivc-ml/dlteif/pretrained_models/model_swinvit.pt'
40
+ ckpt = torch.load(ckpt_path, map_location='cpu')
41
+ self.img_model = SwinUNETR(
42
+ in_channels=1,
43
+ out_channels=1,
44
+ img_size=128,
45
+ feature_size=48,
46
+ use_checkpoint=True,
47
+ )
48
+ ckpt["state_dict"] = {k.replace("swinViT.", "module."): v for k, v in ckpt["state_dict"].items()}
49
+ ic(ckpt["state_dict"].keys())
50
+ self.img_model.load_from(ckpt)
51
+ self.dim = 768
52
+
53
+ elif "vit" in self.arch.lower():
54
+ if "emb" not in self.arch.lower():
55
+ # Initialize image model
56
+ self.img_model = nn.__dict__[self.arch](
57
+ in_channels = 1,
58
+ img_size = self.img_size,
59
+ patch_size = self.patch_size,
60
+ )
61
+
62
+ if self.ckpt_path:
63
+ self.img_model.load(self.ckpt_path, map_location=self.device)
64
+ self.dim = self.img_model.hidden_size
65
+ else:
66
+ self.dim = 768
67
+
68
+ if "vit" in self.arch.lower() or "swinunetr" in self.arch.lower():
69
+ dim = self.dim
70
+ if self.fusion_stage == 'middle':
71
+ downsample = torch.nn.ModuleList()
72
+ # print('Number of layers: ', self.layers)
73
+ for i in range(self.layers):
74
+ if i == self.layers - 1:
75
+ dim_out = self.out_dim
76
+ # print(layers)
77
+ ks = 2
78
+ stride = 2
79
+ else:
80
+ dim_out = dim // 2
81
+ ks = 2
82
+ stride = 2
83
+
84
+ downsample.append(
85
+ torch.nn.Conv1d(in_channels=dim, out_channels=dim_out, kernel_size=ks, stride=stride)
86
+ )
87
+
88
+ dim = dim_out
89
+
90
+ downsample.append(
91
+ torch.nn.BatchNorm1d(dim)
92
+ )
93
+ downsample.append(
94
+ torch.nn.ReLU()
95
+ )
96
+
97
+
98
+ self.downsample = torch.nn.Sequential(*downsample)
99
+ elif self.fusion_stage == 'late':
100
+ self.downsample = torch.nn.Identity()
101
+ else:
102
+ pass
103
+
104
+ # print('Downsample layers: ', self.downsample)
105
+
106
+ elif "densenet" in self.arch.lower():
107
+ if "emb" not in self.arch.lower():
108
+ self.img_model = model.ImagingModel.from_ckpt(self.ckpt_path, device=self.device, img_backend=self.arch, load_from_ckpt=True).net_
109
+
110
+ self.downsample = torch.nn.Linear(3900, self.out_dim)
111
+
112
+ # randomly initialize weights for downsample block
113
+ for p in self.downsample.parameters():
114
+ if p.dim() > 1:
115
+ torch.nn.init.xavier_uniform_(p)
116
+ p.requires_grad = True
117
+
118
+ if "emb" not in self.arch.lower():
119
+ # freeze imaging model parameters
120
+ if "densenet" in self.arch.lower():
121
+ for n, p in self.img_model.features.named_parameters():
122
+ if not self.train_backbone:
123
+ p.requires_grad = False
124
+ else:
125
+ p.requires_grad = True
126
+ for n, p in self.img_model.tgt.named_parameters():
127
+ p.requires_grad = False
128
+ else:
129
+ for n, p in self.img_model.named_parameters():
130
+ # print(n, p.requires_grad)
131
+ if not self.train_backbone:
132
+ p.requires_grad = False
133
+ else:
134
+ p.requires_grad = True
135
+
136
+ def forward(self, x):
137
+ # print("--------ImagingModelWrapper forward--------")
138
+ if "emb" not in self.arch.lower():
139
+ if "swinunetr" in self.arch.lower():
140
+ # print(x.size())
141
+ out = self.img_model(x)
142
+ # print(out.size())
143
+ out = self.downsample(out)
144
+ # print(out.size())
145
+ out = torch.mean(out, dim=-1)
146
+ # print(out.size())
147
+ elif "vit" in self.arch.lower():
148
+ out = self.img_model(x, return_emb=True)
149
+ ic(out.size())
150
+ out = self.downsample(out)
151
+ out = torch.mean(out, dim=-1)
152
+ elif "densenet" in self.arch.lower():
153
+ out = torch.nn.Sequential(*list(self.img_model.features.children()))(x)
154
+ # print(out.size())
155
+ out = torch.flatten(out, 1)
156
+ out = self.downsample(out)
157
+ else:
158
+ # print(x.size())
159
+ if "swinunetr" in self.arch.lower():
160
+ x = torch.squeeze(x, dim=1)
161
+ x = x.view(x.size(0),self.dim, -1)
162
+ # print('x: ', x.size())
163
+ out = self.downsample(x)
164
+ # print('out: ', out.size())
165
+ if self.fusion_stage == 'middle':
166
+ if "vit" in self.arch.lower() or "swinunetr" in self.arch.lower():
167
+ out = torch.mean(out, dim=-1)
168
+ else:
169
+ out = torch.squeeze(out, dim=1)
170
+ elif self.fusion_stage == 'late':
171
+ pass
172
+
173
+ return out
174
+
adrd/nn/net_resnet3d.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Created on Sat Nov 21 10:49:39 2021
3
+
4
+ @author: cxue2
5
+ """
6
+
7
+ import torch.nn as nn
8
+
9
+
10
+ __all__ = ['r3d_18', 'mc3_18', 'r2plus1d_18']
11
+
12
+
13
+ class Conv3DSimple(nn.Conv3d):
14
+ def __init__(self,
15
+ in_planes,
16
+ out_planes,
17
+ midplanes=None,
18
+ stride=1,
19
+ padding=1):
20
+
21
+ super(Conv3DSimple, self).__init__(
22
+ in_channels=in_planes,
23
+ out_channels=out_planes,
24
+ kernel_size=(3, 3, 3),
25
+ stride=stride,
26
+ padding=padding,
27
+ bias=False)
28
+
29
+ @staticmethod
30
+ def get_downsample_stride(stride):
31
+ return stride, stride, stride
32
+
33
+
34
+ class Conv2Plus1D(nn.Sequential):
35
+
36
+ def __init__(self,
37
+ in_planes,
38
+ out_planes,
39
+ midplanes,
40
+ stride=1,
41
+ padding=1):
42
+ super(Conv2Plus1D, self).__init__(
43
+ nn.Conv3d(in_planes, midplanes, kernel_size=(1, 3, 3),
44
+ stride=(1, stride, stride), padding=(0, padding, padding),
45
+ bias=False),
46
+ nn.BatchNorm3d(midplanes),
47
+ nn.ReLU(inplace=True),
48
+ nn.Conv3d(midplanes, out_planes, kernel_size=(3, 1, 1),
49
+ stride=(stride, 1, 1), padding=(padding, 0, 0),
50
+ bias=False))
51
+
52
+ @staticmethod
53
+ def get_downsample_stride(stride):
54
+ return stride, stride, stride
55
+
56
+
57
+ class Conv3DNoTemporal(nn.Conv3d):
58
+
59
+ def __init__(self,
60
+ in_planes,
61
+ out_planes,
62
+ midplanes=None,
63
+ stride=1,
64
+ padding=1):
65
+
66
+ super(Conv3DNoTemporal, self).__init__(
67
+ in_channels=in_planes,
68
+ out_channels=out_planes,
69
+ kernel_size=(1, 3, 3),
70
+ stride=(1, stride, stride),
71
+ padding=(0, padding, padding),
72
+ bias=False)
73
+
74
+ @staticmethod
75
+ def get_downsample_stride(stride):
76
+ return 1, stride, stride
77
+
78
+
79
+ class BasicBlock(nn.Module):
80
+
81
+ expansion = 1
82
+
83
+ def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None):
84
+ midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes)
85
+
86
+ super(BasicBlock, self).__init__()
87
+ self.conv1 = nn.Sequential(
88
+ conv_builder(inplanes, planes, midplanes, stride),
89
+ nn.BatchNorm3d(planes),
90
+ nn.ReLU(inplace=True)
91
+ )
92
+ self.conv2 = nn.Sequential(
93
+ conv_builder(planes, planes, midplanes),
94
+ nn.BatchNorm3d(planes)
95
+ )
96
+ self.relu = nn.ReLU(inplace=True)
97
+ self.downsample = downsample
98
+ self.stride = stride
99
+
100
+ def forward(self, x):
101
+ residual = x
102
+
103
+ out = self.conv1(x)
104
+ out = self.conv2(out)
105
+ if self.downsample is not None:
106
+ residual = self.downsample(x)
107
+
108
+ out += residual
109
+ out = self.relu(out)
110
+
111
+ return out
112
+
113
+
114
+ class Bottleneck(nn.Module):
115
+ expansion = 4
116
+
117
+ def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None):
118
+
119
+ super(Bottleneck, self).__init__()
120
+ midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes)
121
+
122
+ # 1x1x1
123
+ self.conv1 = nn.Sequential(
124
+ nn.Conv3d(inplanes, planes, kernel_size=1, bias=False),
125
+ nn.BatchNorm3d(planes),
126
+ nn.ReLU(inplace=True)
127
+ )
128
+ # Second kernel
129
+ self.conv2 = nn.Sequential(
130
+ conv_builder(planes, planes, midplanes, stride),
131
+ nn.BatchNorm3d(planes),
132
+ nn.ReLU(inplace=True)
133
+ )
134
+
135
+ # 1x1x1
136
+ self.conv3 = nn.Sequential(
137
+ nn.Conv3d(planes, planes * self.expansion, kernel_size=1, bias=False),
138
+ nn.BatchNorm3d(planes * self.expansion)
139
+ )
140
+ self.relu = nn.ReLU(inplace=True)
141
+ self.downsample = downsample
142
+ self.stride = stride
143
+
144
+ def forward(self, x):
145
+ residual = x
146
+
147
+ out = self.conv1(x)
148
+ out = self.conv2(out)
149
+ out = self.conv3(out)
150
+
151
+ if self.downsample is not None:
152
+ residual = self.downsample(x)
153
+
154
+ out += residual
155
+ out = self.relu(out)
156
+
157
+ return out
158
+
159
+
160
+ class BasicStem(nn.Sequential):
161
+ """The default conv-batchnorm-relu stem
162
+ """
163
+ def __init__(self):
164
+ super(BasicStem, self).__init__(
165
+ nn.Conv3d(1, 64, kernel_size=(7, 7, 7), stride=(2, 2, 2),
166
+ padding=(3, 3, 3), bias=False),
167
+ nn.BatchNorm3d(64),
168
+ nn.ReLU(inplace=True))
169
+
170
+
171
+ class R2Plus1dStem(nn.Sequential):
172
+ """R(2+1)D stem is different than the default one as it uses separated 3D convolution
173
+ """
174
+ def __init__(self):
175
+ super(R2Plus1dStem, self).__init__(
176
+ nn.Conv3d(3, 45, kernel_size=(1, 7, 7),
177
+ stride=(1, 2, 2), padding=(0, 3, 3),
178
+ bias=False),
179
+ nn.BatchNorm3d(45),
180
+ nn.ReLU(inplace=True),
181
+ nn.Conv3d(45, 64, kernel_size=(3, 1, 1),
182
+ stride=(1, 1, 1), padding=(1, 0, 0),
183
+ bias=False),
184
+ nn.BatchNorm3d(64),
185
+ nn.ReLU(inplace=True))
186
+
187
+
188
+ class VideoResNet(nn.Module):
189
+
190
+ def __init__(self, block, conv_makers, layers,
191
+ stem, num_classes=16,
192
+ zero_init_residual=False):
193
+ """Generic resnet video generator.
194
+ Args:
195
+ block (nn.Module): resnet building block
196
+ conv_makers (list(functions)): generator function for each layer
197
+ layers (List[int]): number of blocks per layer
198
+ stem (nn.Module, optional): Resnet stem, if None, defaults to conv-bn-relu. Defaults to None.
199
+ num_classes (int, optional): Dimension of the final FC layer. Defaults to 400.
200
+ zero_init_residual (bool, optional): Zero init bottleneck residual BN. Defaults to False.
201
+ """
202
+ super(VideoResNet, self).__init__()
203
+ self.inplanes = 64
204
+
205
+ self.stem = stem()
206
+
207
+ self.layer1 = self._make_layer(block, conv_makers[0], 64, layers[0], stride=1)
208
+ self.layer2 = self._make_layer(block, conv_makers[1], 128, layers[1], stride=2)
209
+ self.layer3 = self._make_layer(block, conv_makers[2], 192, layers[2], stride=2)
210
+ self.layer4 = self._make_layer(block, conv_makers[3], 256, layers[3], stride=2)
211
+
212
+ self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))
213
+ self.fc = nn.Linear(256 * block.expansion, num_classes)
214
+
215
+ # init weights
216
+ self._initialize_weights()
217
+
218
+ if zero_init_residual:
219
+ for m in self.modules():
220
+ if isinstance(m, Bottleneck):
221
+ nn.init.constant_(m.bn3.weight, 0)
222
+
223
+ def forward(self, x):
224
+ x = self.stem(x)
225
+
226
+ x = self.layer1(x)
227
+ x = self.layer2(x)
228
+ x = self.layer3(x)
229
+ x = self.layer4(x)
230
+
231
+ x = self.avgpool(x)
232
+ # Flatten the layer to fc
233
+ x = x.flatten(1)
234
+ x = self.fc(x)
235
+
236
+ return x
237
+
238
+ def _make_layer(self, block, conv_builder, planes, blocks, stride=1):
239
+ downsample = None
240
+
241
+ if stride != 1 or self.inplanes != planes * block.expansion:
242
+ ds_stride = conv_builder.get_downsample_stride(stride)
243
+ downsample = nn.Sequential(
244
+ nn.Conv3d(self.inplanes, planes * block.expansion,
245
+ kernel_size=1, stride=ds_stride, bias=False),
246
+ nn.BatchNorm3d(planes * block.expansion)
247
+ )
248
+ layers = []
249
+ layers.append(block(self.inplanes, planes, conv_builder, stride, downsample))
250
+
251
+ self.inplanes = planes * block.expansion
252
+ for i in range(1, blocks):
253
+ layers.append(block(self.inplanes, planes, conv_builder))
254
+
255
+ return nn.Sequential(*layers)
256
+
257
+ def _initialize_weights(self):
258
+ for m in self.modules():
259
+ if isinstance(m, nn.Conv3d):
260
+ nn.init.kaiming_normal_(m.weight, mode='fan_out',
261
+ nonlinearity='relu')
262
+ if m.bias is not None:
263
+ nn.init.constant_(m.bias, 0)
264
+ elif isinstance(m, nn.BatchNorm3d):
265
+ nn.init.constant_(m.weight, 1)
266
+ nn.init.constant_(m.bias, 0)
267
+ elif isinstance(m, nn.Linear):
268
+ nn.init.normal_(m.weight, 0, 0.01)
269
+ nn.init.constant_(m.bias, 0)
270
+
271
+
272
+ def _video_resnet(arch, pretrained=False, progress=True, **kwargs):
273
+ model = VideoResNet(**kwargs)
274
+
275
+ return model
276
+
277
+
278
+ def r3d_18(pretrained=False, progress=True, **kwargs):
279
+ """Construct 18 layer Resnet3D model as in
280
+ https://arxiv.org/abs/1711.11248
281
+ Args:
282
+ pretrained (bool): If True, returns a model pre-trained on Kinetics-400
283
+ progress (bool): If True, displays a progress bar of the download to stderr
284
+ Returns:
285
+ nn.Module: R3D-18 network
286
+ """
287
+
288
+ return _video_resnet('r3d_18',
289
+ pretrained, progress,
290
+ block=BasicBlock,
291
+ conv_makers=[Conv3DSimple] * 4,
292
+ layers=[2, 2, 2, 2],
293
+ stem=BasicStem, **kwargs)
294
+
295
+
296
+ def mc3_18(pretrained=False, progress=True, **kwargs):
297
+ """Constructor for 18 layer Mixed Convolution network as in
298
+ https://arxiv.org/abs/1711.11248
299
+ Args:
300
+ pretrained (bool): If True, returns a model pre-trained on Kinetics-400
301
+ progress (bool): If True, displays a progress bar of the download to stderr
302
+ Returns:
303
+ nn.Module: MC3 Network definition
304
+ """
305
+ return _video_resnet('mc3_18',
306
+ pretrained, progress,
307
+ block=BasicBlock,
308
+ conv_makers=[Conv3DSimple] + [Conv3DNoTemporal] * 3,
309
+ layers=[2, 2, 2, 2],
310
+ stem=BasicStem, **kwargs)
311
+
312
+
313
+ def r2plus1d_18(pretrained=False, progress=True, **kwargs):
314
+ """Constructor for the 18 layer deep R(2+1)D network as in
315
+ https://arxiv.org/abs/1711.11248
316
+ Args:
317
+ pretrained (bool): If True, returns a model pre-trained on Kinetics-400
318
+ progress (bool): If True, displays a progress bar of the download to stderr
319
+ Returns:
320
+ nn.Module: R(2+1)D-18 network
321
+ """
322
+ return _video_resnet('r2plus1d_18',
323
+ pretrained, progress,
324
+ block=BasicBlock,
325
+ conv_makers=[Conv2Plus1D] * 4,
326
+ layers=[2, 2, 2, 2],
327
+ stem=R2Plus1dStem, **kwargs)
328
+
329
+
330
+ if __name__ == '__main__':
331
+
332
+ import torch
333
+
334
+ net = r3d_18().to(0)
335
+ x = torch.zeros(3, 1, 182, 218, 182).to(0)
336
+
337
+ print(net(x).shape)
338
+ print(net)
adrd/nn/resnet3d.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Simplified from torchvision.models.video.r3d_18. The citation information is
3
+ shown below.
4
+
5
+ @article{DBLP:journals/corr/abs-1711-11248,
6
+ author = {Du Tran and
7
+ Heng Wang and
8
+ Lorenzo Torresani and
9
+ Jamie Ray and
10
+ Yann LeCun and
11
+ Manohar Paluri},
12
+ title = {A Closer Look at Spatiotemporal Convolutions for Action Recognition},
13
+ journal = {CoRR},
14
+ volume = {abs/1711.11248},
15
+ year = {2017},
16
+ url = {http://arxiv.org/abs/1711.11248},
17
+ archivePrefix = {arXiv},
18
+ eprint = {1711.11248},
19
+ timestamp = {Mon, 13 Aug 2018 16:46:39 +0200},
20
+ biburl = {https://dblp.org/rec/journals/corr/abs-1711-11248.bib},
21
+ bibsource = {dblp computer science bibliography, https://dblp.org}
22
+ }
23
+ """
24
+
25
+ import torch.nn as nn
26
+
27
+
28
+ class Conv3DSimple(nn.Conv3d):
29
+ def __init__(self,
30
+ in_planes,
31
+ out_planes,
32
+ midplanes=None,
33
+ stride=1,
34
+ padding=1):
35
+
36
+ super().__init__(
37
+ in_channels=in_planes,
38
+ out_channels=out_planes,
39
+ kernel_size=(3, 3, 3),
40
+ stride=stride,
41
+ padding=padding,
42
+ bias=False)
43
+
44
+ @staticmethod
45
+ def get_downsample_stride(stride):
46
+ return stride, stride, stride
47
+
48
+
49
+ class BasicBlock(nn.Module):
50
+
51
+ expansion = 1
52
+
53
+ def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None):
54
+ midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes)
55
+
56
+ super(BasicBlock, self).__init__()
57
+ self.conv1 = nn.Sequential(
58
+ conv_builder(inplanes, planes, midplanes, stride),
59
+ nn.BatchNorm3d(planes),
60
+ nn.ReLU(inplace=True)
61
+ )
62
+ self.conv2 = nn.Sequential(
63
+ conv_builder(planes, planes, midplanes),
64
+ nn.BatchNorm3d(planes)
65
+ )
66
+ self.relu = nn.ReLU(inplace=True)
67
+ self.downsample = downsample
68
+ self.stride = stride
69
+
70
+ def forward(self, x):
71
+ residual = x
72
+
73
+ out = self.conv1(x)
74
+ out = self.conv2(out)
75
+ if self.downsample is not None:
76
+ residual = self.downsample(x)
77
+
78
+ out += residual
79
+ out = self.relu(out)
80
+
81
+ return out
82
+
83
+
84
+ class Bottleneck(nn.Module):
85
+ expansion = 4
86
+
87
+ def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None):
88
+
89
+ super(Bottleneck, self).__init__()
90
+ midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes)
91
+
92
+ # 1x1x1
93
+ self.conv1 = nn.Sequential(
94
+ nn.Conv3d(inplanes, planes, kernel_size=1, bias=False),
95
+ nn.BatchNorm3d(planes),
96
+ nn.ReLU(inplace=True)
97
+ )
98
+ # Second kernel
99
+ self.conv2 = nn.Sequential(
100
+ conv_builder(planes, planes, midplanes, stride),
101
+ nn.BatchNorm3d(planes),
102
+ nn.ReLU(inplace=True)
103
+ )
104
+
105
+ # 1x1x1
106
+ self.conv3 = nn.Sequential(
107
+ nn.Conv3d(planes, planes * self.expansion, kernel_size=1, bias=False),
108
+ nn.BatchNorm3d(planes * self.expansion)
109
+ )
110
+ self.relu = nn.ReLU(inplace=True)
111
+ self.downsample = downsample
112
+ self.stride = stride
113
+
114
+ def forward(self, x):
115
+ residual = x
116
+
117
+ out = self.conv1(x)
118
+ out = self.conv2(out)
119
+ out = self.conv3(out)
120
+
121
+ if self.downsample is not None:
122
+ residual = self.downsample(x)
123
+
124
+ out += residual
125
+ out = self.relu(out)
126
+
127
+ return out
128
+
129
+
130
+ class BasicStem(nn.Sequential):
131
+ """The default conv-batchnorm-relu stem
132
+ """
133
+ def __init__(self):
134
+ super(BasicStem, self).__init__(
135
+ nn.Conv3d(1, 64, kernel_size=(7, 7, 7), stride=(2, 2, 2),
136
+ padding=(3, 3, 3), bias=False),
137
+ nn.BatchNorm3d(64),
138
+ nn.ReLU(inplace=True))
139
+
140
+
141
+ class VideoResNet(nn.Module):
142
+
143
+ def __init__(self, block, conv_makers, layers,
144
+ stem, num_classes=2,
145
+ zero_init_residual=False):
146
+ """Generic resnet video generator.
147
+ Args:
148
+ block (nn.Module): resnet building block
149
+ conv_makers (list(functions)): generator function for each layer
150
+ layers (List[int]): number of blocks per layer
151
+ stem (nn.Module, optional): Resnet stem, if None, defaults to conv-bn-relu. Defaults to None.
152
+ num_classes (int, optional): Dimension of the final FC layer. Defaults to 400.
153
+ zero_init_residual (bool, optional): Zero init bottleneck residual BN. Defaults to False.
154
+ """
155
+ super(VideoResNet, self).__init__()
156
+ self.inplanes = 64
157
+
158
+ self.stem = stem()
159
+
160
+ self.layer1 = self._make_layer(block, conv_makers[0], 64, layers[0], stride=1)
161
+ self.layer2 = self._make_layer(block, conv_makers[1], 128, layers[1], stride=2)
162
+ self.layer3 = self._make_layer(block, conv_makers[2], 192, layers[2], stride=2)
163
+ self.layer4 = self._make_layer(block, conv_makers[3], 256, layers[3], stride=2)
164
+
165
+ self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))
166
+ # self.fc = nn.Linear(256 * block.expansion, num_classes)
167
+
168
+ # init weights
169
+ self._initialize_weights()
170
+
171
+ if zero_init_residual:
172
+ for m in self.modules():
173
+ if isinstance(m, Bottleneck):
174
+ nn.init.constant_(m.bn3.weight, 0)
175
+
176
+ def forward(self, x):
177
+ x = self.stem(x)
178
+
179
+ x = self.layer1(x)
180
+ x = self.layer2(x)
181
+ x = self.layer3(x)
182
+ x = self.layer4(x)
183
+
184
+ x = self.avgpool(x)
185
+ # Flatten the layer to fc
186
+ x = x.flatten(1)
187
+ # x = self.fc(x)
188
+
189
+ return x
190
+
191
+ def _make_layer(self, block, conv_builder, planes, blocks, stride=1):
192
+ downsample = None
193
+
194
+ if stride != 1 or self.inplanes != planes * block.expansion:
195
+ ds_stride = conv_builder.get_downsample_stride(stride)
196
+ downsample = nn.Sequential(
197
+ nn.Conv3d(self.inplanes, planes * block.expansion,
198
+ kernel_size=1, stride=ds_stride, bias=False),
199
+ nn.BatchNorm3d(planes * block.expansion)
200
+ )
201
+ layers = []
202
+ layers.append(block(self.inplanes, planes, conv_builder, stride, downsample))
203
+
204
+ self.inplanes = planes * block.expansion
205
+ for i in range(1, blocks):
206
+ layers.append(block(self.inplanes, planes, conv_builder))
207
+
208
+ return nn.Sequential(*layers)
209
+
210
+ def _initialize_weights(self):
211
+ for m in self.modules():
212
+ if isinstance(m, nn.Conv3d):
213
+ nn.init.kaiming_normal_(m.weight, mode='fan_out',
214
+ nonlinearity='relu')
215
+ if m.bias is not None:
216
+ nn.init.constant_(m.bias, 0)
217
+ elif isinstance(m, nn.BatchNorm3d):
218
+ nn.init.constant_(m.weight, 1)
219
+ nn.init.constant_(m.bias, 0)
220
+ elif isinstance(m, nn.Linear):
221
+ nn.init.normal_(m.weight, 0, 0.01)
222
+ nn.init.constant_(m.bias, 0)
223
+
224
+
225
+ def _video_resnet(arch, pretrained=False, progress=True, **kwargs):
226
+ model = VideoResNet(**kwargs)
227
+
228
+ return model
229
+
230
+
231
+ def r3d_18(pretrained=False, progress=True, **kwargs):
232
+ """Construct 18 layer Resnet3D model as in
233
+ https://arxiv.org/abs/1711.11248
234
+ Args:
235
+ pretrained (bool): If True, returns a model pre-trained on Kinetics-400
236
+ progress (bool): If True, displays a progress bar of the download to stderr
237
+ Returns:
238
+ nn.Module: R3D-18 network
239
+ """
240
+
241
+ return _video_resnet('r3d_18',
242
+ pretrained, progress,
243
+ block=BasicBlock,
244
+ conv_makers=[Conv3DSimple] * 4,
245
+ layers=[2, 2, 2, 2],
246
+ stem=BasicStem, **kwargs)
247
+
248
+
249
+ if __name__ == '__main__':
250
+ """ ... """
251
+ import torch
252
+
253
+ net = r3d_18().to('cuda:1')
254
+ x = torch.zeros(8, 1, 182, 218, 182).to('cuda:1')
255
+
256
+ print(net(x).shape)
adrd/nn/resnet_img_model.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import sys
4
+ from icecream import ic
5
+ # sys.path.append('/home/skowshik/ADRD_repo/adrd_tool/adrd/')
6
+ from .net_resnet3d import r3d_18
7
+ # from dev.data.dataset_csv import CSVDataset
8
+
9
+
10
+ class ResNetModel(nn.Module):
11
+ ''' ... '''
12
+ def __init__(
13
+ self,
14
+ tgt_modalities,
15
+ mri_feature = 'img_MRI_T1',
16
+ ):
17
+ ''' ... '''
18
+ super().__init__()
19
+
20
+ self.mri_feature = mri_feature
21
+
22
+ self.img_net_ = r3d_18()
23
+
24
+ # self.modules_emb_src = nn.Sequential(
25
+ # nn.BatchNorm1d(9),
26
+ # nn.Linear(9, d_model)
27
+ # )
28
+
29
+ # classifiers (binary only)
30
+ self.modules_cls = nn.ModuleDict()
31
+ for k, info in tgt_modalities.items():
32
+ if info['type'] == 'categorical' and info['num_categories'] == 2:
33
+ # categorical
34
+ self.modules_cls[k] = nn.Linear(64, 1)
35
+
36
+ else:
37
+ # unrecognized
38
+ raise ValueError
39
+
40
+ def forward(self, x):
41
+ ''' ... '''
42
+ tgt_iter = self.modules_cls.keys()
43
+
44
+ img_x_batch = x[self.mri_feature]
45
+ img_out = self.img_net_(img_x_batch)
46
+
47
+ # ic(img_out.shape)
48
+
49
+ # run linear classifiers
50
+ out = [self.modules_cls[k](img_out).squeeze(1) for i, k in enumerate(tgt_iter)]
51
+ out = torch.stack(out, dim=1)
52
+
53
+ # ic(out.shape)
54
+
55
+ # out to dict
56
+ out = {k: out[:, i] for i, k in enumerate(tgt_iter)}
57
+
58
+ return out
59
+
60
+
61
+ if __name__ == '__main__':
62
+ ''' for testing purpose only '''
63
+ # import torch
64
+ # import numpy as np
65
+
66
+ # seed = 0
67
+ # print('Loading training dataset ... ')
68
+ # dat_trn = CSVDataset(mode=0, split=[1, 700], seed=seed)
69
+ # print(len(dat_trn))
70
+ # tgt_modalities = dat_trn.label_modalities
71
+ # net = ResNetModel(tgt_modalities).to('cuda')
72
+ # x = dat_trn.features
73
+ # x = {k: torch.as_tensor(np.array([x[i][k] for i in range(len(x))])).to('cuda') for k in x[0]}
74
+ # ic(x)
75
+
76
+
77
+ # # print(net(x).shape)
78
+ # print(net(x))
79
+
80
+
81
+
adrd/nn/selfattention.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+ from monai.utils import optional_import
12
+ import torch
13
+ import torch.nn as nn
14
+
15
+
16
+ Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")
17
+
18
+
19
+ class SABlock(nn.Module):
20
+ """
21
+ A self-attention block, based on: "Dosovitskiy et al.,
22
+ An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
23
+ """
24
+
25
+ def __init__(self, hidden_size: int, num_heads: int, dropout_rate: float = 0.0, qkv_bias: bool = False) -> None:
26
+ """
27
+ Args:
28
+ hidden_size: dimension of hidden layer.
29
+ num_heads: number of attention heads.
30
+ dropout_rate: faction of the input units to drop.
31
+ qkv_bias: bias term for the qkv linear layer.
32
+
33
+ """
34
+
35
+ super().__init__()
36
+
37
+ if not (0 <= dropout_rate <= 1):
38
+ raise ValueError("dropout_rate should be between 0 and 1.")
39
+
40
+ if hidden_size % num_heads != 0:
41
+ raise ValueError("hidden size should be divisible by num_heads.")
42
+
43
+ self.num_heads = num_heads
44
+ self.out_proj = nn.Linear(hidden_size, hidden_size)
45
+ self.qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias)
46
+ self.input_rearrange = Rearrange("b h (qkv l d) -> qkv b l h d", qkv=3, l=num_heads)
47
+ self.out_rearrange = Rearrange("b h l d -> b l (h d)")
48
+ self.drop_output = nn.Dropout(dropout_rate)
49
+ self.drop_weights = nn.Dropout(dropout_rate)
50
+ self.head_dim = hidden_size // num_heads
51
+ self.scale = self.head_dim**-0.5
52
+
53
+ def forward(self, x):
54
+ output = self.input_rearrange(self.qkv(x))
55
+ q, k, v = output[0], output[1], output[2]
56
+ att_mat = (torch.einsum("blxd,blyd->blxy", q, k) * self.scale).softmax(dim=-1)
57
+ att_mat = self.drop_weights(att_mat)
58
+ x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v)
59
+ x = self.out_rearrange(x)
60
+ x = self.out_proj(x)
61
+ x = self.drop_output(x)
62
+ return x, att_mat
adrd/nn/transformer.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from .. import nn
4
+ # from ..nn import ImagingModelWrapper
5
+ from .net_resnet3d import r3d_18
6
+ from typing import Any, Type
7
+ import math
8
+ Tensor = Type[torch.Tensor]
9
+ from icecream import ic
10
+ ic.disable()
11
+
12
+ class Transformer(torch.nn.Module):
13
+ ''' ... '''
14
+ def __init__(self,
15
+ src_modalities: dict[str, dict[str, Any]],
16
+ tgt_modalities: dict[str, dict[str, Any]],
17
+ d_model: int,
18
+ nhead: int,
19
+ num_encoder_layers: int = 1,
20
+ num_decoder_layers: int = 1,
21
+ device: str = 'cpu',
22
+ cuda_devices: list = [3],
23
+ img_net: str | None = None,
24
+ layers: int = 3,
25
+ img_size: int | None = 128,
26
+ patch_size: int | None = 16,
27
+ imgnet_ckpt: str | None = None,
28
+ train_imgnet: bool = False,
29
+ fusion_stage: str = 'middle',
30
+ ) -> None:
31
+ ''' ... '''
32
+ super().__init__()
33
+
34
+ self.d_model = d_model
35
+ self.nhead = nhead
36
+ self.num_encoder_layers = num_encoder_layers
37
+ self.num_decoder_layers = num_decoder_layers
38
+ self.img_net = img_net
39
+ self.img_size = img_size
40
+ self.patch_size = patch_size
41
+ self.imgnet_ckpt = imgnet_ckpt
42
+ self.train_imgnet = train_imgnet
43
+ self.layers = layers
44
+ self.src_modalities = src_modalities
45
+ self.tgt_modalities = tgt_modalities
46
+ self.device = device
47
+ self.fusion_stage = fusion_stage
48
+
49
+ # embedding modules for source
50
+
51
+ self.modules_emb_src = torch.nn.ModuleDict()
52
+ print('Downsample layers: ', self.layers)
53
+ self.img_model = nn.ImagingModelWrapper(arch=self.img_net, img_size=self.img_size, patch_size=self.patch_size, ckpt_path=self.imgnet_ckpt, train_backbone=self.train_imgnet, layers=self.layers, out_dim=self.d_model, device=self.device, fusion_stage=self.fusion_stage)
54
+
55
+ for k, info in src_modalities.items():
56
+ # ic(k)
57
+ # for key, val in info.items():
58
+ # ic(key, val)
59
+ if info['type'] == 'categorical':
60
+ self.modules_emb_src[k] = torch.nn.Embedding(info['num_categories'], d_model)
61
+ elif info['type'] == 'numerical':
62
+ self.modules_emb_src[k] = torch.nn.Sequential(
63
+ torch.nn.BatchNorm1d(info['shape'][0]),
64
+ torch.nn.Linear(info['shape'][0], d_model)
65
+ )
66
+ elif info['type'] == 'imaging':
67
+ # print(info['shape'], info['img_shape'])
68
+ if self.img_net:
69
+ self.modules_emb_src[k] = self.img_model
70
+
71
+ else:
72
+ # unrecognized
73
+ raise ValueError('{} is an unrecognized data modality'.format(k))
74
+
75
+ # positional encoding
76
+ self.pe = PositionalEncoding(d_model)
77
+
78
+ # auxiliary embedding vectors for targets
79
+ self.emb_aux = torch.nn.Parameter(
80
+ torch.zeros(len(tgt_modalities), 1, d_model),
81
+ requires_grad = True,
82
+ )
83
+
84
+ # transformer
85
+ enc = torch.nn.TransformerEncoderLayer(
86
+ self.d_model, self.nhead,
87
+ dim_feedforward = self.d_model,
88
+ activation = 'gelu',
89
+ dropout = 0.3,
90
+ )
91
+ self.transformer = torch.nn.TransformerEncoder(enc, self.num_encoder_layers)
92
+
93
+
94
+ # classifiers (binary only)
95
+ self.modules_cls = torch.nn.ModuleDict()
96
+ for k, info in tgt_modalities.items():
97
+ if info['type'] == 'categorical' and info['num_categories'] == 2:
98
+ self.modules_cls[k] = torch.nn.Linear(d_model, 1)
99
+ else:
100
+ # unrecognized
101
+ raise ValueError
102
+
103
+ # for n,p in self.named_parameters():
104
+ # print(n, p.requires_grad)
105
+
106
+ def forward(self,
107
+ x: dict[str, Tensor],
108
+ mask: dict[str, Tensor],
109
+ # x_img: dict[str, Tensor] | Any = None,
110
+ skip_embedding: dict[str, bool] | None = None,
111
+ return_out_emb: bool = False,
112
+ ) -> dict[str, Tensor]:
113
+ """ ... """
114
+
115
+ out_emb = self.forward_emb(x, mask, skip_embedding)
116
+ if self.fusion_stage == "late":
117
+ out_emb = {k: v for k,v in out_emb.items() if "img_MRI" not in k}
118
+ img_out_emb = {k: v for k,v in out_emb.items() if "img_MRI" in k}
119
+ # for k,v in out_emb.items():
120
+ # print(k, v.size())
121
+ mask_nonimg = {k: v for k,v in mask.items() if "img_MRI" not in k}
122
+ out_trf = self.forward_trf(out_emb, mask_nonimg) # (8,128) + (8,50,128)
123
+ # print("out_trf: ", out_trf.size())
124
+ out_trf = torch.concatenate()
125
+ else:
126
+ out_trf = self.forward_trf(out_emb, mask)
127
+
128
+ out_cls = self.forward_cls(out_trf)
129
+
130
+ if return_out_emb:
131
+ return out_emb, out_cls
132
+ return out_cls
133
+
134
+ def forward_emb(self,
135
+ x: dict[str, Tensor],
136
+ mask: dict[str, Tensor],
137
+ skip_embedding: dict[str, bool] | None = None,
138
+ # x_img: dict[str, Tensor] | Any = None,
139
+ ) -> dict[str, Tensor]:
140
+ """ ... """
141
+ # print("-------forward_emb--------")
142
+ out_emb = dict()
143
+ for k in self.modules_emb_src.keys():
144
+ if skip_embedding is None or k not in skip_embedding or not skip_embedding[k]:
145
+ if "img_MRI" in k:
146
+ # print("img_MRI in ", k)
147
+ if torch.all(mask[k]):
148
+ if "swinunetr" in self.img_net.lower() and self.fusion_stage == 'late':
149
+ out_emb[k] = torch.zeros((1,768,4,4,4))
150
+ else:
151
+ if 'cuda' in self.device:
152
+ device = x[k].device
153
+ # print(device)
154
+ else:
155
+ device = self.device
156
+ out_emb[k] = torch.zeros((mask[k].shape[0], self.d_model)).to(device, non_blocking=True)
157
+ # print("mask is True, out_emb[k]: ", out_emb[k].size())
158
+ else:
159
+ # print("calling modules_emb_src...")
160
+ out_emb[k] = self.modules_emb_src[k](x[k])
161
+ # print("mask is False, out_emb[k]: ", out_emb[k].size())
162
+
163
+ else:
164
+ out_emb[k] = self.modules_emb_src[k](x[k])
165
+
166
+ # out_emb[k] = self.modules_emb_src[k](x[k])
167
+ else:
168
+ out_emb[k] = x[k]
169
+ return out_emb
170
+
171
+ def forward_trf(self,
172
+ out_emb: dict[str, Tensor],
173
+ mask: dict[str, Tensor],
174
+ ) -> dict[str, Tensor]:
175
+ """ ... """
176
+ # print('-----------forward_trf----------')
177
+ N = len(next(iter(out_emb.values()))) # batch size
178
+ S = len(self.modules_emb_src) # number of sources
179
+ T = len(self.modules_cls) # number of targets
180
+ if self.fusion_stage == 'late':
181
+ src_iter = [k for k in self.modules_emb_src.keys() if "img_MRI" not in k]
182
+ S = len(src_iter) # number of sources
183
+
184
+ else:
185
+ src_iter = self.modules_emb_src.keys()
186
+ tgt_iter = self.modules_cls.keys()
187
+
188
+ emb_src = torch.stack([o for o in out_emb.values()], dim=0)
189
+ # print('emb_src: ', emb_src.size())
190
+
191
+ self.pe.index = -1
192
+ emb_src = self.pe(emb_src)
193
+ # print('emb_src + pe: ', emb_src.size())
194
+
195
+ # target embedding
196
+ # print('emb_aux: ', self.emb_aux.size())
197
+ emb_tgt = self.emb_aux.repeat(1, N, 1)
198
+ # print('emb_tgt: ', emb_tgt.size())
199
+
200
+ # concatenate source embeddings and target embeddings
201
+ emb_all = torch.concatenate((emb_tgt, emb_src), dim=0)
202
+
203
+ # combine masks
204
+ mask_src = [mask[k] for k in src_iter]
205
+ mask_src = torch.stack(mask_src, dim=1)
206
+
207
+ # target masks
208
+ mask_tgt = torch.zeros((N, T), dtype=torch.bool, device=self.emb_aux.device)
209
+
210
+ # concatenate source masks and target masks
211
+ mask_all = torch.concatenate((mask_tgt, mask_src), dim=1)
212
+
213
+ # repeat mask_all to fit transformer
214
+ mask_all = mask_all.unsqueeze(1).expand(-1, S + T, -1).repeat(self.nhead, 1, 1)
215
+
216
+ # run transformer
217
+ out_trf = self.transformer(
218
+ src = emb_all,
219
+ mask = mask_all,
220
+ )[0]
221
+ # print('out_trf: ', out_trf.size())
222
+ # out_trf = {k: out_trf[i] for i, k in enumerate(tgt_iter)}
223
+ return out_trf
224
+
225
+ def forward_cls(self,
226
+ out_trf: dict[str, Tensor],
227
+ ) -> dict[str, Tensor]:
228
+ """ ... """
229
+ tgt_iter = self.modules_cls.keys()
230
+ out_cls = {k: self.modules_cls[k](out_trf).squeeze(1) for k in tgt_iter}
231
+ return out_cls
232
+
233
+ class PositionalEncoding(torch.nn.Module):
234
+
235
+ def __init__(self,
236
+ d_model: int,
237
+ max_len: int = 512
238
+ ):
239
+ """ ... """
240
+ super().__init__()
241
+ position = torch.arange(max_len).unsqueeze(1)
242
+ div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
243
+ pe = torch.zeros(max_len, 1, d_model)
244
+ pe[:, 0, 0::2] = torch.sin(position * div_term)
245
+ pe[:, 0, 1::2] = torch.cos(position * div_term)
246
+ self.register_buffer('pe', pe)
247
+ self.index = -1
248
+
249
+ def forward(self, x: Tensor, pe_type: str = 'non_img') -> Tensor:
250
+ """
251
+ Arguments:
252
+ x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
253
+ """
254
+ # print('pe: ', self.pe.size())
255
+ # print('x: ', x.size())
256
+ if pe_type == 'img':
257
+ self.index += 1
258
+ return x + self.pe[self.index]
259
+ else:
260
+ self.index += 1
261
+ return x + self.pe[self.index:x.size(0)+self.index]
262
+
263
+
264
+ if __name__ == '__main__':
265
+ ''' for testing purpose only '''
266
+ pass
267
+
268
+
adrd/nn/unet.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torchvision
5
+ from torchvision import models
6
+ from torch.nn import init
7
+ import torch.nn.functional as F
8
+ from icecream import ic
9
+
10
+
11
+ class ContBatchNorm3d(nn.modules.batchnorm._BatchNorm):
12
+ def _check_input_dim(self, input):
13
+
14
+ if input.dim() != 5:
15
+ raise ValueError('expected 5D input (got {}D input)'.format(input.dim()))
16
+ #super(ContBatchNorm3d, self)._check_input_dim(input)
17
+
18
+ def forward(self, input):
19
+ self._check_input_dim(input)
20
+ return F.batch_norm(
21
+ input, self.running_mean, self.running_var, self.weight, self.bias,
22
+ True, self.momentum, self.eps)
23
+
24
+
25
+ class LUConv(nn.Module):
26
+ def __init__(self, in_chan, out_chan, act):
27
+ super(LUConv, self).__init__()
28
+ self.conv1 = nn.Conv3d(in_chan, out_chan, kernel_size=3, padding=1)
29
+ self.bn1 = ContBatchNorm3d(out_chan)
30
+
31
+ if act == 'relu':
32
+ self.activation = nn.ReLU(out_chan)
33
+ elif act == 'prelu':
34
+ self.activation = nn.PReLU(out_chan)
35
+ elif act == 'elu':
36
+ self.activation = nn.ELU(inplace=True)
37
+ else:
38
+ raise
39
+
40
+ def forward(self, x):
41
+ out = self.activation(self.bn1(self.conv1(x)))
42
+ return out
43
+
44
+
45
+ def _make_nConv(in_channel, depth, act, double_chnnel=False):
46
+ if double_chnnel:
47
+ layer1 = LUConv(in_channel, 32 * (2 ** (depth+1)),act)
48
+ layer2 = LUConv(32 * (2 ** (depth+1)), 32 * (2 ** (depth+1)),act)
49
+ else:
50
+ layer1 = LUConv(in_channel, 32*(2**depth),act)
51
+ layer2 = LUConv(32*(2**depth), 32*(2**depth)*2,act)
52
+
53
+ return nn.Sequential(layer1,layer2)
54
+
55
+
56
+ class DownTransition(nn.Module):
57
+ def __init__(self, in_channel,depth, act):
58
+ super(DownTransition, self).__init__()
59
+ self.ops = _make_nConv(in_channel, depth,act)
60
+ self.maxpool = nn.MaxPool3d(2)
61
+ self.current_depth = depth
62
+
63
+ def forward(self, x):
64
+ if self.current_depth == 3:
65
+ out = self.ops(x)
66
+ out_before_pool = out
67
+ else:
68
+ out_before_pool = self.ops(x)
69
+ out = self.maxpool(out_before_pool)
70
+ return out, out_before_pool
71
+
72
+ class UpTransition(nn.Module):
73
+ def __init__(self, inChans, outChans, depth,act):
74
+ super(UpTransition, self).__init__()
75
+ self.depth = depth
76
+ self.up_conv = nn.ConvTranspose3d(inChans, outChans, kernel_size=2, stride=2)
77
+ self.ops = _make_nConv(inChans+ outChans//2,depth, act, double_chnnel=True)
78
+
79
+ def forward(self, x, skip_x):
80
+ out_up_conv = self.up_conv(x)
81
+ concat = torch.cat((out_up_conv,skip_x),1)
82
+ out = self.ops(concat)
83
+ return out
84
+
85
+ class OutputTransition(nn.Module):
86
+ def __init__(self, inChans, n_labels):
87
+
88
+ super(OutputTransition, self).__init__()
89
+ self.final_conv = nn.Conv3d(inChans, n_labels, kernel_size=1)
90
+ self.sigmoid = nn.Sigmoid()
91
+
92
+ def forward(self, x):
93
+ out = self.sigmoid(self.final_conv(x))
94
+ return out
95
+
96
+ class ConvLayer(nn.Module):
97
+ def __init__(self, in_channels, out_channels, drop_rate, kernel, pooling, BN=True, relu_type='leaky'):
98
+ super().__init__()
99
+ kernel_size, kernel_stride, kernel_padding = kernel
100
+ pool_kernel, pool_stride, pool_padding = pooling
101
+ self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, kernel_stride, kernel_padding, bias=False)
102
+ self.pooling = nn.MaxPool3d(pool_kernel, pool_stride, pool_padding)
103
+ self.BN = nn.BatchNorm3d(out_channels)
104
+ self.relu = nn.LeakyReLU(inplace=False) if relu_type=='leaky' else nn.ReLU(inplace=False)
105
+ self.dropout = nn.Dropout(drop_rate, inplace=False)
106
+
107
+ def forward(self, x):
108
+ x = self.conv(x)
109
+ x = self.pooling(x)
110
+ x = self.BN(x)
111
+ x = self.relu(x)
112
+ x = self.dropout(x)
113
+ return x
114
+
115
+ class AttentionModule(nn.Module):
116
+ def __init__(self, in_channels, out_channels, drop_rate=0.1):
117
+ super(AttentionModule, self).__init__()
118
+ self.conv = nn.Conv3d(in_channels, out_channels, 1, 1, 0, bias=False)
119
+ self.attention = ConvLayer(in_channels, out_channels, drop_rate, (1, 1, 0), (1, 1, 0))
120
+
121
+ def forward(self, x, return_attention=True):
122
+ feats = self.conv(x)
123
+ att = F.softmax(self.attention(x))
124
+
125
+ out = feats * att
126
+
127
+ if return_attention:
128
+ return att, out
129
+
130
+ return out
131
+
132
+ class UNet3D(nn.Module):
133
+ # the number of convolutions in each layer corresponds
134
+ # to what is in the actual prototxt, not the intent
135
+ def __init__(self, n_class=1, act='relu', pretrained=False, input_size=(1,1,182,218,182), attention=False, drop_rate=0.1, blocks=4):
136
+ super(UNet3D, self).__init__()
137
+
138
+ self.blocks = blocks
139
+ self.down_tr64 = DownTransition(1,0,act)
140
+ self.down_tr128 = DownTransition(64,1,act)
141
+ self.down_tr256 = DownTransition(128,2,act)
142
+ self.down_tr512 = DownTransition(256,3,act)
143
+
144
+ self.up_tr256 = UpTransition(512, 512,2,act)
145
+ self.up_tr128 = UpTransition(256,256, 1,act)
146
+ self.up_tr64 = UpTransition(128,128,0,act)
147
+ self.out_tr = OutputTransition(64, 1)
148
+
149
+ self.pretrained = pretrained
150
+ self.attention = attention
151
+ if pretrained:
152
+ print("Using image pretrained model checkpoint")
153
+ weight_dir = '/home/skowshik/ADRD_repo/img_pretrained_ckpt/Genesis_Chest_CT.pt'
154
+ checkpoint = torch.load(weight_dir)
155
+ state_dict = checkpoint['state_dict']
156
+ unParalled_state_dict = {}
157
+ for key in state_dict.keys():
158
+ unParalled_state_dict[key.replace("module.", "")] = state_dict[key]
159
+ self.load_state_dict(unParalled_state_dict)
160
+ del self.up_tr256
161
+ del self.up_tr128
162
+ del self.up_tr64
163
+ del self.out_tr
164
+
165
+ if self.blocks == 5:
166
+ self.down_tr1024 = DownTransition(512,4,act)
167
+
168
+
169
+ # self.conv1 = nn.Conv3d(512, 256, 1, 1, 0, bias=False)
170
+ # self.conv2 = nn.Conv3d(256, 128, 1, 1, 0, bias=False)
171
+ # self.conv3 = nn.Conv3d(128, 64, 1, 1, 0, bias=False)
172
+
173
+ if attention:
174
+ self.attention_module = AttentionModule(1024 if self.blocks==5 else 512, n_class, drop_rate=drop_rate)
175
+ # Output.
176
+ self.avgpool = nn.AvgPool3d((6,7,6), stride=(6,6,6))
177
+
178
+ dummy_inp = torch.rand(input_size)
179
+ dummy_feats = self.forward(dummy_inp, stage='get_features')
180
+ dummy_feats = dummy_feats[0]
181
+ self.in_features = list(dummy_feats.shape)
182
+ ic(self.in_features)
183
+
184
+ self._init_weights()
185
+
186
+ def _init_weights(self):
187
+ if not self.pretrained:
188
+ for m in self.modules():
189
+ if isinstance(m, nn.Conv3d):
190
+ init.kaiming_normal_(m.weight)
191
+ elif isinstance(m, ContBatchNorm3d):
192
+ init.constant_(m.weight, 1)
193
+ init.constant_(m.bias, 0)
194
+ elif isinstance(m, nn.Linear):
195
+ init.kaiming_normal_(m.weight)
196
+ init.constant_(m.bias, 0)
197
+ elif self.attention:
198
+ for m in self.attention_module.modules():
199
+ if isinstance(m, nn.Conv3d):
200
+ init.kaiming_normal_(m.weight)
201
+ elif isinstance(m, nn.BatchNorm3d):
202
+ init.constant_(m.weight, 1)
203
+ init.constant_(m.bias, 0)
204
+ else:
205
+ pass
206
+ # Zero initialize the last batchnorm in each residual branch.
207
+ # for m in self.modules():
208
+ # if isinstance(m, BottleneckBlock):
209
+ # init.constant_(m.out_conv.bn.weight, 0)
210
+
211
+ def forward(self, x, stage='normal', attention=False):
212
+ ic('backbone forward')
213
+ self.out64, self.skip_out64 = self.down_tr64(x)
214
+ self.out128,self.skip_out128 = self.down_tr128(self.out64)
215
+ self.out256,self.skip_out256 = self.down_tr256(self.out128)
216
+ self.out512,self.skip_out512 = self.down_tr512(self.out256)
217
+ if self.blocks == 5:
218
+ self.out1024,self.skip_out1024 = self.down_tr1024(self.out512)
219
+ ic(self.out1024.shape)
220
+ # self.out = self.conv1(self.out512)
221
+ # self.out = self.conv2(self.out)
222
+ # self.out = self.conv3(self.out)
223
+ # self.out = self.conv(self.out)
224
+ ic(hasattr(self, 'attention_module'))
225
+ if hasattr(self, 'attention_module'):
226
+ att, feats = self.attention_module(self.out1024 if self.blocks==5 else self.out512)
227
+ else:
228
+ feats = self.out1024 if self.blocks==5 else self.out512
229
+ ic(feats.shape)
230
+ if attention:
231
+ return att, feats
232
+ return feats
adrd/nn/unet_3d.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append('..')
3
+ # from feature_extractor.for_image_data.backbone import CNN_GAP, ResNet3D, UNet3D
4
+ import torch
5
+ import torch.nn as nn
6
+ from torchvision import models
7
+ import torch.nn.functional as F
8
+ # from . import UNet3D
9
+ from .unet import UNet3D
10
+ from icecream import ic
11
+
12
+
13
+ class UNet3DBase(nn.Module):
14
+ def __init__(self, n_class=1, act='relu', attention=False, pretrained=False, drop_rate=0.1, blocks=4):
15
+ super(UNet3DBase, self).__init__()
16
+ model = UNet3D(n_class=n_class, attention=attention, pretrained=pretrained, blocks=blocks)
17
+
18
+ self.blocks = blocks
19
+
20
+ self.down_tr64 = model.down_tr64
21
+ self.down_tr128 = model.down_tr128
22
+ self.down_tr256 = model.down_tr256
23
+ self.down_tr512 = model.down_tr512
24
+ if self.blocks == 5:
25
+ self.down_tr1024 = model.down_tr1024
26
+ # self.block_modules = nn.ModuleList([self.down_tr64, self.down_tr128, self.down_tr256, self.down_tr512])
27
+
28
+ self.in_features = model.in_features
29
+ # ic(attention)
30
+ if attention:
31
+ self.attention_module = model.attention_module
32
+ # self.attention_module = AttentionModule(512, n_class, drop_rate=drop_rate)
33
+ # self.avgpool = nn.AvgPool3d((6,7,6), stride=(6,6,6))
34
+
35
+ def forward(self, x, stage='normal', attention=False):
36
+ # ic('UNet3DBase forward')
37
+ self.out64, self.skip_out64 = self.down_tr64(x)
38
+ # ic(self.out64.shape, self.skip_out64.shape)
39
+ self.out128,self.skip_out128 = self.down_tr128(self.out64)
40
+ # ic(self.out128.shape, self.skip_out128.shape)
41
+ self.out256,self.skip_out256 = self.down_tr256(self.out128)
42
+ # ic(self.out256.shape, self.skip_out256.shape)
43
+ self.out512,self.skip_out512 = self.down_tr512(self.out256)
44
+ # ic(self.out512.shape, self.skip_out512.shape)
45
+ if self.blocks == 5:
46
+ self.out1024,self.skip_out1024 = self.down_tr1024(self.out512)
47
+ # ic(self.out1024.shape, self.skip_out1024.shape)
48
+ # ic(hasattr(self, 'attention_module'))
49
+ if hasattr(self, 'attention_module'):
50
+ att, feats = self.attention_module(self.out1024 if self.blocks == 5 else self.out512)
51
+ else:
52
+ feats = self.out1024 if self.blocks == 5 else self.out512
53
+ # ic(feats.shape)
54
+ if attention:
55
+ return att, feats
56
+ return feats
57
+
58
+ # self.out_up_256 = self.up_tr256(self.out512,self.skip_out256)
59
+ # self.out_up_128 = self.up_tr128(self.out_up_256, self.skip_out128)
60
+ # self.out_up_64 = self.up_tr64(self.out_up_128, self.skip_out64)
61
+ # self.out = self.out_tr(self.out_up_64)
62
+
63
+ # return self.out
adrd/nn/unet_img_model.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pyexpat import features
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torch.cuda.amp import autocast
6
+ import numpy as np
7
+ import re
8
+ from icecream import ic
9
+ import math
10
+ import torch.nn.utils.weight_norm as weightNorm
11
+
12
+ # from . import UNet3DBase
13
+ from .unet_3d import UNet3DBase
14
+
15
+
16
+ def init_weights(m):
17
+ classname = m.__class__.__name__
18
+ if classname.find('Conv2d') != -1 or classname.find('ConvTranspose2d') != -1:
19
+ nn.init.kaiming_uniform_(m.weight)
20
+ nn.init.zeros_(m.bias)
21
+ elif classname.find('BatchNorm') != -1:
22
+ nn.init.normal_(m.weight, 1.0, 0.02)
23
+ nn.init.zeros_(m.bias)
24
+ elif classname.find('Linear') != -1:
25
+ nn.init.xavier_normal_(m.weight)
26
+ nn.init.zeros_(m.bias)
27
+
28
+ class feat_classifier(nn.Module):
29
+ def __init__(self, class_num, bottleneck_dim=256, type="linear"):
30
+ super(feat_classifier, self).__init__()
31
+ self.type = type
32
+ # if type in ['conv', 'gap'] and len(bottleneck_dim) > 3:
33
+ # bottleneck_dim = bottleneck_dim[-3:]
34
+ ic(bottleneck_dim)
35
+ if type == 'wn':
36
+ self.layer = weightNorm(
37
+ nn.Linear(bottleneck_dim[1:], class_num), name="weight")
38
+ # self.fc.apply(init_weights)
39
+ elif type == 'gap':
40
+ if len(bottleneck_dim) > 3:
41
+ bottleneck_dim = bottleneck_dim[-3:]
42
+ self.layer = nn.AvgPool3d(bottleneck_dim, stride=(1,1,1))
43
+ elif type == 'conv':
44
+ if len(bottleneck_dim) > 3:
45
+ bottleneck_dim = bottleneck_dim[-4:]
46
+ ic(bottleneck_dim)
47
+ self.layer = nn.Conv3d(bottleneck_dim[0], class_num, kernel_size=bottleneck_dim[1:])
48
+ ic(self.layer)
49
+ else:
50
+ print('bottleneck dim: ', bottleneck_dim)
51
+ self.layer = nn.Sequential(
52
+ torch.nn.Flatten(start_dim=1, end_dim=-1),
53
+ nn.Linear(math.prod(bottleneck_dim), class_num)
54
+ )
55
+ self.layer.apply(init_weights)
56
+
57
+ def forward(self, x):
58
+ # print('=> feat_classifier forward')
59
+ # ic(x.size())
60
+ x = self.layer(x)
61
+ # ic(x.size())
62
+ if self.type in ['gap','conv']:
63
+ x = torch.squeeze(x)
64
+ if len(x.shape) < 2:
65
+ x = torch.unsqueeze(x,0)
66
+ # print('returning x: ', x.size())
67
+ return x
68
+
69
+ class ImageModel(nn.Module):
70
+ """
71
+ Empirical Risk Minimization (ERM)
72
+ """
73
+
74
+ def __init__(
75
+ self,
76
+ counts=None,
77
+ classifier='gap',
78
+ accum_iter=8,
79
+ save_emb=False,
80
+ # ssl,
81
+ num_classes=1,
82
+ load_img_ckpt=False,
83
+ ):
84
+ super(ImageModel, self).__init__()
85
+ if counts is not None:
86
+ if isinstance(counts[0], list):
87
+ counts = np.stack(counts, axis=0).sum(axis=0)
88
+ print('counts: ', counts)
89
+ total = np.sum(counts)
90
+ print(total/counts)
91
+ self.weight = total/torch.FloatTensor(counts)
92
+ else:
93
+ total = sum(counts)
94
+ self.weight = torch.FloatTensor([total/c for c in counts])
95
+ else:
96
+ self.weight = None
97
+ print('weight: ', self.weight)
98
+ # device = torch.device(f'cuda:{args.gpu_id}' if args.gpu_id is not None else 'cpu')
99
+ self.criterion = nn.CrossEntropyLoss(weight=self.weight)
100
+ # if ssl:
101
+ # # add contrastive loss
102
+ # # self.ssl_criterion =
103
+ # pass
104
+
105
+ self.featurizer = UNet3DBase(n_class=num_classes, attention=True, pretrained=load_img_ckpt)
106
+ self.classifier = feat_classifier(
107
+ num_classes, self.featurizer.in_features, classifier)
108
+
109
+ self.network = nn.Sequential(
110
+ self.featurizer, self.classifier)
111
+ self.accum_iter = accum_iter
112
+ self.acc_steps = 0
113
+ self.save_embedding = save_emb
114
+
115
+ def update(self, minibatches, opt, sch, scaler):
116
+ print('--------------def update----------------')
117
+ device = list(self.parameters())[0].device
118
+ all_x = torch.cat([data[1].to(device).float() for data in minibatches])
119
+ all_y = torch.cat([data[2].to(device).long() for data in minibatches])
120
+ print('all_x: ', all_x.size())
121
+ # all_p = self.predict(all_x)
122
+ # all_probs =
123
+ label_list = all_y.tolist()
124
+ count = float(len(label_list))
125
+ ic(count)
126
+
127
+ uniques = sorted(list(set(label_list)))
128
+ ic(uniques)
129
+ counts = [float(label_list.count(i)) for i in uniques]
130
+ ic(counts)
131
+
132
+ weights = [count / c for c in counts]
133
+ ic(weights)
134
+
135
+ with autocast():
136
+ loss = self.criterion(self.predict(all_x), all_y)
137
+ self.acc_steps += 1
138
+ print('class: ', loss.item())
139
+
140
+ scaler.scale(loss / self.accum_iter).backward()
141
+
142
+ if self.acc_steps == self.accum_iter:
143
+ scaler.step(opt)
144
+ if sch:
145
+ sch.step()
146
+ scaler.update()
147
+ self.zero_grad()
148
+ self.acc_steps = 0
149
+ torch.cuda.empty_cache()
150
+
151
+ del all_x
152
+ del all_y
153
+ return {'class': loss.item()}, sch
154
+
155
+ def forward(self, *args, **kwargs):
156
+ return self.network(*args, **kwargs)
157
+
158
+ def predict(self, x, stage='normal', attention=False):
159
+ # print('network device: ', list(self.network.parameters())[0].device)
160
+ # print('x device: ', x.device)
161
+ if stage == 'get_features' or self.save_embedding:
162
+ feats = self.network[0](x, attention=attention)
163
+ output = self.network[1](feats[-1] if attention else feats)
164
+ return feats, output
165
+ else:
166
+ return self.network(x)
167
+
168
+ def extract_features(self, x, attention=False):
169
+ feats = self.network[0](x, attention=attention)
170
+ return feats
171
+
172
+ def load_checkpoint(self, state_dict):
173
+ try:
174
+ self.load_checkpoint_helper(state_dict)
175
+ except:
176
+ featurizer_dict = {}
177
+ net_dict = {}
178
+ for key,val in state_dict.items():
179
+ if 'featurizer' in key:
180
+ featurizer_dict[key] = val
181
+ elif 'network' in key:
182
+ net_dict[key] = val
183
+ self.featurizer.load_state_dict(featurizer_dict)
184
+ self.classifier.load_state_dict(net_dict)
185
+
186
+ def load_checkpoint_helper(self, state_dict):
187
+ try:
188
+ self.load_state_dict(state_dict)
189
+ print('try: loaded')
190
+ except RuntimeError as e:
191
+ print('--> except')
192
+ if 'Missing key(s) in state_dict:' in str(e):
193
+ state_dict = {
194
+ key.replace('module.', '', 1): value
195
+ for key, value in state_dict.items()
196
+ }
197
+ state_dict = {
198
+ key.replace('featurizer.', '', 1).replace('classifier.','',1): value
199
+ for key, value in state_dict.items()
200
+ }
201
+ state_dict = {
202
+ re.sub('network.[0-9].', '', key): value
203
+ for key, value in state_dict.items()
204
+ }
205
+ try:
206
+ del state_dict['criterion.weight']
207
+ except:
208
+ pass
209
+ self.load_state_dict(state_dict)
210
+
211
+ print('except: loaded')
adrd/nn/vitautoenc.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+
13
+ from monai.networks.blocks.patchembedding import PatchEmbeddingBlock
14
+ from monai.networks.layers import Conv
15
+ from monai.utils import ensure_tuple_rep
16
+
17
+ from typing import Sequence, Union
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from ..nn.blocks import TransformerBlock
22
+ from icecream import ic
23
+ ic.disable()
24
+
25
+ __all__ = ["ViTAutoEnc"]
26
+
27
+
28
+ class ViTAutoEnc(nn.Module):
29
+ """
30
+ Vision Transformer (ViT), based on: "Dosovitskiy et al.,
31
+ An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
32
+
33
+ Modified to also give same dimension outputs as the input size of the image
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ in_channels: int,
39
+ img_size: Union[Sequence[int], int],
40
+ patch_size: Union[Sequence[int], int],
41
+ out_channels: int = 1,
42
+ deconv_chns: int = 16,
43
+ hidden_size: int = 768,
44
+ mlp_dim: int = 3072,
45
+ num_layers: int = 12,
46
+ num_heads: int = 12,
47
+ pos_embed: str = "conv",
48
+ dropout_rate: float = 0.0,
49
+ spatial_dims: int = 3,
50
+ ) -> None:
51
+ """
52
+ Args:
53
+ in_channels: dimension of input channels or the number of channels for input
54
+ img_size: dimension of input image.
55
+ patch_size: dimension of patch size.
56
+ hidden_size: dimension of hidden layer.
57
+ out_channels: number of output channels.
58
+ deconv_chns: number of channels for the deconvolution layers.
59
+ mlp_dim: dimension of feedforward layer.
60
+ num_layers: number of transformer blocks.
61
+ num_heads: number of attention heads.
62
+ pos_embed: position embedding layer type.
63
+ dropout_rate: faction of the input units to drop.
64
+ spatial_dims: number of spatial dimensions.
65
+
66
+ Examples::
67
+
68
+ # for single channel input with image size of (96,96,96), conv position embedding and segmentation backbone
69
+ # It will provide an output of same size as that of the input
70
+ >>> net = ViTAutoEnc(in_channels=1, patch_size=(16,16,16), img_size=(96,96,96), pos_embed='conv')
71
+
72
+ # for 3-channel with image size of (128,128,128), output will be same size as of input
73
+ >>> net = ViTAutoEnc(in_channels=3, patch_size=(16,16,16), img_size=(128,128,128), pos_embed='conv')
74
+
75
+ """
76
+
77
+ super().__init__()
78
+
79
+ self.patch_size = ensure_tuple_rep(patch_size, spatial_dims)
80
+ self.spatial_dims = spatial_dims
81
+ self.hidden_size = hidden_size
82
+
83
+ self.patch_embedding = PatchEmbeddingBlock(
84
+ in_channels=in_channels,
85
+ img_size=img_size,
86
+ patch_size=patch_size,
87
+ hidden_size=hidden_size,
88
+ num_heads=num_heads,
89
+ pos_embed=pos_embed,
90
+ dropout_rate=dropout_rate,
91
+ spatial_dims=self.spatial_dims,
92
+ )
93
+ self.blocks = nn.ModuleList(
94
+ [TransformerBlock(hidden_size, mlp_dim, num_heads, dropout_rate) for i in range(num_layers)]
95
+ )
96
+ self.norm = nn.LayerNorm(hidden_size)
97
+
98
+ new_patch_size = [4] * self.spatial_dims
99
+ conv_trans = Conv[Conv.CONVTRANS, self.spatial_dims]
100
+ # self.conv3d_transpose* is to be compatible with existing 3d model weights.
101
+ self.conv3d_transpose = conv_trans(hidden_size, deconv_chns, kernel_size=new_patch_size, stride=new_patch_size)
102
+ self.conv3d_transpose_1 = conv_trans(
103
+ in_channels=deconv_chns, out_channels=out_channels, kernel_size=new_patch_size, stride=new_patch_size
104
+ )
105
+
106
+ def forward(self, x, return_emb=False, return_hiddens=False):
107
+ """
108
+ Args:
109
+ x: input tensor must have isotropic spatial dimensions,
110
+ such as ``[batch_size, channels, sp_size, sp_size[, sp_size]]``.
111
+ """
112
+ spatial_size = x.shape[2:]
113
+ x = self.patch_embedding(x)
114
+ hidden_states_out = []
115
+ for blk in self.blocks:
116
+ x = blk(x)
117
+ hidden_states_out.append(x)
118
+ x = self.norm(x)
119
+ x = x.transpose(1, 2)
120
+ if return_emb:
121
+ return x
122
+ d = [s // p for s, p in zip(spatial_size, self.patch_size)]
123
+ x = torch.reshape(x, [x.shape[0], x.shape[1], *d])
124
+ x = self.conv3d_transpose(x)
125
+ x = self.conv3d_transpose_1(x)
126
+ if return_hiddens:
127
+ return x, hidden_states_out
128
+ return x
129
+
130
+ def get_last_selfattention(self, x):
131
+ """
132
+ Args:
133
+ x: input tensor must have isotropic spatial dimensions,
134
+ such as ``[batch_size, channels, sp_size, sp_size[, sp_size]]``.
135
+ """
136
+ x = self.patch_embedding(x)
137
+ ic(x.size())
138
+ for i, blk in enumerate(self.blocks):
139
+ if i < len(self.blocks) - 1:
140
+ x = blk(x)
141
+ x.size()
142
+ else:
143
+ return blk(x, return_attention=True)
144
+
145
+ def load(self, ckpt_path, map_location='cpu', checkpoint_key='state_dict'):
146
+ """
147
+ Args:
148
+ ckpt_path: path to the pretrained weights
149
+ map_location: device to load the checkpoint on
150
+ """
151
+ state_dict = torch.load(ckpt_path, map_location=map_location)
152
+ ic(state_dict['epoch'], state_dict['train_loss'])
153
+ if checkpoint_key in state_dict:
154
+ print(f"Take key {checkpoint_key} in provided checkpoint dict")
155
+ state_dict = state_dict[checkpoint_key]
156
+ # remove `module.` prefix
157
+ state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
158
+ # remove `backbone.` prefix induced by multicrop wrapper
159
+ state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
160
+ msg = self.load_state_dict(state_dict, strict=False)
161
+ print('Pretrained weights found at {} and loaded with msg: {}'.format(ckpt_path, msg))
162
+
163
+