ok
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- adrd/__init__.py +22 -0
- adrd/__pycache__/__init__.cpython-311.pyc +0 -0
- adrd/_ds/__init__.py +0 -0
- adrd/_ds/lddl.py +71 -0
- adrd/model/__init__.py +6 -0
- adrd/model/__pycache__/__init__.cpython-311.pyc +0 -0
- adrd/model/__pycache__/adrd_model.cpython-311.pyc +0 -0
- adrd/model/__pycache__/calibration.cpython-311.pyc +0 -0
- adrd/model/__pycache__/imaging_model.cpython-311.pyc +0 -0
- adrd/model/__pycache__/train_resnet.cpython-311.pyc +0 -0
- adrd/model/adrd_model.py +976 -0
- adrd/model/calibration.py +450 -0
- adrd/model/cnn_resnet3d_with_linear_classifier.py +533 -0
- adrd/model/imaging_model.py +843 -0
- adrd/model/train_resnet.py +484 -0
- adrd/model/transformer.py +600 -0
- adrd/nn/__init__.py +12 -0
- adrd/nn/__pycache__/__init__.cpython-311.pyc +0 -0
- adrd/nn/__pycache__/blocks.cpython-311.pyc +0 -0
- adrd/nn/__pycache__/c3d.cpython-311.pyc +0 -0
- adrd/nn/__pycache__/cnn_resnet3d.cpython-311.pyc +0 -0
- adrd/nn/__pycache__/cnn_resnet3d_with_linear_classifier.cpython-311.pyc +0 -0
- adrd/nn/__pycache__/dense_net.cpython-311.pyc +0 -0
- adrd/nn/__pycache__/focal_loss.cpython-311.pyc +0 -0
- adrd/nn/__pycache__/img_model_wrapper.cpython-311.pyc +0 -0
- adrd/nn/__pycache__/net_resnet3d.cpython-311.pyc +0 -0
- adrd/nn/__pycache__/resnet3d.cpython-311.pyc +0 -0
- adrd/nn/__pycache__/resnet_img_model.cpython-311.pyc +0 -0
- adrd/nn/__pycache__/selfattention.cpython-311.pyc +0 -0
- adrd/nn/__pycache__/transformer.cpython-311.pyc +0 -0
- adrd/nn/__pycache__/unet.cpython-311.pyc +0 -0
- adrd/nn/__pycache__/unet_3d.cpython-311.pyc +0 -0
- adrd/nn/__pycache__/unet_img_model.cpython-311.pyc +0 -0
- adrd/nn/__pycache__/vitautoenc.cpython-311.pyc +0 -0
- adrd/nn/blocks.py +57 -0
- adrd/nn/c3d.py +99 -0
- adrd/nn/cnn_resnet3d.py +81 -0
- adrd/nn/cnn_resnet3d_with_linear_classifier.py +56 -0
- adrd/nn/dense_net.py +211 -0
- adrd/nn/focal_loss.py +120 -0
- adrd/nn/img_model_wrapper.py +174 -0
- adrd/nn/net_resnet3d.py +338 -0
- adrd/nn/resnet3d.py +256 -0
- adrd/nn/resnet_img_model.py +81 -0
- adrd/nn/selfattention.py +62 -0
- adrd/nn/transformer.py +268 -0
- adrd/nn/unet.py +232 -0
- adrd/nn/unet_3d.py +63 -0
- adrd/nn/unet_img_model.py +211 -0
- adrd/nn/vitautoenc.py +163 -0
adrd/__init__.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__version__ = '0.0.1'
|
2 |
+
|
3 |
+
from . import nn
|
4 |
+
from . import model
|
5 |
+
|
6 |
+
# # load pretrained transformer
|
7 |
+
# pretrained_transformer = model.Transformer.from_ckpt('{}/ckpt/ckpt.pt'.format(__path__[0]))
|
8 |
+
# from . import shap_adrd
|
9 |
+
# from .model import DynamicCalibratedClassifier
|
10 |
+
# from .model import StaticCalibratedClassifier
|
11 |
+
|
12 |
+
# load fitted transformer and calibrated wrapper
|
13 |
+
# try:
|
14 |
+
# fitted_resnet3d = model.CNNResNet3DWithLinearClassifier.from_ckpt('{}/ckpt/ckpt_img_072523.pt'.format(__path__[0]))
|
15 |
+
# fitted_calibrated_classifier_nonimg = StaticCalibratedClassifier.from_ckpt(
|
16 |
+
# filepath_state_dict = '{}/ckpt/static_calibrated_classifier_073023.pkl'.format(__path__[0]),
|
17 |
+
# filepath_wrapped_model = '{}/ckpt/ckpt_080823.pt'.format(__path__[0]),
|
18 |
+
# )
|
19 |
+
# fitted_transformer_nonimg = fitted_calibrated_classifier_nonimg.model
|
20 |
+
# shap_explainer = shap_adrd.SamplingExplainer(fitted_transformer_nonimg)
|
21 |
+
# except:
|
22 |
+
# print('Fail to load checkpoints.')
|
adrd/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (288 Bytes). View file
|
|
adrd/_ds/__init__.py
ADDED
File without changes
|
adrd/_ds/lddl.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Self, overload
|
2 |
+
|
3 |
+
|
4 |
+
class lddl:
|
5 |
+
''' ... '''
|
6 |
+
def __init__(self) -> None:
|
7 |
+
''' ... '''
|
8 |
+
self.dat_ld: list[dict[str, Any]] = None
|
9 |
+
self.dat_dl: dict[str, list[Any]] = None
|
10 |
+
|
11 |
+
@overload
|
12 |
+
def __getitem__(self, idx: int) -> dict[str, Any]: ...
|
13 |
+
|
14 |
+
@overload
|
15 |
+
def __getitem__(self, idx: str) -> list[Any]: ...
|
16 |
+
|
17 |
+
def __getitem__(self, idx: str | int) -> list[Any] | dict[str, Any]:
|
18 |
+
''' ... '''
|
19 |
+
if isinstance(idx, str):
|
20 |
+
return self.dat_dl[idx]
|
21 |
+
elif isinstance(idx, int):
|
22 |
+
return self.dat_ld[idx]
|
23 |
+
else:
|
24 |
+
raise TypeError('Unexpected key type: {}'.format(type(idx)))
|
25 |
+
|
26 |
+
@classmethod
|
27 |
+
def from_ld(cls, dat: list[dict[str, Any]]) -> Self:
|
28 |
+
''' Construct from list of dicts. '''
|
29 |
+
obj = cls()
|
30 |
+
obj.dat_ld = dat
|
31 |
+
obj.dat_dl = {k: [dat[i][k] for i in range(len(dat))] for k in dat[0]}
|
32 |
+
return obj
|
33 |
+
|
34 |
+
@classmethod
|
35 |
+
def from_dl(cls, dat: dict[str, list[Any]]) -> Self:
|
36 |
+
''' Construct from dict of lists. '''
|
37 |
+
obj = cls()
|
38 |
+
obj.dat_ld = [dict(zip(dat, v)) for v in zip(*dat.values())]
|
39 |
+
obj.dat_dl = dat
|
40 |
+
return obj
|
41 |
+
|
42 |
+
|
43 |
+
if __name__ == '__main__':
|
44 |
+
''' for testing purpose only '''
|
45 |
+
dl = {
|
46 |
+
'a': [0, 1, 2],
|
47 |
+
'b': [3, 4, 5],
|
48 |
+
}
|
49 |
+
|
50 |
+
ld = [
|
51 |
+
{'a': 0, 'b': 1, 'c': 2},
|
52 |
+
{'a': 3, 'b': 4, 'c': 5},
|
53 |
+
]
|
54 |
+
|
55 |
+
# test constructing from ld
|
56 |
+
dat_ld = lddl.from_ld(ld)
|
57 |
+
print(dat_ld.dat_ld)
|
58 |
+
print(dat_ld.dat_dl)
|
59 |
+
|
60 |
+
# test constructing from dl
|
61 |
+
dat_dl = lddl.from_dl(dl)
|
62 |
+
print(dat_dl.dat_ld)
|
63 |
+
print(dat_dl.dat_dl)
|
64 |
+
|
65 |
+
# test __getitem__
|
66 |
+
print(dat_dl['a'])
|
67 |
+
print(dat_dl[0])
|
68 |
+
|
69 |
+
# mouse hover to check if type hints are correct
|
70 |
+
v = dat_dl['a']
|
71 |
+
v = dat_dl[0]
|
adrd/model/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .adrd_model import ADRDModel
|
2 |
+
from .imaging_model import ImagingModel
|
3 |
+
from .train_resnet import TrainResNet
|
4 |
+
# from .transformer import Transformer
|
5 |
+
from .calibration import DynamicCalibratedClassifier
|
6 |
+
from .calibration import StaticCalibratedClassifier
|
adrd/model/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (512 Bytes). View file
|
|
adrd/model/__pycache__/adrd_model.cpython-311.pyc
ADDED
Binary file (56.5 kB). View file
|
|
adrd/model/__pycache__/calibration.cpython-311.pyc
ADDED
Binary file (27.8 kB). View file
|
|
adrd/model/__pycache__/imaging_model.cpython-311.pyc
ADDED
Binary file (46.5 kB). View file
|
|
adrd/model/__pycache__/train_resnet.cpython-311.pyc
ADDED
Binary file (26.1 kB). View file
|
|
adrd/model/adrd_model.py
ADDED
@@ -0,0 +1,976 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__all__ = ['ADRDModel']
|
2 |
+
|
3 |
+
import wandb
|
4 |
+
import torch
|
5 |
+
from torch.utils.data import DataLoader
|
6 |
+
import numpy as np
|
7 |
+
import tqdm
|
8 |
+
import multiprocessing
|
9 |
+
from sklearn.base import BaseEstimator
|
10 |
+
from sklearn.utils.validation import check_is_fitted
|
11 |
+
from sklearn.model_selection import train_test_split
|
12 |
+
from scipy.special import expit
|
13 |
+
from copy import deepcopy
|
14 |
+
from contextlib import suppress
|
15 |
+
from typing import Any, Self, Type
|
16 |
+
from functools import wraps
|
17 |
+
from tqdm import tqdm
|
18 |
+
Tensor = Type[torch.Tensor]
|
19 |
+
Module = Type[torch.nn.Module]
|
20 |
+
|
21 |
+
# for DistributedDataParallel
|
22 |
+
import torch.distributed as dist
|
23 |
+
import torch.multiprocessing as mp
|
24 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
25 |
+
|
26 |
+
from .. import nn
|
27 |
+
from ..nn import Transformer
|
28 |
+
from ..utils import TransformerTrainingDataset, TransformerBalancedTrainingDataset, TransformerValidationDataset, TransformerTestingDataset, Transformer2ndOrderBalancedTrainingDataset
|
29 |
+
from ..utils.misc import ProgressBar
|
30 |
+
from ..utils.misc import get_metrics_multitask, print_metrics_multitask
|
31 |
+
from ..utils.misc import convert_args_kwargs_to_kwargs
|
32 |
+
|
33 |
+
|
34 |
+
def _manage_ctx_fit(func):
|
35 |
+
''' ... '''
|
36 |
+
@wraps(func)
|
37 |
+
def wrapper(*args, **kwargs):
|
38 |
+
# format arguments
|
39 |
+
kwargs = convert_args_kwargs_to_kwargs(func, args, kwargs)
|
40 |
+
|
41 |
+
if kwargs['self']._device_ids is None:
|
42 |
+
return func(**kwargs)
|
43 |
+
else:
|
44 |
+
# change primary device
|
45 |
+
default_device = kwargs['self'].device
|
46 |
+
kwargs['self'].device = kwargs['self']._device_ids[0]
|
47 |
+
rtn = func(**kwargs)
|
48 |
+
kwargs['self'].to(default_device)
|
49 |
+
return rtn
|
50 |
+
return wrapper
|
51 |
+
|
52 |
+
|
53 |
+
class ADRDModel(BaseEstimator):
|
54 |
+
"""Primary model class for ADRD framework.
|
55 |
+
|
56 |
+
The ADRDModel encapsulates the core pipeline of the ADRD framework,
|
57 |
+
permitting users to train and validate with the provided data. Designed for
|
58 |
+
user-friendly operation, the ADRDModel is derived from
|
59 |
+
``sklearn.base.BaseEstimator``, ensuring compliance with the well-established
|
60 |
+
API design conventions of scikit-learn.
|
61 |
+
"""
|
62 |
+
def __init__(self,
|
63 |
+
src_modalities: dict[str, dict[str, Any]],
|
64 |
+
tgt_modalities: dict[str, dict[str, Any]],
|
65 |
+
label_fractions: dict[str, float],
|
66 |
+
d_model: int = 32,
|
67 |
+
nhead: int = 1,
|
68 |
+
num_encoder_layers: int = 1,
|
69 |
+
num_decoder_layers: int = 1,
|
70 |
+
num_epochs: int = 32,
|
71 |
+
batch_size: int = 8,
|
72 |
+
batch_size_multiplier: int = 1,
|
73 |
+
lr: float = 1e-2,
|
74 |
+
weight_decay: float = 0.0,
|
75 |
+
beta: float = 0.9999,
|
76 |
+
gamma: float = 2.0,
|
77 |
+
criterion: str | None = None,
|
78 |
+
device: str = 'cpu',
|
79 |
+
cuda_devices: list = [1],
|
80 |
+
img_net: str | None = None,
|
81 |
+
imgnet_layers: int | None = 2,
|
82 |
+
img_size: int | None = 128,
|
83 |
+
fusion_stage: str = 'middle',
|
84 |
+
patch_size: int | None = 16,
|
85 |
+
imgnet_ckpt: str | None = None,
|
86 |
+
train_imgnet: bool = False,
|
87 |
+
ckpt_path: str = './adrd_tool/adrd/dev/ckpt/ckpt.pt',
|
88 |
+
load_from_ckpt: bool = False,
|
89 |
+
save_intermediate_ckpts: bool = False,
|
90 |
+
data_parallel: bool = False,
|
91 |
+
verbose: int = 0,
|
92 |
+
wandb_ = 0,
|
93 |
+
balanced_sampling: bool = False,
|
94 |
+
label_distribution: dict = {},
|
95 |
+
ranking_loss: bool = False,
|
96 |
+
_device_ids: list | None = None,
|
97 |
+
|
98 |
+
_dataloader_num_workers: int = 4,
|
99 |
+
_amp_enabled: bool = False,
|
100 |
+
) -> None:
|
101 |
+
"""Create a new ADRD model.
|
102 |
+
|
103 |
+
:param src_modalities: _description_
|
104 |
+
:type src_modalities: dict[str, dict[str, Any]]
|
105 |
+
:param tgt_modalities: _description_
|
106 |
+
:type tgt_modalities: dict[str, dict[str, Any]]
|
107 |
+
:param label_fractions: _description_
|
108 |
+
:type label_fractions: dict[str, float]
|
109 |
+
:param d_model: _description_, defaults to 32
|
110 |
+
:type d_model: int, optional
|
111 |
+
:param nhead: number of transformer heads, defaults to 1
|
112 |
+
:type nhead: int, optional
|
113 |
+
:param num_encoder_layers: number of encoder layers, defaults to 1
|
114 |
+
:type num_encoder_layers: int, optional
|
115 |
+
:param num_decoder_layers: number of decoder layers, defaults to 1
|
116 |
+
:type num_decoder_layers: int, optional
|
117 |
+
:param num_epochs: number of training epochs, defaults to 32
|
118 |
+
:type num_epochs: int, optional
|
119 |
+
:param batch_size: batch size, defaults to 8
|
120 |
+
:type batch_size: int, optional
|
121 |
+
:param batch_size_multiplier: _description_, defaults to 1
|
122 |
+
:type batch_size_multiplier: int, optional
|
123 |
+
:param lr: learning rate, defaults to 1e-2
|
124 |
+
:type lr: float, optional
|
125 |
+
:param weight_decay: _description_, defaults to 0.0
|
126 |
+
:type weight_decay: float, optional
|
127 |
+
:param beta: _description_, defaults to 0.9999
|
128 |
+
:type beta: float, optional
|
129 |
+
:param gamma: The focusing parameter for the focal loss. Higher values of gamma make easy-to-classify examples contribute less to the loss relative to hard-to-classify examples. Must be non-negative., defaults to 2.0
|
130 |
+
:type gamma: float, optional
|
131 |
+
:param criterion: The criterion to select the best model, defaults to None
|
132 |
+
:type criterion: str | None, optional
|
133 |
+
:param device: 'cuda' or 'cpu', defaults to 'cpu'
|
134 |
+
:type device: str, optional
|
135 |
+
:param cuda_devices: A list of gpu numbers to data parallel training. The device must be set to 'cuda' and data_parallel must be set to True, defaults to [1]
|
136 |
+
:type cuda_devices: list, optional
|
137 |
+
:param img_net: _description_, defaults to None
|
138 |
+
:type img_net: str | None, optional
|
139 |
+
:param imgnet_layers: _description_, defaults to 2
|
140 |
+
:type imgnet_layers: int | None, optional
|
141 |
+
:param img_size: _description_, defaults to 128
|
142 |
+
:type img_size: int | None, optional
|
143 |
+
:param fusion_stage: _description_, defaults to 'middle'
|
144 |
+
:type fusion_stage: str, optional
|
145 |
+
:param patch_size: _description_, defaults to 16
|
146 |
+
:type patch_size: int | None, optional
|
147 |
+
:param imgnet_ckpt: _description_, defaults to None
|
148 |
+
:type imgnet_ckpt: str | None, optional
|
149 |
+
:param train_imgnet: Set to True to finetune the img_net backbone, defaults to False
|
150 |
+
:type train_imgnet: bool, optional
|
151 |
+
:param ckpt_path: The model checkpoint point path, defaults to './adrd_tool/adrd/dev/ckpt/ckpt.pt'
|
152 |
+
:type ckpt_path: str, optional
|
153 |
+
:param load_from_ckpt: Set to True to load the model weights from checkpoint ckpt_path, defaults to False
|
154 |
+
:type load_from_ckpt: bool, optional
|
155 |
+
:param save_intermediate_ckpts: Set to True to save intermediate model checkpoints, defaults to False
|
156 |
+
:type save_intermediate_ckpts: bool, optional
|
157 |
+
:param data_parallel: Set to True to enable data parallel trsining, defaults to False
|
158 |
+
:type data_parallel: bool, optional
|
159 |
+
:param verbose: _description_, defaults to 0
|
160 |
+
:type verbose: int, optional
|
161 |
+
:param wandb_: Set to 1 to track the loss and accuracy curves on wandb, defaults to 0
|
162 |
+
:type wandb_: int, optional
|
163 |
+
:param balanced_sampling: _description_, defaults to False
|
164 |
+
:type balanced_sampling: bool, optional
|
165 |
+
:param label_distribution: _description_, defaults to {}
|
166 |
+
:type label_distribution: dict, optional
|
167 |
+
:param ranking_loss: _description_, defaults to False
|
168 |
+
:type ranking_loss: bool, optional
|
169 |
+
:param _device_ids: _description_, defaults to None
|
170 |
+
:type _device_ids: list | None, optional
|
171 |
+
:param _dataloader_num_workers: _description_, defaults to 4
|
172 |
+
:type _dataloader_num_workers: int, optional
|
173 |
+
:param _amp_enabled: _description_, defaults to False
|
174 |
+
:type _amp_enabled: bool, optional
|
175 |
+
"""
|
176 |
+
# for multiprocessing
|
177 |
+
self._rank = 0
|
178 |
+
self._lock = None
|
179 |
+
|
180 |
+
# positional parameters
|
181 |
+
self.src_modalities = src_modalities
|
182 |
+
self.tgt_modalities = tgt_modalities
|
183 |
+
|
184 |
+
# training parameters
|
185 |
+
self.label_fractions = label_fractions
|
186 |
+
self.d_model = d_model
|
187 |
+
self.nhead = nhead
|
188 |
+
self.num_encoder_layers = num_encoder_layers
|
189 |
+
self.num_decoder_layers = num_decoder_layers
|
190 |
+
self.num_epochs = num_epochs
|
191 |
+
self.batch_size = batch_size
|
192 |
+
self.batch_size_multiplier = batch_size_multiplier
|
193 |
+
self.lr = lr
|
194 |
+
self.weight_decay = weight_decay
|
195 |
+
self.beta = beta
|
196 |
+
self.gamma = gamma
|
197 |
+
self.criterion = criterion
|
198 |
+
self.device = device
|
199 |
+
self.cuda_devices = cuda_devices
|
200 |
+
self.img_net = img_net
|
201 |
+
self.patch_size = patch_size
|
202 |
+
self.img_size = img_size
|
203 |
+
self.fusion_stage = fusion_stage
|
204 |
+
self.imgnet_ckpt = imgnet_ckpt
|
205 |
+
self.imgnet_layers = imgnet_layers
|
206 |
+
self.train_imgnet = train_imgnet
|
207 |
+
self.ckpt_path = ckpt_path
|
208 |
+
self.load_from_ckpt = load_from_ckpt
|
209 |
+
self.save_intermediate_ckpts = save_intermediate_ckpts
|
210 |
+
self.data_parallel = data_parallel
|
211 |
+
self.verbose = verbose
|
212 |
+
self.label_distribution = label_distribution
|
213 |
+
self.wandb_ = wandb_
|
214 |
+
self.balanced_sampling = balanced_sampling
|
215 |
+
self.ranking_loss = ranking_loss
|
216 |
+
self._device_ids = _device_ids
|
217 |
+
self._dataloader_num_workers = _dataloader_num_workers
|
218 |
+
self._amp_enabled = _amp_enabled
|
219 |
+
self.scaler = torch.cuda.amp.GradScaler()
|
220 |
+
# self._init_net()
|
221 |
+
|
222 |
+
@_manage_ctx_fit
|
223 |
+
def fit(self, x_trn, x_vld, y_trn, y_vld, img_train_trans=None, img_vld_trans=None, img_mode=0) -> Self:
|
224 |
+
# def fit(self, x, y) -> Self:
|
225 |
+
''' ... '''
|
226 |
+
|
227 |
+
# start a new wandb run to track this script
|
228 |
+
if self.wandb_ == 1:
|
229 |
+
wandb.init(
|
230 |
+
# set the wandb project where this run will be logged
|
231 |
+
project="ADRD_main",
|
232 |
+
|
233 |
+
# track hyperparameters and run metadata
|
234 |
+
config={
|
235 |
+
"Loss": 'Focalloss',
|
236 |
+
"ranking_loss": self.ranking_loss,
|
237 |
+
"img architecture": self.img_net,
|
238 |
+
"EMB": "ALL_SEQ",
|
239 |
+
"epochs": self.num_epochs,
|
240 |
+
"d_model": self.d_model,
|
241 |
+
# 'positional encoding': 'Diff PE',
|
242 |
+
'Balanced Sampling': self.balanced_sampling,
|
243 |
+
'Shared CNN': 'Yes',
|
244 |
+
}
|
245 |
+
)
|
246 |
+
wandb.run.log_code(".")
|
247 |
+
else:
|
248 |
+
wandb.init(mode="disabled")
|
249 |
+
# for PyTorch computational efficiency
|
250 |
+
torch.set_num_threads(1)
|
251 |
+
# print(img_train_trans)
|
252 |
+
# initialize neural network
|
253 |
+
print(self.criterion)
|
254 |
+
print(f"Ranking loss: {self.ranking_loss}")
|
255 |
+
print(f"Batch size: {self.batch_size}")
|
256 |
+
print(f"Batch size multiplier: {self.batch_size_multiplier}")
|
257 |
+
|
258 |
+
if img_mode in [0,1,2]:
|
259 |
+
for k, info in self.src_modalities.items():
|
260 |
+
if info['type'] == 'imaging':
|
261 |
+
if 'densenet' in self.img_net.lower() and 'emb' not in self.img_net.lower():
|
262 |
+
info['shape'] = (1,) + self.img_size
|
263 |
+
info['img_shape'] = (1,) + self.img_size
|
264 |
+
elif 'emb' not in self.img_net.lower():
|
265 |
+
info['shape'] = (1,) + (self.img_size,) * 3
|
266 |
+
info['img_shape'] = (1,) + (self.img_size,) * 3
|
267 |
+
elif 'swinunetr' in self.img_net.lower():
|
268 |
+
info['shape'] = (1, 768, 4, 4, 4)
|
269 |
+
info['img_shape'] = (1, 768, 4, 4, 4)
|
270 |
+
|
271 |
+
|
272 |
+
|
273 |
+
self._init_net()
|
274 |
+
ldr_trn, ldr_vld = self._init_dataloader(x_trn, x_vld, y_trn, y_vld, img_train_trans=img_train_trans, img_vld_trans=img_vld_trans)
|
275 |
+
|
276 |
+
# initialize optimizer and scheduler
|
277 |
+
if not self.load_from_ckpt:
|
278 |
+
self.optimizer = self._init_optimizer()
|
279 |
+
self.scheduler = self._init_scheduler(self.optimizer)
|
280 |
+
|
281 |
+
# gradient scaler for AMP
|
282 |
+
if self._amp_enabled:
|
283 |
+
self.scaler = torch.cuda.amp.GradScaler()
|
284 |
+
|
285 |
+
# initialize the focal losses
|
286 |
+
self.loss_fn = {}
|
287 |
+
|
288 |
+
for k in self.tgt_modalities:
|
289 |
+
if self.label_fractions[k] >= 0.3:
|
290 |
+
alpha = -1
|
291 |
+
else:
|
292 |
+
alpha = pow((1 - self.label_fractions[k]), 2)
|
293 |
+
# alpha = -1
|
294 |
+
self.loss_fn[k] = nn.SigmoidFocalLoss(
|
295 |
+
alpha = alpha,
|
296 |
+
gamma = self.gamma,
|
297 |
+
reduction = 'none'
|
298 |
+
)
|
299 |
+
|
300 |
+
# to record the best validation performance criterion
|
301 |
+
if self.criterion is not None:
|
302 |
+
best_crit = None
|
303 |
+
best_crit_AUPR = None
|
304 |
+
|
305 |
+
# progress bar for epoch loops
|
306 |
+
if self.verbose == 1:
|
307 |
+
with self._lock if self._lock is not None else suppress():
|
308 |
+
pbr_epoch = tqdm.tqdm(
|
309 |
+
desc = 'Rank {:02d}'.format(self._rank),
|
310 |
+
total = self.num_epochs,
|
311 |
+
position = self._rank,
|
312 |
+
ascii = True,
|
313 |
+
leave = False,
|
314 |
+
bar_format='{l_bar}{r_bar}'
|
315 |
+
)
|
316 |
+
|
317 |
+
self.skip_embedding = {}
|
318 |
+
for k, info in self.src_modalities.items():
|
319 |
+
# if info['type'] == 'imaging':
|
320 |
+
# if not self.img_net:
|
321 |
+
# self.skip_embedding[k] = True
|
322 |
+
# else:
|
323 |
+
self.skip_embedding[k] = False
|
324 |
+
|
325 |
+
self.grad_list = []
|
326 |
+
# Define a hook function to print and store the gradient of a layer
|
327 |
+
def print_and_store_grad(grad):
|
328 |
+
self.grad_list.append(grad)
|
329 |
+
# print(grad)
|
330 |
+
|
331 |
+
|
332 |
+
# initialize the ranking loss
|
333 |
+
self.lambda_coeff = 0.005
|
334 |
+
self.margin = 0.25
|
335 |
+
self.margin_loss = torch.nn.MarginRankingLoss(reduction='sum', margin=self.margin)
|
336 |
+
|
337 |
+
# training loop
|
338 |
+
for epoch in range(self.start_epoch, self.num_epochs):
|
339 |
+
met_trn = self.train_one_epoch(ldr_trn, epoch)
|
340 |
+
met_vld = self.validate_one_epoch(ldr_vld, epoch)
|
341 |
+
|
342 |
+
print(self.ckpt_path.split('/')[-1])
|
343 |
+
|
344 |
+
# save the model if it has the best validation performance criterion by far
|
345 |
+
if self.criterion is None: continue
|
346 |
+
|
347 |
+
# is current criterion better than previous best?
|
348 |
+
curr_crit = np.mean([met_vld[i][self.criterion] for i in range(len(self.tgt_modalities))])
|
349 |
+
curr_crit_AUPR = np.mean([met_vld[i]["AUC (PR)"] for i in range(len(self.tgt_modalities))])
|
350 |
+
# AUROC
|
351 |
+
if best_crit is None or np.isnan(best_crit):
|
352 |
+
is_better = True
|
353 |
+
elif self.criterion == 'Loss' and best_crit >= curr_crit:
|
354 |
+
is_better = True
|
355 |
+
elif self.criterion != 'Loss' and best_crit <= curr_crit :
|
356 |
+
is_better = True
|
357 |
+
else:
|
358 |
+
is_better = False
|
359 |
+
|
360 |
+
# AUPR
|
361 |
+
if best_crit_AUPR is None or np.isnan(best_crit_AUPR):
|
362 |
+
is_better_AUPR = True
|
363 |
+
elif best_crit_AUPR <= curr_crit_AUPR :
|
364 |
+
is_better_AUPR = True
|
365 |
+
else:
|
366 |
+
is_better_AUPR = False
|
367 |
+
# update best criterion
|
368 |
+
if is_better_AUPR:
|
369 |
+
best_crit_AUPR = curr_crit_AUPR
|
370 |
+
if self.save_intermediate_ckpts:
|
371 |
+
print(f"Saving the model to {self.ckpt_path[:-3]}_AUPR.pt...")
|
372 |
+
self.save(self.ckpt_path[:-3]+"_AUPR.pt", epoch)
|
373 |
+
if is_better:
|
374 |
+
best_crit = curr_crit
|
375 |
+
best_state_dict = deepcopy(self.net_.state_dict())
|
376 |
+
if self.save_intermediate_ckpts:
|
377 |
+
print(f"Saving the model to {self.ckpt_path}...")
|
378 |
+
self.save(self.ckpt_path, epoch)
|
379 |
+
|
380 |
+
if self.verbose > 2:
|
381 |
+
print('Best {}: {}'.format(self.criterion, best_crit))
|
382 |
+
print('Best {}: {}'.format('AUC (PR)', best_crit_AUPR))
|
383 |
+
|
384 |
+
if self.verbose == 1:
|
385 |
+
with self._lock if self._lock is not None else suppress():
|
386 |
+
pbr_epoch.update(1)
|
387 |
+
pbr_epoch.refresh()
|
388 |
+
|
389 |
+
if self.verbose == 1:
|
390 |
+
with self._lock if self._lock is not None else suppress():
|
391 |
+
pbr_epoch.close()
|
392 |
+
|
393 |
+
return self
|
394 |
+
|
395 |
+
def train_one_epoch(self, ldr_trn, epoch):
|
396 |
+
# progress bar for batch loops
|
397 |
+
if self.verbose > 1:
|
398 |
+
pbr_batch = ProgressBar(len(ldr_trn.dataset), 'Epoch {:03d} (TRN)'.format(epoch))
|
399 |
+
|
400 |
+
# set model to train mode
|
401 |
+
torch.set_grad_enabled(True)
|
402 |
+
self.net_.train()
|
403 |
+
|
404 |
+
scores_trn, y_true_trn, y_mask_trn = [], [], []
|
405 |
+
losses_trn = [[] for _ in self.tgt_modalities]
|
406 |
+
iters = len(ldr_trn)
|
407 |
+
for n_iter, (x_batch, y_batch, mask, y_mask) in enumerate(ldr_trn):
|
408 |
+
|
409 |
+
# mount data to the proper device
|
410 |
+
x_batch = {k: x_batch[k].to(self.device) for k in x_batch}
|
411 |
+
y_batch = {k: y_batch[k].to(torch.float).to(self.device) for k in y_batch}
|
412 |
+
mask = {k: mask[k].to(self.device) for k in mask}
|
413 |
+
y_mask = {k: y_mask[k].to(self.device) for k in y_mask}
|
414 |
+
|
415 |
+
with torch.autocast(
|
416 |
+
device_type = 'cpu' if self.device == 'cpu' else 'cuda',
|
417 |
+
dtype = torch.bfloat16 if self.device == 'cpu' else torch.float16,
|
418 |
+
enabled = self._amp_enabled,
|
419 |
+
):
|
420 |
+
|
421 |
+
outputs = self.net_(x_batch, mask, skip_embedding=self.skip_embedding)
|
422 |
+
|
423 |
+
# calculate multitask loss
|
424 |
+
loss = 0
|
425 |
+
|
426 |
+
# for initial 10 epochs, only the focal loss is used for stable training
|
427 |
+
if self.ranking_loss:
|
428 |
+
if epoch < 10:
|
429 |
+
loss = 0
|
430 |
+
else:
|
431 |
+
for i, k in enumerate(self.tgt_modalities):
|
432 |
+
for ii, kk in enumerate(self.tgt_modalities):
|
433 |
+
if ii>i:
|
434 |
+
pairs = (y_mask[k] == 1) & (y_mask[kk] == 1)
|
435 |
+
total_elements = (torch.abs(y_batch[k][pairs]-y_batch[kk][pairs])).sum()
|
436 |
+
if total_elements != 0:
|
437 |
+
loss += self.lambda_coeff * (self.margin_loss(torch.sigmoid(outputs[k])[pairs],torch.sigmoid(outputs[kk][pairs]),y_batch[k][pairs]-y_batch[kk][pairs]))/total_elements
|
438 |
+
|
439 |
+
for i, k in enumerate(self.tgt_modalities):
|
440 |
+
loss_task = self.loss_fn[k](outputs[k], y_batch[k])
|
441 |
+
msk_loss_task = loss_task * y_mask[k]
|
442 |
+
msk_loss_mean = msk_loss_task.sum() / y_mask[k].sum()
|
443 |
+
# msk_loss_mean = msk_loss_task.sum()
|
444 |
+
loss += msk_loss_mean
|
445 |
+
losses_trn[i] += msk_loss_task.detach().cpu().numpy().tolist()
|
446 |
+
|
447 |
+
# backward
|
448 |
+
loss = loss / self.batch_size_multiplier
|
449 |
+
if self._amp_enabled:
|
450 |
+
self.scaler.scale(loss).backward()
|
451 |
+
else:
|
452 |
+
loss.backward()
|
453 |
+
|
454 |
+
if len(self.grad_list) > 0:
|
455 |
+
print(len(self.grad_list), len(self.grad_list[-1]))
|
456 |
+
print(f"Gradient at {n_iter}: {self.grad_list[-1][0]}")
|
457 |
+
|
458 |
+
# print("img_MRI_T1_1 ", self.net_.modules_emb_src.img_MRI_T1_1.img_model.features[0].weight)
|
459 |
+
# print("img_MRI_T1_1 ", self.net_.modules_emb_src.img_MRI_T1_1.downsample[0].weight)
|
460 |
+
|
461 |
+
# update parameters
|
462 |
+
if n_iter != 0 and n_iter % self.batch_size_multiplier == 0:
|
463 |
+
if self._amp_enabled:
|
464 |
+
self.scaler.step(self.optimizer)
|
465 |
+
self.scaler.update()
|
466 |
+
self.optimizer.zero_grad()
|
467 |
+
else:
|
468 |
+
self.optimizer.step()
|
469 |
+
self.optimizer.zero_grad()
|
470 |
+
|
471 |
+
# set self.scheduler
|
472 |
+
self.scheduler.step(epoch + n_iter / iters)
|
473 |
+
|
474 |
+
''' TODO: change array to dictionary later '''
|
475 |
+
outputs = torch.stack(list(outputs.values()), dim=1)
|
476 |
+
y_batch = torch.stack(list(y_batch.values()), dim=1)
|
477 |
+
y_mask = torch.stack(list(y_mask.values()), dim=1)
|
478 |
+
|
479 |
+
# save outputs to evaluate performance later
|
480 |
+
scores_trn.append(outputs.detach().to(torch.float).cpu())
|
481 |
+
y_true_trn.append(y_batch.cpu())
|
482 |
+
y_mask_trn.append(y_mask.cpu())
|
483 |
+
|
484 |
+
# update progress bar
|
485 |
+
if self.verbose > 1:
|
486 |
+
batch_size = len(next(iter(x_batch.values())))
|
487 |
+
pbr_batch.update(batch_size, {})
|
488 |
+
pbr_batch.refresh()
|
489 |
+
|
490 |
+
# clear cuda cache
|
491 |
+
if "cuda" in self.device:
|
492 |
+
torch.cuda.empty_cache()
|
493 |
+
|
494 |
+
# for better tqdm progress bar display
|
495 |
+
if self.verbose > 1:
|
496 |
+
pbr_batch.close()
|
497 |
+
|
498 |
+
# calculate and print training performance metrics
|
499 |
+
scores_trn = torch.cat(scores_trn)
|
500 |
+
y_true_trn = torch.cat(y_true_trn)
|
501 |
+
y_mask_trn = torch.cat(y_mask_trn)
|
502 |
+
y_pred_trn = (scores_trn > 0).to(torch.int)
|
503 |
+
y_prob_trn = torch.sigmoid(scores_trn)
|
504 |
+
met_trn = get_metrics_multitask(
|
505 |
+
y_true_trn.numpy(),
|
506 |
+
y_pred_trn.numpy(),
|
507 |
+
y_prob_trn.numpy(),
|
508 |
+
y_mask_trn.numpy()
|
509 |
+
)
|
510 |
+
|
511 |
+
# add loss to metrics
|
512 |
+
for i in range(len(self.tgt_modalities)):
|
513 |
+
met_trn[i]['Loss'] = np.mean(losses_trn[i])
|
514 |
+
|
515 |
+
# log metrics to wandb
|
516 |
+
wandb.log({f"Train loss {list(self.tgt_modalities)[i]}": met_trn[i]['Loss'] for i in range(len(self.tgt_modalities))}, step=epoch)
|
517 |
+
wandb.log({f"Train Balanced Accuracy {list(self.tgt_modalities)[i]}": met_trn[i]['Balanced Accuracy'] for i in range(len(self.tgt_modalities))}, step=epoch)
|
518 |
+
|
519 |
+
wandb.log({f"Train AUC (ROC) {list(self.tgt_modalities)[i]}": met_trn[i]['AUC (ROC)'] for i in range(len(self.tgt_modalities))}, step=epoch)
|
520 |
+
wandb.log({f"Train AUPR {list(self.tgt_modalities)[i]}": met_trn[i]['AUC (PR)'] for i in range(len(self.tgt_modalities))}, step=epoch)
|
521 |
+
|
522 |
+
if self.verbose > 2:
|
523 |
+
print_metrics_multitask(met_trn)
|
524 |
+
|
525 |
+
return met_trn
|
526 |
+
|
527 |
+
def validate_one_epoch(self, ldr_vld, epoch):
|
528 |
+
# # progress bar for validation
|
529 |
+
if self.verbose > 1:
|
530 |
+
pbr_batch = ProgressBar(len(ldr_vld.dataset), 'Epoch {:03d} (VLD)'.format(epoch))
|
531 |
+
|
532 |
+
# set model to validation mode
|
533 |
+
torch.set_grad_enabled(False)
|
534 |
+
self.net_.eval()
|
535 |
+
|
536 |
+
scores_vld, y_true_vld, y_mask_vld = [], [], []
|
537 |
+
losses_vld = [[] for _ in self.tgt_modalities]
|
538 |
+
for x_batch, y_batch, mask, y_mask in ldr_vld:
|
539 |
+
# if len(next(iter(x_batch.values()))) < self.batch_size:
|
540 |
+
# break
|
541 |
+
# mount data to the proper device
|
542 |
+
x_batch = {k: x_batch[k].to(self.device) for k in x_batch} # if 'img' not in k}
|
543 |
+
# x_img_batch = {k: x_img_batch[k].to(self.device) for k in x_img_batch}
|
544 |
+
y_batch = {k: y_batch[k].to(torch.float).to(self.device) for k in y_batch}
|
545 |
+
mask = {k: mask[k].to(self.device) for k in mask}
|
546 |
+
y_mask = {k: y_mask[k].to(self.device) for k in y_mask}
|
547 |
+
|
548 |
+
# forward
|
549 |
+
with torch.autocast(
|
550 |
+
device_type = 'cpu' if self.device == 'cpu' else 'cuda',
|
551 |
+
dtype = torch.bfloat16 if self.device == 'cpu' else torch.float16,
|
552 |
+
enabled = self._amp_enabled
|
553 |
+
):
|
554 |
+
|
555 |
+
outputs = self.net_(x_batch, mask, skip_embedding=self.skip_embedding)
|
556 |
+
|
557 |
+
# calculate multitask loss
|
558 |
+
for i, k in enumerate(self.tgt_modalities):
|
559 |
+
loss_task = self.loss_fn[k](outputs[k], y_batch[k])
|
560 |
+
msk_loss_task = loss_task * y_mask[k]
|
561 |
+
losses_vld[i] += msk_loss_task.detach().cpu().numpy().tolist()
|
562 |
+
|
563 |
+
''' TODO: change array to dictionary later '''
|
564 |
+
outputs = torch.stack(list(outputs.values()), dim=1)
|
565 |
+
y_batch = torch.stack(list(y_batch.values()), dim=1)
|
566 |
+
y_mask = torch.stack(list(y_mask.values()), dim=1)
|
567 |
+
|
568 |
+
# save outputs to evaluate performance later
|
569 |
+
scores_vld.append(outputs.detach().to(torch.float).cpu())
|
570 |
+
y_true_vld.append(y_batch.cpu())
|
571 |
+
y_mask_vld.append(y_mask.cpu())
|
572 |
+
|
573 |
+
# update progress bar
|
574 |
+
if self.verbose > 1:
|
575 |
+
batch_size = len(next(iter(x_batch.values())))
|
576 |
+
pbr_batch.update(batch_size, {})
|
577 |
+
pbr_batch.refresh()
|
578 |
+
|
579 |
+
# clear cuda cache
|
580 |
+
if "cuda" in self.device:
|
581 |
+
torch.cuda.empty_cache()
|
582 |
+
|
583 |
+
# for better tqdm progress bar display
|
584 |
+
if self.verbose > 1:
|
585 |
+
pbr_batch.close()
|
586 |
+
|
587 |
+
# calculate and print validation performance metrics
|
588 |
+
scores_vld = torch.cat(scores_vld)
|
589 |
+
y_true_vld = torch.cat(y_true_vld)
|
590 |
+
y_mask_vld = torch.cat(y_mask_vld)
|
591 |
+
y_pred_vld = (scores_vld > 0).to(torch.int)
|
592 |
+
y_prob_vld = torch.sigmoid(scores_vld)
|
593 |
+
met_vld = get_metrics_multitask(
|
594 |
+
y_true_vld.numpy(),
|
595 |
+
y_pred_vld.numpy(),
|
596 |
+
y_prob_vld.numpy(),
|
597 |
+
y_mask_vld.numpy()
|
598 |
+
)
|
599 |
+
|
600 |
+
# add loss to metrics
|
601 |
+
for i in range(len(self.tgt_modalities)):
|
602 |
+
met_vld[i]['Loss'] = np.mean(losses_vld[i])
|
603 |
+
|
604 |
+
wandb.log({f"Validation loss {list(self.tgt_modalities)[i]}": met_vld[i]['Loss'] for i in range(len(self.tgt_modalities))}, step=epoch)
|
605 |
+
wandb.log({f"Validation Balanced Accuracy {list(self.tgt_modalities)[i]}": met_vld[i]['Balanced Accuracy'] for i in range(len(self.tgt_modalities))}, step=epoch)
|
606 |
+
|
607 |
+
wandb.log({f"Validation AUC (ROC) {list(self.tgt_modalities)[i]}": met_vld[i]['AUC (ROC)'] for i in range(len(self.tgt_modalities))}, step=epoch)
|
608 |
+
wandb.log({f"Validation AUPR {list(self.tgt_modalities)[i]}": met_vld[i]['AUC (PR)'] for i in range(len(self.tgt_modalities))}, step=epoch)
|
609 |
+
|
610 |
+
if self.verbose > 2:
|
611 |
+
print_metrics_multitask(met_vld)
|
612 |
+
|
613 |
+
return met_vld
|
614 |
+
|
615 |
+
|
616 |
+
def predict_logits(self,
|
617 |
+
x: list[dict[str, Any]],
|
618 |
+
_batch_size: int | None = None,
|
619 |
+
skip_embedding: dict | None = None,
|
620 |
+
img_transform: Any | None = None,
|
621 |
+
) -> list[dict[str, float]]:
|
622 |
+
'''
|
623 |
+
The input x can be a single sample or a list of samples.
|
624 |
+
'''
|
625 |
+
# input validation
|
626 |
+
check_is_fitted(self)
|
627 |
+
print(self.device)
|
628 |
+
|
629 |
+
# for PyTorch computational efficiency
|
630 |
+
torch.set_num_threads(1)
|
631 |
+
|
632 |
+
# set model to eval mode
|
633 |
+
torch.set_grad_enabled(False)
|
634 |
+
self.net_.eval()
|
635 |
+
|
636 |
+
# intialize dataset and dataloader object
|
637 |
+
dat = TransformerTestingDataset(x, self.src_modalities, img_transform=img_transform)
|
638 |
+
ldr = DataLoader(
|
639 |
+
dataset = dat,
|
640 |
+
batch_size = _batch_size if _batch_size is not None else len(x),
|
641 |
+
shuffle = False,
|
642 |
+
drop_last = False,
|
643 |
+
num_workers = 0,
|
644 |
+
collate_fn = TransformerTestingDataset.collate_fn,
|
645 |
+
)
|
646 |
+
# print("dataloader done")
|
647 |
+
|
648 |
+
# run model and collect results
|
649 |
+
logits: list[dict[str, float]] = []
|
650 |
+
for x_batch, mask in tqdm(ldr):
|
651 |
+
# mount data to the proper device
|
652 |
+
# print(x_batch['his_SEX'])
|
653 |
+
x_batch = {k: x_batch[k].to(self.device) for k in x_batch}
|
654 |
+
mask = {k: mask[k].to(self.device) for k in mask}
|
655 |
+
|
656 |
+
# forward
|
657 |
+
output: dict[str, Tensor] = self.net_(x_batch, mask, skip_embedding)
|
658 |
+
|
659 |
+
# convert output from dict-of-list to list of dict, then append
|
660 |
+
tmp = {k: output[k].tolist() for k in self.tgt_modalities}
|
661 |
+
tmp = [{k: tmp[k][i] for k in self.tgt_modalities} for i in range(len(next(iter(tmp.values()))))]
|
662 |
+
logits += tmp
|
663 |
+
|
664 |
+
return logits
|
665 |
+
|
666 |
+
def predict_proba(self,
|
667 |
+
x: list[dict[str, Any]],
|
668 |
+
skip_embedding: dict | None = None,
|
669 |
+
temperature: float = 1.0,
|
670 |
+
_batch_size: int | None = None,
|
671 |
+
img_transform: Any | None = None,
|
672 |
+
) -> list[dict[str, float]]:
|
673 |
+
''' ... '''
|
674 |
+
logits = self.predict_logits(x=x, _batch_size=_batch_size, img_transform=img_transform, skip_embedding=skip_embedding)
|
675 |
+
print("got logits")
|
676 |
+
return logits, [{k: expit(smp[k] / temperature) for k in self.tgt_modalities} for smp in logits]
|
677 |
+
|
678 |
+
def predict(self,
|
679 |
+
x: list[dict[str, Any]],
|
680 |
+
skip_embedding: dict | None = None,
|
681 |
+
fpr: dict[str, Any] | None = None,
|
682 |
+
tpr: dict[str, Any] | None = None,
|
683 |
+
thresholds: dict[str, Any] | None = None,
|
684 |
+
_batch_size: int | None = None,
|
685 |
+
img_transform: Any | None = None,
|
686 |
+
) -> list[dict[str, int]]:
|
687 |
+
''' ... '''
|
688 |
+
if fpr is None or tpr is None or thresholds is None:
|
689 |
+
logits, proba = self.predict_proba(x, _batch_size=_batch_size, img_transform=img_transform, skip_embedding=skip_embedding)
|
690 |
+
print("got proba")
|
691 |
+
return logits, proba, [{k: int(smp[k] > 0.5) for k in self.tgt_modalities} for smp in proba]
|
692 |
+
else:
|
693 |
+
logits, proba = self.predict_proba(x, _batch_size=_batch_size, img_transform=img_transform, skip_embedding=skip_embedding)
|
694 |
+
print("got proba")
|
695 |
+
youden_index = {}
|
696 |
+
thr = {}
|
697 |
+
for i, k in enumerate(self.tgt_modalities):
|
698 |
+
youden_index[k] = tpr[i] - fpr[i]
|
699 |
+
thr[k] = thresholds[i][np.argmax(youden_index[k])]
|
700 |
+
# print(thr[k])
|
701 |
+
# print(thr)
|
702 |
+
return logits, proba, [{k: int(smp[k] > thr[k]) for k in self.tgt_modalities} for smp in proba]
|
703 |
+
|
704 |
+
def save(self, filepath: str, epoch: int) -> None:
|
705 |
+
"""Save the model to the given file stream.
|
706 |
+
|
707 |
+
:param filepath: _description_
|
708 |
+
:type filepath: str
|
709 |
+
:param epoch: _description_
|
710 |
+
:type epoch: int
|
711 |
+
"""
|
712 |
+
check_is_fitted(self)
|
713 |
+
if self.data_parallel:
|
714 |
+
state_dict = self.net_.module.state_dict()
|
715 |
+
else:
|
716 |
+
state_dict = self.net_.state_dict()
|
717 |
+
|
718 |
+
# attach model hyper parameters
|
719 |
+
state_dict['src_modalities'] = self.src_modalities
|
720 |
+
state_dict['tgt_modalities'] = self.tgt_modalities
|
721 |
+
state_dict['d_model'] = self.d_model
|
722 |
+
state_dict['nhead'] = self.nhead
|
723 |
+
state_dict['num_encoder_layers'] = self.num_encoder_layers
|
724 |
+
state_dict['num_decoder_layers'] = self.num_decoder_layers
|
725 |
+
state_dict['optimizer'] = self.optimizer
|
726 |
+
state_dict['img_net'] = self.img_net
|
727 |
+
state_dict['imgnet_layers'] = self.imgnet_layers
|
728 |
+
state_dict['img_size'] = self.img_size
|
729 |
+
state_dict['patch_size'] = self.patch_size
|
730 |
+
state_dict['imgnet_ckpt'] = self.imgnet_ckpt
|
731 |
+
state_dict['train_imgnet'] = self.train_imgnet
|
732 |
+
state_dict['epoch'] = epoch
|
733 |
+
|
734 |
+
if self.scaler is not None:
|
735 |
+
state_dict['scaler'] = self.scaler.state_dict()
|
736 |
+
if self.label_distribution:
|
737 |
+
state_dict['label_distribution'] = self.label_distribution
|
738 |
+
|
739 |
+
torch.save(state_dict, filepath)
|
740 |
+
|
741 |
+
def load(self, filepath: str, map_location: str = 'cpu', img_dict=None) -> None:
|
742 |
+
"""Load a model from the given file stream.
|
743 |
+
|
744 |
+
:param filepath: _description_
|
745 |
+
:type filepath: str
|
746 |
+
:param map_location: _description_, defaults to 'cpu'
|
747 |
+
:type map_location: str, optional
|
748 |
+
:param img_dict: _description_, defaults to None
|
749 |
+
:type img_dict: _type_, optional
|
750 |
+
"""
|
751 |
+
# load state_dict
|
752 |
+
state_dict = torch.load(filepath, map_location=map_location)
|
753 |
+
|
754 |
+
# load data modalities
|
755 |
+
self.src_modalities: dict[str, dict[str, Any]] = state_dict.pop('src_modalities')
|
756 |
+
self.tgt_modalities: dict[str, dict[str, Any]] = state_dict.pop('tgt_modalities')
|
757 |
+
if 'label_distribution' in state_dict:
|
758 |
+
self.label_distribution: dict[str, dict[int, int]] = state_dict.pop('label_distribution')
|
759 |
+
if 'optimizer' in state_dict:
|
760 |
+
self.optimizer = state_dict.pop('optimizer')
|
761 |
+
|
762 |
+
# initialize model
|
763 |
+
self.d_model = state_dict.pop('d_model')
|
764 |
+
self.nhead = state_dict.pop('nhead')
|
765 |
+
self.num_encoder_layers = state_dict.pop('num_encoder_layers')
|
766 |
+
self.num_decoder_layers = state_dict.pop('num_decoder_layers')
|
767 |
+
if 'epoch' in state_dict.keys():
|
768 |
+
self.start_epoch = state_dict.pop('epoch')
|
769 |
+
if img_dict is None:
|
770 |
+
self.img_net = state_dict.pop('img_net')
|
771 |
+
self.imgnet_layers = state_dict.pop('imgnet_layers')
|
772 |
+
self.img_size = state_dict.pop('img_size')
|
773 |
+
self.patch_size = state_dict.pop('patch_size')
|
774 |
+
self.imgnet_ckpt = state_dict.pop('imgnet_ckpt')
|
775 |
+
self.train_imgnet = state_dict.pop('train_imgnet')
|
776 |
+
else:
|
777 |
+
self.img_net = img_dict['img_net']
|
778 |
+
self.imgnet_layers = img_dict['imgnet_layers']
|
779 |
+
self.img_size = img_dict['img_size']
|
780 |
+
self.patch_size = img_dict['patch_size']
|
781 |
+
self.imgnet_ckpt = img_dict['imgnet_ckpt']
|
782 |
+
self.train_imgnet = img_dict['train_imgnet']
|
783 |
+
state_dict.pop('img_net')
|
784 |
+
state_dict.pop('imgnet_layers')
|
785 |
+
state_dict.pop('img_size')
|
786 |
+
state_dict.pop('patch_size')
|
787 |
+
state_dict.pop('imgnet_ckpt')
|
788 |
+
state_dict.pop('train_imgnet')
|
789 |
+
|
790 |
+
for k, info in self.src_modalities.items():
|
791 |
+
if info['type'] == 'imaging':
|
792 |
+
if 'emb' not in self.img_net.lower():
|
793 |
+
info['shape'] = (1,) + (self.img_size,) * 3
|
794 |
+
info['img_shape'] = (1,) + (self.img_size,) * 3
|
795 |
+
elif 'swinunetr' in self.img_net.lower():
|
796 |
+
info['shape'] = (1, 768, 4, 4, 4)
|
797 |
+
info['img_shape'] = (1, 768, 4, 4, 4)
|
798 |
+
# print(info['shape'])
|
799 |
+
|
800 |
+
self.net_ = Transformer(self.src_modalities, self.tgt_modalities, self.d_model, self.nhead, self.num_encoder_layers, self.num_decoder_layers, self.device, self.cuda_devices, self.img_net, self.imgnet_layers, self.img_size, self.patch_size, self.imgnet_ckpt, self.train_imgnet, self.fusion_stage)
|
801 |
+
|
802 |
+
|
803 |
+
if 'scaler' in state_dict and state_dict['scaler']:
|
804 |
+
self.scaler.load_state_dict(state_dict.pop('scaler'))
|
805 |
+
self.net_.load_state_dict(state_dict)
|
806 |
+
check_is_fitted(self)
|
807 |
+
self.net_.to(self.device)
|
808 |
+
|
809 |
+
def to(self, device: str) -> Self:
|
810 |
+
"""Mount the model to the given device.
|
811 |
+
|
812 |
+
:param device: _description_
|
813 |
+
:type device: str
|
814 |
+
:return: _description_
|
815 |
+
:rtype: Self
|
816 |
+
"""
|
817 |
+
self.device = device
|
818 |
+
if hasattr(self, 'model'): self.net_ = self.net_.to(device)
|
819 |
+
if hasattr(self, 'img_model'): self.img_model = self.img_model.to(device)
|
820 |
+
return self
|
821 |
+
|
822 |
+
@classmethod
|
823 |
+
def from_ckpt(cls, filepath: str, device='cpu', img_dict=None) -> Self:
|
824 |
+
"""Create a new ADRD model and load parameters from the checkpoint.
|
825 |
+
|
826 |
+
This is an alternative constructor.
|
827 |
+
|
828 |
+
:param filepath: _description_
|
829 |
+
:type filepath: str
|
830 |
+
:param device: _description_, defaults to 'cpu'
|
831 |
+
:type device: str, optional
|
832 |
+
:param img_dict: _description_, defaults to None
|
833 |
+
:type img_dict: _type_, optional
|
834 |
+
:return: _description_
|
835 |
+
:rtype: Self
|
836 |
+
"""
|
837 |
+
obj = cls(None, None, None,device=device)
|
838 |
+
if device == 'cuda':
|
839 |
+
obj.device = "{}:{}".format(obj.device, str(obj.cuda_devices[0]))
|
840 |
+
print(obj.device)
|
841 |
+
obj.load(filepath, map_location=obj.device, img_dict=img_dict)
|
842 |
+
return obj
|
843 |
+
|
844 |
+
def _init_net(self):
|
845 |
+
""" ... """
|
846 |
+
# set the device for use
|
847 |
+
if self.device == 'cuda':
|
848 |
+
self.device = "{}:{}".format(self.device, str(self.cuda_devices[0]))
|
849 |
+
print("Device: " + self.device)
|
850 |
+
|
851 |
+
self.start_epoch = 0
|
852 |
+
if self.load_from_ckpt:
|
853 |
+
try:
|
854 |
+
print("Loading model from checkpoint...")
|
855 |
+
self.load(self.ckpt_path, map_location=self.device)
|
856 |
+
except:
|
857 |
+
print("Cannot load from checkpoint. Initializing new model...")
|
858 |
+
self.load_from_ckpt = False
|
859 |
+
|
860 |
+
if not self.load_from_ckpt:
|
861 |
+
self.net_ = nn.Transformer(
|
862 |
+
src_modalities = self.src_modalities,
|
863 |
+
tgt_modalities = self.tgt_modalities,
|
864 |
+
d_model = self.d_model,
|
865 |
+
nhead = self.nhead,
|
866 |
+
num_encoder_layers = self.num_encoder_layers,
|
867 |
+
num_decoder_layers = self.num_decoder_layers,
|
868 |
+
device = self.device,
|
869 |
+
cuda_devices = self.cuda_devices,
|
870 |
+
img_net = self.img_net,
|
871 |
+
layers = self.imgnet_layers,
|
872 |
+
img_size = self.img_size,
|
873 |
+
patch_size = self.patch_size,
|
874 |
+
imgnet_ckpt = self.imgnet_ckpt,
|
875 |
+
train_imgnet = self.train_imgnet,
|
876 |
+
fusion_stage = self.fusion_stage,
|
877 |
+
)
|
878 |
+
|
879 |
+
# intialize model parameters using xavier_uniform
|
880 |
+
for name, p in self.net_.named_parameters():
|
881 |
+
if p.dim() > 1:
|
882 |
+
torch.nn.init.xavier_uniform_(p)
|
883 |
+
|
884 |
+
self.net_.to(self.device)
|
885 |
+
|
886 |
+
# Initialize the number of GPUs
|
887 |
+
if self.data_parallel and torch.cuda.device_count() > 1:
|
888 |
+
print("Available", torch.cuda.device_count(), "GPUs!")
|
889 |
+
self.net_ = torch.nn.DataParallel(self.net_, device_ids=self.cuda_devices)
|
890 |
+
|
891 |
+
# return net
|
892 |
+
|
893 |
+
def _init_dataloader(self, x_trn, x_vld, y_trn, y_vld, img_train_trans=None, img_vld_trans=None):
|
894 |
+
# initialize dataset and dataloader
|
895 |
+
if self.balanced_sampling:
|
896 |
+
dat_trn = Transformer2ndOrderBalancedTrainingDataset(
|
897 |
+
x_trn, y_trn,
|
898 |
+
self.src_modalities,
|
899 |
+
self.tgt_modalities,
|
900 |
+
dropout_rate = .5,
|
901 |
+
dropout_strategy = 'permutation',
|
902 |
+
img_transform=img_train_trans,
|
903 |
+
)
|
904 |
+
else:
|
905 |
+
dat_trn = TransformerTrainingDataset(
|
906 |
+
x_trn, y_trn,
|
907 |
+
self.src_modalities,
|
908 |
+
self.tgt_modalities,
|
909 |
+
dropout_rate = .5,
|
910 |
+
dropout_strategy = 'permutation',
|
911 |
+
img_transform=img_train_trans,
|
912 |
+
)
|
913 |
+
|
914 |
+
dat_vld = TransformerValidationDataset(
|
915 |
+
x_vld, y_vld,
|
916 |
+
self.src_modalities,
|
917 |
+
self.tgt_modalities,
|
918 |
+
img_transform=img_vld_trans,
|
919 |
+
)
|
920 |
+
|
921 |
+
ldr_trn = DataLoader(
|
922 |
+
dataset = dat_trn,
|
923 |
+
batch_size = self.batch_size,
|
924 |
+
shuffle = True,
|
925 |
+
drop_last = False,
|
926 |
+
num_workers = self._dataloader_num_workers,
|
927 |
+
collate_fn = TransformerTrainingDataset.collate_fn,
|
928 |
+
# pin_memory = True
|
929 |
+
)
|
930 |
+
|
931 |
+
ldr_vld = DataLoader(
|
932 |
+
dataset = dat_vld,
|
933 |
+
batch_size = self.batch_size,
|
934 |
+
shuffle = False,
|
935 |
+
drop_last = False,
|
936 |
+
num_workers = self._dataloader_num_workers,
|
937 |
+
collate_fn = TransformerValidationDataset.collate_fn,
|
938 |
+
# pin_memory = True
|
939 |
+
)
|
940 |
+
|
941 |
+
return ldr_trn, ldr_vld
|
942 |
+
|
943 |
+
def _init_optimizer(self):
|
944 |
+
""" ... """
|
945 |
+
params = list(self.net_.parameters())
|
946 |
+
return torch.optim.AdamW(
|
947 |
+
params,
|
948 |
+
lr = self.lr,
|
949 |
+
betas = (0.9, 0.98),
|
950 |
+
weight_decay = self.weight_decay
|
951 |
+
)
|
952 |
+
|
953 |
+
def _init_scheduler(self, optimizer):
|
954 |
+
""" ... """
|
955 |
+
|
956 |
+
return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
|
957 |
+
optimizer=optimizer,
|
958 |
+
T_0=64,
|
959 |
+
T_mult=2,
|
960 |
+
eta_min = 0,
|
961 |
+
verbose=(self.verbose > 2)
|
962 |
+
)
|
963 |
+
|
964 |
+
def _init_loss_func(self,
|
965 |
+
num_per_cls: dict[str, tuple[int, int]],
|
966 |
+
) -> dict[str, Module]:
|
967 |
+
""" ... """
|
968 |
+
return {k: nn.SigmoidFocalLossBeta(
|
969 |
+
beta = self.beta,
|
970 |
+
gamma = self.gamma,
|
971 |
+
num_per_cls = num_per_cls[k],
|
972 |
+
reduction = 'none',
|
973 |
+
) for k in self.tgt_modalities}
|
974 |
+
|
975 |
+
def _proc_fit(self):
|
976 |
+
""" ... """
|
adrd/model/calibration.py
ADDED
@@ -0,0 +1,450 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from sklearn.base import BaseEstimator
|
3 |
+
from sklearn.utils.validation import check_is_fitted
|
4 |
+
from sklearn.linear_model import LogisticRegression
|
5 |
+
from sklearn.isotonic import IsotonicRegression
|
6 |
+
from functools import lru_cache
|
7 |
+
from functools import cached_property
|
8 |
+
from typing import Self, Any
|
9 |
+
from pickle import dump
|
10 |
+
from pickle import load
|
11 |
+
from abc import ABC, abstractmethod
|
12 |
+
|
13 |
+
from . import ADRDModel
|
14 |
+
from ..utils import Formatter
|
15 |
+
from ..utils import MissingMasker
|
16 |
+
|
17 |
+
|
18 |
+
def calibration_curve(
|
19 |
+
y_true: list[int],
|
20 |
+
y_pred: list[float],
|
21 |
+
n_bins: int = 10,
|
22 |
+
ratio: float = 1.0,
|
23 |
+
) -> tuple[list[float], list[float]]:
|
24 |
+
"""
|
25 |
+
Compute true and predicted probabilities for a calibration curve. The method
|
26 |
+
assumes the inputs come from a binary classifier, and discretize the [0, 1]
|
27 |
+
interval into bins.
|
28 |
+
|
29 |
+
Note that this function is an alternative to
|
30 |
+
sklearn.calibration.calibration_curve() which can only estimate the absolute
|
31 |
+
proportion of positive cases in each bin.
|
32 |
+
|
33 |
+
Parameters
|
34 |
+
----------
|
35 |
+
y_true : list[int]
|
36 |
+
True targets.
|
37 |
+
y_pred : list[float]
|
38 |
+
Probabilities of the positive class.
|
39 |
+
n_bins : int, default=10
|
40 |
+
Number of bins to discretize the [0, 1] interval. A bigger number
|
41 |
+
requires more data. Bins with no samples (i.e. without corresponding
|
42 |
+
values in y_prob) will not be returned, thus the returned arrays may
|
43 |
+
have less than n_bins values.
|
44 |
+
ratio : float, default=1.0
|
45 |
+
Used to adjust the class balance.
|
46 |
+
|
47 |
+
Returns
|
48 |
+
-------
|
49 |
+
prob_true : list[float]
|
50 |
+
The proportion of positive samples in each bin.
|
51 |
+
prob_pred : list[float]
|
52 |
+
The mean predicted probability in each bin.
|
53 |
+
"""
|
54 |
+
# generate "n_bin" intervals
|
55 |
+
tmp = np.around(np.linspace(0, 1, n_bins + 1), decimals=6)
|
56 |
+
intvs = [(tmp[i - 1], tmp[i]) for i in range(1, len(tmp))]
|
57 |
+
|
58 |
+
# pair up (pred, true) and group them by intervals
|
59 |
+
tmp = list(zip(y_pred, y_true))
|
60 |
+
intv_pairs = {(l, r): [p for p in tmp if l <= p[0] < r] for l, r in intvs}
|
61 |
+
|
62 |
+
# calculate balanced proportion of POSITIVE cases for each intervel
|
63 |
+
# along with the balanced averaged predictions
|
64 |
+
intv_prob_true: dict[tuple, float] = dict()
|
65 |
+
intv_prob_pred: dict[tuple, float] = dict()
|
66 |
+
for intv, pairs in intv_pairs.items():
|
67 |
+
# number of cases that fall into the interval
|
68 |
+
n_pairs = len(pairs)
|
69 |
+
|
70 |
+
# it's likely that no predictions fall into the interval
|
71 |
+
if n_pairs == 0: continue
|
72 |
+
|
73 |
+
# count number of positives and negatives in the interval
|
74 |
+
n_pos = sum([p[1] for p in pairs])
|
75 |
+
n_neg = n_pairs - n_pos
|
76 |
+
|
77 |
+
# calculate adjusted proportion of positives
|
78 |
+
intv_prob_true[intv] = n_pos / (n_pos + n_neg * ratio)
|
79 |
+
|
80 |
+
# calculate adjusted avg. predictions
|
81 |
+
sum_pred_pos = sum([p[0] for p in pairs if p[1] == 1])
|
82 |
+
sum_pred_neg = sum([p[0] for p in pairs if p[1] == 0])
|
83 |
+
intv_prob_pred[intv] = (sum_pred_pos + sum_pred_neg * ratio)
|
84 |
+
intv_prob_pred[intv] /= (n_pos + n_neg * ratio)
|
85 |
+
|
86 |
+
prob_true = list(intv_prob_true.values())
|
87 |
+
prob_pred = list(intv_prob_pred.values())
|
88 |
+
return prob_true, prob_pred
|
89 |
+
|
90 |
+
|
91 |
+
class CalibrationCore(BaseEstimator):
|
92 |
+
"""
|
93 |
+
A wrapper class of multiple regressors to predict the proportions of
|
94 |
+
positive samples from the predicted probabilities. The method for
|
95 |
+
calibration can be 'sigmoid' which corresponds to Platt's method (i.e. a
|
96 |
+
logistic regression model) or 'isotonic' which is a non-parametric approach.
|
97 |
+
It is not advised to use isotonic calibration with too few calibration
|
98 |
+
samples (<<1000) since it tends to overfit.
|
99 |
+
|
100 |
+
TODO
|
101 |
+
----
|
102 |
+
- 'sigmoid' method is not trivial to implement.
|
103 |
+
"""
|
104 |
+
def __init__(self,
|
105 |
+
method: str = 'isotonic',
|
106 |
+
) -> None:
|
107 |
+
"""
|
108 |
+
Initialization function of CalibrationCore class.
|
109 |
+
|
110 |
+
Parameters
|
111 |
+
----------
|
112 |
+
method : {'sigmoid', 'isotonic'}, default='isotonic'
|
113 |
+
The method to use for calibration. can be 'sigmoid' which
|
114 |
+
corresponds to Platt's method (i.e. a logistic regression model) or
|
115 |
+
'isotonic' which is a non-parametric approach. It is not advised to
|
116 |
+
use isotonic calibration with too few calibration samples (<<1000)
|
117 |
+
since it tends to overfit.
|
118 |
+
|
119 |
+
Raises
|
120 |
+
------
|
121 |
+
ValueError
|
122 |
+
Sigmoid approach has not been implemented.
|
123 |
+
"""
|
124 |
+
assert method in ('sigmoid', 'isotonic')
|
125 |
+
if method == 'sigmoid':
|
126 |
+
raise ValueError('Sigmoid approach has not been implemented.')
|
127 |
+
self.method = method
|
128 |
+
|
129 |
+
def fit(self,
|
130 |
+
prob_pred: list[float],
|
131 |
+
prob_true: list[float],
|
132 |
+
) -> Self:
|
133 |
+
"""
|
134 |
+
Fit the underlying regressor using prob_pred, prob_true as training
|
135 |
+
data.
|
136 |
+
|
137 |
+
Parameters
|
138 |
+
----------
|
139 |
+
prob_pred : list[float]
|
140 |
+
Probabilities predicted directly by a model.
|
141 |
+
prob_true : list[float]
|
142 |
+
Target probabilities to calibrate to.
|
143 |
+
|
144 |
+
Returns
|
145 |
+
-------
|
146 |
+
Self
|
147 |
+
CalibrationCore object.
|
148 |
+
"""
|
149 |
+
# using Platt's method for calibration
|
150 |
+
if self.method == 'sigmoid':
|
151 |
+
self.model_ = LogisticRegression()
|
152 |
+
self.model_.fit(prob_pred, prob_true)
|
153 |
+
|
154 |
+
# using isotonic calibration
|
155 |
+
elif self.method == 'isotonic':
|
156 |
+
self.model_ = IsotonicRegression(y_min=0, y_max=1, out_of_bounds='clip')
|
157 |
+
self.model_.fit(prob_pred, prob_true)
|
158 |
+
|
159 |
+
return self
|
160 |
+
|
161 |
+
def predict(self,
|
162 |
+
prob_pred: list[float],
|
163 |
+
) -> list[float]:
|
164 |
+
"""
|
165 |
+
Calibrate the input probabilities using the fitted regressor.
|
166 |
+
|
167 |
+
Parameters
|
168 |
+
----------
|
169 |
+
prob_pred : list[float]
|
170 |
+
Probabilities predicted directly by a model.
|
171 |
+
|
172 |
+
Returns
|
173 |
+
-------
|
174 |
+
prob_cali : list[float]
|
175 |
+
Calibrated probabilities.
|
176 |
+
"""
|
177 |
+
# as usual, the core needs to be fitted
|
178 |
+
check_is_fitted(self)
|
179 |
+
|
180 |
+
# note that logistic regression is classification model, we need to call
|
181 |
+
# 'predict_proba' instead of 'predict' to get the calibrated results
|
182 |
+
if self.method == 'sigmoid':
|
183 |
+
prob_cali = self.model_.predict_proba(prob_pred)
|
184 |
+
elif self.method == 'isotonic':
|
185 |
+
prob_cali = self.model_.predict(prob_pred)
|
186 |
+
|
187 |
+
return prob_cali
|
188 |
+
|
189 |
+
|
190 |
+
class CalibratedClassifier(ABC):
|
191 |
+
"""
|
192 |
+
Abstract class of calibrated classifier.
|
193 |
+
"""
|
194 |
+
def __init__(self,
|
195 |
+
model: ADRDModel,
|
196 |
+
background_src: list[dict[str, Any]],
|
197 |
+
background_tgt: list[dict[str, Any]],
|
198 |
+
background_is_embedding: dict[str, bool] | None = None,
|
199 |
+
method: str = 'isotonic',
|
200 |
+
) -> None:
|
201 |
+
"""
|
202 |
+
Constructor of Calibrator class.
|
203 |
+
|
204 |
+
Parameters
|
205 |
+
----------
|
206 |
+
model : ADRDModel
|
207 |
+
Fitted model to calibrate.
|
208 |
+
background_src : list[dict[str, Any]]
|
209 |
+
Features of the background dataset.
|
210 |
+
background_tgt : list[dict[str, Any]]
|
211 |
+
Labels of the background dataset.
|
212 |
+
method : {'sigmoid', 'isotonic'}, default='isotonic'
|
213 |
+
Method used by the underlying regressor.
|
214 |
+
"""
|
215 |
+
self.method = method
|
216 |
+
self.model = model
|
217 |
+
self.src_modalities = model.src_modalities
|
218 |
+
self.tgt_modalities = model.tgt_modalities
|
219 |
+
self.background_is_embedding = background_is_embedding
|
220 |
+
|
221 |
+
# format background data
|
222 |
+
fmt_src = Formatter(self.src_modalities)
|
223 |
+
fmt_tgt = Formatter(self.tgt_modalities)
|
224 |
+
self.background_src = [fmt_src(smp) for smp in background_src]
|
225 |
+
self.background_tgt = [fmt_tgt(smp) for smp in background_tgt]
|
226 |
+
|
227 |
+
@abstractmethod
|
228 |
+
def predict_proba(self,
|
229 |
+
src: list[dict[str, Any]],
|
230 |
+
is_embedding: dict[str, bool] | None = None,
|
231 |
+
) -> list[dict[str, float]]:
|
232 |
+
"""
|
233 |
+
This method returns calibrated probabilities of classification.
|
234 |
+
|
235 |
+
Parameters
|
236 |
+
----------
|
237 |
+
src : list[dict[str, Any]]
|
238 |
+
Features of the input samples.
|
239 |
+
|
240 |
+
Returns
|
241 |
+
-------
|
242 |
+
list[dict[str, float]]
|
243 |
+
Calibrated probabilities.
|
244 |
+
"""
|
245 |
+
pass
|
246 |
+
|
247 |
+
def predict(self,
|
248 |
+
src: list[dict[str, Any]],
|
249 |
+
is_embedding: dict[str, bool] | None = None,
|
250 |
+
) -> list[dict[str, int]]:
|
251 |
+
"""
|
252 |
+
Make predictions based on the results of predict_proba().
|
253 |
+
|
254 |
+
Parameters
|
255 |
+
----------
|
256 |
+
x : list[dict[str, Any]]
|
257 |
+
Input features.
|
258 |
+
|
259 |
+
Returns
|
260 |
+
-------
|
261 |
+
list[dict[str, int]]
|
262 |
+
Calibrated predictions.
|
263 |
+
"""
|
264 |
+
proba = self.predict_proba(src, is_embedding)
|
265 |
+
return [{k: int(smp[k] > 0.5) for k in self.tgt_modalities} for smp in proba]
|
266 |
+
|
267 |
+
def save(self,
|
268 |
+
filepath_state_dict: str,
|
269 |
+
) -> None:
|
270 |
+
"""
|
271 |
+
Save the state dict and the underlying model to the given paths.
|
272 |
+
|
273 |
+
Parameters
|
274 |
+
----------
|
275 |
+
filepath_state_dict : str
|
276 |
+
File path to save the state_dict which includes the background
|
277 |
+
dataset and the regressor information.
|
278 |
+
filepath_wrapped_model : str | None, default=None
|
279 |
+
File path to save the wrapped model. If None, the model won't be
|
280 |
+
saved.
|
281 |
+
"""
|
282 |
+
# save state dict
|
283 |
+
state_dict = {
|
284 |
+
'background_src': self.background_src,
|
285 |
+
'background_tgt': self.background_tgt,
|
286 |
+
'background_is_embedding': self.background_is_embedding,
|
287 |
+
'method': self.method,
|
288 |
+
}
|
289 |
+
with open(filepath_state_dict, 'wb') as f:
|
290 |
+
dump(state_dict, f)
|
291 |
+
|
292 |
+
@classmethod
|
293 |
+
def from_ckpt(cls,
|
294 |
+
filepath_state_dict: str,
|
295 |
+
filepath_wrapped_model: str,
|
296 |
+
) -> Self:
|
297 |
+
"""
|
298 |
+
Alternative constructor which loads from checkpoint.
|
299 |
+
|
300 |
+
Parameters
|
301 |
+
----------
|
302 |
+
filepath_state_dict : str
|
303 |
+
File path to load the state_dict which includes the background
|
304 |
+
dataset and the regressor information.
|
305 |
+
filepath_wrapped_model : str
|
306 |
+
File path of the wrapped model.
|
307 |
+
|
308 |
+
Returns
|
309 |
+
-------
|
310 |
+
Self
|
311 |
+
CalibratedClassifier class object.
|
312 |
+
"""
|
313 |
+
with open(filepath_state_dict, 'rb') as f:
|
314 |
+
kwargs = load(f)
|
315 |
+
kwargs['model'] = ADRDModel.from_ckpt(filepath_wrapped_model)
|
316 |
+
return cls(**kwargs)
|
317 |
+
|
318 |
+
|
319 |
+
class DynamicCalibratedClassifier(CalibratedClassifier):
|
320 |
+
"""
|
321 |
+
The dynamic approach generates background predictions based on the
|
322 |
+
missingness pattern of each input. With an astronomical number of
|
323 |
+
missingness patterns, calibrating each sample requires a comprehensive
|
324 |
+
process that involves running the ADRDModel on the majority of the
|
325 |
+
background data and training a corresponding regressor. This results in a
|
326 |
+
computationally intensive calculation.
|
327 |
+
"""
|
328 |
+
def predict_proba(self,
|
329 |
+
src: list[dict[str, Any]],
|
330 |
+
is_embedding: dict[str, bool] | None = None,
|
331 |
+
) -> list[dict[str, float]]:
|
332 |
+
|
333 |
+
# initialize mask generator and format inputs
|
334 |
+
msk_gen = MissingMasker(self.src_modalities)
|
335 |
+
fmt_src = Formatter(self.src_modalities)
|
336 |
+
src = [fmt_src(smp) for smp in src]
|
337 |
+
|
338 |
+
# calculate calibrated probabilities
|
339 |
+
calibrated_prob: list[dict[str, float]] = []
|
340 |
+
for smp in src:
|
341 |
+
# model output and missingness pattern
|
342 |
+
prob = self.model.predict_proba([smp], is_embedding)[0]
|
343 |
+
mask = tuple(msk_gen(smp).values())
|
344 |
+
|
345 |
+
# get/fit core and calculate calibrated probabilities
|
346 |
+
core = self._fit_core(mask)
|
347 |
+
calibrated_prob.append({k: core[k].predict([prob[k]])[0] for k in self.tgt_modalities})
|
348 |
+
|
349 |
+
return calibrated_prob
|
350 |
+
|
351 |
+
# @lru_cache(maxsize = None)
|
352 |
+
def _fit_core(self,
|
353 |
+
missingness_pattern: tuple[bool],
|
354 |
+
) -> dict[str, CalibrationCore]:
|
355 |
+
''' ... '''
|
356 |
+
# remove features from all background samples accordingly
|
357 |
+
background_src, background_tgt = [], []
|
358 |
+
for src, tgt in zip(self.background_src, self.background_tgt):
|
359 |
+
src = {k: v for j, (k, v) in enumerate(src.items()) if missingness_pattern[j] == False}
|
360 |
+
|
361 |
+
# make sure there is at least one feature available
|
362 |
+
if len([v is not None for v in src.values()]) == 0: continue
|
363 |
+
background_src.append(src)
|
364 |
+
background_tgt.append(tgt)
|
365 |
+
|
366 |
+
# run model on background samples and collection predictions
|
367 |
+
background_prob = self.model.predict_proba(background_src, self.background_is_embedding, _batch_size=1024)
|
368 |
+
|
369 |
+
# list[dict] -> dict[list]
|
370 |
+
N = len(background_src)
|
371 |
+
background_prob = {k: [background_prob[i][k] for i in range(N)] for k in self.tgt_modalities}
|
372 |
+
background_true = {k: [background_tgt[i][k] for i in range(N)] for k in self.tgt_modalities}
|
373 |
+
|
374 |
+
# now, fit cores
|
375 |
+
core: dict[str, CalibrationCore] = dict()
|
376 |
+
for k in self.tgt_modalities:
|
377 |
+
prob_true, prob_pred = calibration_curve(
|
378 |
+
background_true[k], background_prob[k],
|
379 |
+
ratio = self.background_ratio[k],
|
380 |
+
)
|
381 |
+
core[k] = CalibrationCore(self.method).fit(prob_pred, prob_true)
|
382 |
+
|
383 |
+
return core
|
384 |
+
|
385 |
+
@cached_property
|
386 |
+
def background_ratio(self) -> dict[str, float]:
|
387 |
+
''' The ratio of positives over negatives in the background dataset. '''
|
388 |
+
return {k: self.background_n_pos[k] / self.background_n_neg[k] for k in self.tgt_modalities}
|
389 |
+
|
390 |
+
@cached_property
|
391 |
+
def background_n_pos(self) -> dict[str, int]:
|
392 |
+
''' Number of positives w.r.t each target in the background dataset. '''
|
393 |
+
return {k: sum([d[k] for d in self.background_tgt]) for k in self.tgt_modalities}
|
394 |
+
|
395 |
+
@cached_property
|
396 |
+
def background_n_neg(self) -> dict[str, int]:
|
397 |
+
''' Number of negatives w.r.t each target in the background dataset. '''
|
398 |
+
return {k: len(self.background_tgt) - self.background_n_pos[k] for k in self.tgt_modalities}
|
399 |
+
|
400 |
+
|
401 |
+
class StaticCalibratedClassifier(CalibratedClassifier):
|
402 |
+
"""
|
403 |
+
The static approach generates background predictions without considering the
|
404 |
+
missingness patterns.
|
405 |
+
"""
|
406 |
+
def predict_proba(self,
|
407 |
+
src: list[dict[str, Any]],
|
408 |
+
is_embedding: dict[str, bool] | None = None,
|
409 |
+
) -> list[dict[str, float]]:
|
410 |
+
|
411 |
+
# number of input samples
|
412 |
+
N = len(src)
|
413 |
+
|
414 |
+
# format inputs, and run ADRDModel, and convert to dict[list]
|
415 |
+
fmt_src = Formatter(self.src_modalities)
|
416 |
+
src = [fmt_src(smp) for smp in src]
|
417 |
+
prob = self.model.predict_proba(src, is_embedding)
|
418 |
+
prob = {k: [prob[i][k] for i in range(N)] for k in self.tgt_modalities}
|
419 |
+
|
420 |
+
# calibrate probabilities
|
421 |
+
core = self._fit_core()
|
422 |
+
calibrated_prob = {k: core[k].predict(prob[k]) for k in self.tgt_modalities}
|
423 |
+
|
424 |
+
# convert back to list[dict]
|
425 |
+
calibrated_prob: list[dict[str, float]] = [
|
426 |
+
{k: calibrated_prob[k][i] for k in self.tgt_modalities} for i in range(N)
|
427 |
+
]
|
428 |
+
return calibrated_prob
|
429 |
+
|
430 |
+
@lru_cache(maxsize = None)
|
431 |
+
def _fit_core(self) -> dict[str, CalibrationCore]:
|
432 |
+
''' ... '''
|
433 |
+
# run model on background samples and collection predictions
|
434 |
+
background_prob = self.model.predict_proba(self.background_src, self.background_is_embedding, _batch_size=1024)
|
435 |
+
|
436 |
+
# list[dict] -> dict[list]
|
437 |
+
N = len(self.background_src)
|
438 |
+
background_prob = {k: [background_prob[i][k] for i in range(N)] for k in self.tgt_modalities}
|
439 |
+
background_true = {k: [self.background_tgt[i][k] for i in range(N)] for k in self.tgt_modalities}
|
440 |
+
|
441 |
+
# now, fit cores
|
442 |
+
core: dict[str, CalibrationCore] = dict()
|
443 |
+
for k in self.tgt_modalities:
|
444 |
+
prob_true, prob_pred = calibration_curve(
|
445 |
+
background_true[k], background_prob[k],
|
446 |
+
ratio = 1.0,
|
447 |
+
)
|
448 |
+
core[k] = CalibrationCore(self.method).fit(prob_pred, prob_true)
|
449 |
+
|
450 |
+
return core
|
adrd/model/cnn_resnet3d_with_linear_classifier.py
ADDED
@@ -0,0 +1,533 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__all__ = ['CNNResNet3DWithLinearClassifier']
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
import numpy as np
|
6 |
+
import tqdm
|
7 |
+
from sklearn.base import BaseEstimator
|
8 |
+
from sklearn.utils.validation import check_is_fitted
|
9 |
+
from sklearn.model_selection import train_test_split
|
10 |
+
from scipy.special import expit
|
11 |
+
from copy import deepcopy
|
12 |
+
from contextlib import suppress
|
13 |
+
from typing import Any, Self, Type
|
14 |
+
from functools import wraps
|
15 |
+
Tensor = Type[torch.Tensor]
|
16 |
+
Module = Type[torch.nn.Module]
|
17 |
+
|
18 |
+
from ..utils.misc import ProgressBar
|
19 |
+
from ..utils.misc import get_metrics_multitask, print_metrics_multitask
|
20 |
+
|
21 |
+
from .. import nn
|
22 |
+
from ..utils import TransformerTrainingDataset
|
23 |
+
from ..utils import Transformer2ndOrderBalancedTrainingDataset
|
24 |
+
from ..utils import TransformerValidationDataset
|
25 |
+
from ..utils import TransformerTestingDataset
|
26 |
+
from ..utils.misc import ProgressBar
|
27 |
+
from ..utils.misc import get_metrics_multitask, print_metrics_multitask
|
28 |
+
from ..utils.misc import convert_args_kwargs_to_kwargs
|
29 |
+
|
30 |
+
|
31 |
+
def _manage_ctx_fit(func):
|
32 |
+
''' ... '''
|
33 |
+
@wraps(func)
|
34 |
+
def wrapper(*args, **kwargs):
|
35 |
+
# format arguments
|
36 |
+
kwargs = convert_args_kwargs_to_kwargs(func, args, kwargs)
|
37 |
+
|
38 |
+
if kwargs['self']._device_ids is None:
|
39 |
+
return func(**kwargs)
|
40 |
+
else:
|
41 |
+
# change primary device
|
42 |
+
default_device = kwargs['self'].device
|
43 |
+
kwargs['self'].device = kwargs['self']._device_ids[0]
|
44 |
+
rtn = func(**kwargs)
|
45 |
+
|
46 |
+
# the actual module is wrapped
|
47 |
+
kwargs['self'].net_ = kwargs['self'].net_.module
|
48 |
+
kwargs['self'].to(default_device)
|
49 |
+
return rtn
|
50 |
+
return wrapper
|
51 |
+
|
52 |
+
|
53 |
+
class CNNResNet3DWithLinearClassifier(BaseEstimator):
|
54 |
+
|
55 |
+
def __init__(self,
|
56 |
+
src_modalities: dict[str, dict[str, Any]],
|
57 |
+
tgt_modalities: dict[str, dict[str, Any]],
|
58 |
+
num_epochs: int = 32,
|
59 |
+
batch_size: int = 8,
|
60 |
+
batch_size_multiplier: int = 1,
|
61 |
+
lr: float = 1e-2,
|
62 |
+
weight_decay: float = 0.0,
|
63 |
+
beta: float = 0.9999,
|
64 |
+
gamma: float = 2.0,
|
65 |
+
scale: float = 1.0,
|
66 |
+
criterion: str | None = None,
|
67 |
+
device: str = 'cpu',
|
68 |
+
verbose: int = 0,
|
69 |
+
_device_ids: list | None = None,
|
70 |
+
_dataloader_num_workers: int = 0,
|
71 |
+
_amp_enabled: bool = False,
|
72 |
+
_tmp_ckpt_filepath: str | None = None,
|
73 |
+
) -> None:
|
74 |
+
''' ... '''
|
75 |
+
# for multiprocessing
|
76 |
+
self._rank = 0
|
77 |
+
self._lock = None
|
78 |
+
|
79 |
+
# positional parameters
|
80 |
+
self.src_modalities = src_modalities
|
81 |
+
self.tgt_modalities = tgt_modalities
|
82 |
+
|
83 |
+
# training parameters
|
84 |
+
self.num_epochs = num_epochs
|
85 |
+
self.batch_size = batch_size
|
86 |
+
self.batch_size_multiplier = batch_size_multiplier
|
87 |
+
self.lr = lr
|
88 |
+
self.weight_decay = weight_decay
|
89 |
+
self.beta = beta
|
90 |
+
self.gamma = gamma
|
91 |
+
self.scale = scale
|
92 |
+
self.criterion = criterion
|
93 |
+
self.device = device
|
94 |
+
self.verbose = verbose
|
95 |
+
self._device_ids = _device_ids
|
96 |
+
self._dataloader_num_workers = _dataloader_num_workers
|
97 |
+
self._amp_enabled = _amp_enabled
|
98 |
+
self._tmp_ckpt_filepath = _tmp_ckpt_filepath
|
99 |
+
|
100 |
+
|
101 |
+
@_manage_ctx_fit
|
102 |
+
def fit(self, x, y) -> Self:
|
103 |
+
''' ... '''
|
104 |
+
# for PyTorch computational efficiency
|
105 |
+
torch.set_num_threads(1)
|
106 |
+
|
107 |
+
# initialize neural network
|
108 |
+
self.net_ = self._init_net()
|
109 |
+
|
110 |
+
# initialize dataloaders
|
111 |
+
ldr_trn, ldr_vld = self._init_dataloader(x, y)
|
112 |
+
|
113 |
+
# initialize optimizer and scheduler
|
114 |
+
optimizer = self._init_optimizer()
|
115 |
+
scheduler = self._init_scheduler(optimizer)
|
116 |
+
|
117 |
+
# gradient scaler for AMP
|
118 |
+
if self._amp_enabled: scaler = torch.cuda.amp.GradScaler()
|
119 |
+
|
120 |
+
# initialize loss function (binary cross entropy)
|
121 |
+
loss_func = self._init_loss_func({
|
122 |
+
k: (
|
123 |
+
sum([_[k] == 0 for _ in ldr_trn.dataset.tgt]),
|
124 |
+
sum([_[k] == 1 for _ in ldr_trn.dataset.tgt]),
|
125 |
+
) for k in self.tgt_modalities
|
126 |
+
})
|
127 |
+
|
128 |
+
# to record the best validation performance criterion
|
129 |
+
if self.criterion is not None: best_crit = None
|
130 |
+
|
131 |
+
# progress bar for epoch loops
|
132 |
+
if self.verbose == 1:
|
133 |
+
with self._lock if self._lock is not None else suppress():
|
134 |
+
pbr_epoch = tqdm.tqdm(
|
135 |
+
desc = 'Rank {:02d}'.format(self._rank),
|
136 |
+
total = self.num_epochs,
|
137 |
+
position = self._rank,
|
138 |
+
ascii = True,
|
139 |
+
leave = False,
|
140 |
+
bar_format='{l_bar}{r_bar}'
|
141 |
+
)
|
142 |
+
|
143 |
+
# training loop
|
144 |
+
for epoch in range(self.num_epochs):
|
145 |
+
# progress bar for batch loops
|
146 |
+
if self.verbose > 1:
|
147 |
+
pbr_batch = ProgressBar(len(ldr_trn.dataset), 'Epoch {:03d} (TRN)'.format(epoch))
|
148 |
+
|
149 |
+
# set model to train mode
|
150 |
+
torch.set_grad_enabled(True)
|
151 |
+
self.net_.train()
|
152 |
+
|
153 |
+
scores_trn: dict[str, list[float]] = {k: [] for k in self.tgt_modalities}
|
154 |
+
y_true_trn: dict[str, list[int]] = {k: [] for k in self.tgt_modalities}
|
155 |
+
losses_trn: dict[str, list[float]] = {k: [] for k in self.tgt_modalities}
|
156 |
+
for n_iter, (x_batch, y_batch, _, mask_y) in enumerate(ldr_trn):
|
157 |
+
# mount data to the proper device
|
158 |
+
x_batch = {k: x_batch[k].to(self.device) for k in self.src_modalities}
|
159 |
+
y_batch = {k: y_batch[k].to(torch.float).to(self.device) for k in self.tgt_modalities}
|
160 |
+
# mask_x = {k: mask_x[k].to(self.device) for k in self.src_modalities}
|
161 |
+
mask_y = {k: mask_y[k].to(self.device) for k in self.tgt_modalities}
|
162 |
+
|
163 |
+
# forward
|
164 |
+
with torch.autocast(
|
165 |
+
device_type = 'cpu' if self.device == 'cpu' else 'cuda',
|
166 |
+
dtype = torch.bfloat16 if self.device == 'cpu' else torch.float16,
|
167 |
+
enabled = self._amp_enabled,
|
168 |
+
):
|
169 |
+
outputs = self.net_(x_batch)
|
170 |
+
|
171 |
+
# calculate multitask loss
|
172 |
+
loss = 0
|
173 |
+
for i, tgt_k in enumerate(self.tgt_modalities):
|
174 |
+
loss_k = loss_func[tgt_k](outputs[tgt_k], y_batch[tgt_k])
|
175 |
+
loss_k = torch.masked_select(loss_k, torch.logical_not(mask_y[tgt_k].squeeze()))
|
176 |
+
loss += loss_k.mean()
|
177 |
+
losses_trn[tgt_k] += loss_k.detach().cpu().numpy().tolist()
|
178 |
+
|
179 |
+
# backward
|
180 |
+
if self._amp_enabled:
|
181 |
+
scaler.scale(loss).backward()
|
182 |
+
else:
|
183 |
+
loss.backward()
|
184 |
+
|
185 |
+
# update parameters
|
186 |
+
if n_iter != 0 and n_iter % self.batch_size_multiplier == 0:
|
187 |
+
if self._amp_enabled:
|
188 |
+
scaler.step(optimizer)
|
189 |
+
scaler.update()
|
190 |
+
optimizer.zero_grad()
|
191 |
+
else:
|
192 |
+
optimizer.step()
|
193 |
+
optimizer.zero_grad()
|
194 |
+
|
195 |
+
# save outputs to evaluate performance later
|
196 |
+
for tgt_k in self.tgt_modalities:
|
197 |
+
tmp = torch.masked_select(outputs[tgt_k], torch.logical_not(mask_y[tgt_k].squeeze()))
|
198 |
+
scores_trn[tgt_k] += tmp.detach().cpu().numpy().tolist()
|
199 |
+
tmp = torch.masked_select(y_batch[tgt_k], torch.logical_not(mask_y[tgt_k].squeeze()))
|
200 |
+
y_true_trn[tgt_k] += tmp.cpu().numpy().tolist()
|
201 |
+
|
202 |
+
# update progress bar
|
203 |
+
if self.verbose > 1:
|
204 |
+
batch_size = len(next(iter(x_batch.values())))
|
205 |
+
pbr_batch.update(batch_size, {})
|
206 |
+
pbr_batch.refresh()
|
207 |
+
|
208 |
+
# for better tqdm progress bar display
|
209 |
+
if self.verbose > 1:
|
210 |
+
pbr_batch.close()
|
211 |
+
|
212 |
+
# set scheduler
|
213 |
+
scheduler.step()
|
214 |
+
|
215 |
+
# calculate and print training performance metrics
|
216 |
+
y_pred_trn: dict[str, list[int]] = {k: [] for k in self.tgt_modalities}
|
217 |
+
y_prob_trn: dict[str, list[float]] = {k: [] for k in self.tgt_modalities}
|
218 |
+
for tgt_k in self.tgt_modalities:
|
219 |
+
for i in range(len(scores_trn[tgt_k])):
|
220 |
+
y_pred_trn[tgt_k].append(1 if scores_trn[tgt_k][i] > 0 else 0)
|
221 |
+
y_prob_trn[tgt_k].append(expit(scores_trn[tgt_k][i]))
|
222 |
+
met_trn = get_metrics_multitask(y_true_trn, y_pred_trn, y_prob_trn)
|
223 |
+
|
224 |
+
# add loss to metrics
|
225 |
+
for tgt_k in self.tgt_modalities:
|
226 |
+
met_trn[tgt_k]['Loss'] = np.mean(losses_trn[tgt_k])
|
227 |
+
|
228 |
+
if self.verbose > 2:
|
229 |
+
print_metrics_multitask(met_trn)
|
230 |
+
|
231 |
+
# progress bar for validation
|
232 |
+
if self.verbose > 1:
|
233 |
+
pbr_batch = ProgressBar(len(ldr_vld.dataset), 'Epoch {:03d} (VLD)'.format(epoch))
|
234 |
+
|
235 |
+
# set model to validation mode
|
236 |
+
torch.set_grad_enabled(False)
|
237 |
+
self.net_.eval()
|
238 |
+
|
239 |
+
scores_vld: dict[str, list[float]] = {k: [] for k in self.tgt_modalities}
|
240 |
+
y_true_vld: dict[str, list[int]] = {k: [] for k in self.tgt_modalities}
|
241 |
+
losses_vld: dict[str, list[float]] = {k: [] for k in self.tgt_modalities}
|
242 |
+
for x_batch, y_batch, _, mask_y in ldr_vld:
|
243 |
+
# mount data to the proper device
|
244 |
+
x_batch = {k: x_batch[k].to(self.device) for k in self.src_modalities}
|
245 |
+
y_batch = {k: y_batch[k].to(torch.float).to(self.device) for k in self.tgt_modalities}
|
246 |
+
# mask_x = {k: mask_x[k].to(self.device) for k in self.src_modalities}
|
247 |
+
mask_y = {k: mask_y[k].to(self.device) for k in self.tgt_modalities}
|
248 |
+
|
249 |
+
# forward
|
250 |
+
with torch.autocast(
|
251 |
+
device_type = 'cpu' if self.device == 'cpu' else 'cuda',
|
252 |
+
dtype = torch.bfloat16 if self.device == 'cpu' else torch.float16,
|
253 |
+
enabled = self._amp_enabled
|
254 |
+
):
|
255 |
+
outputs = self.net_(x_batch)
|
256 |
+
|
257 |
+
# calculate multitask loss
|
258 |
+
for i, tgt_k in enumerate(self.tgt_modalities):
|
259 |
+
loss_k = loss_func[tgt_k](outputs[tgt_k], y_batch[tgt_k])
|
260 |
+
loss_k = torch.masked_select(loss_k, torch.logical_not(mask_y[tgt_k].squeeze()))
|
261 |
+
losses_vld[tgt_k] += loss_k.detach().cpu().numpy().tolist()
|
262 |
+
|
263 |
+
# save outputs to evaluate performance later
|
264 |
+
for tgt_k in self.tgt_modalities:
|
265 |
+
tmp = torch.masked_select(outputs[tgt_k], torch.logical_not(mask_y[tgt_k].squeeze()))
|
266 |
+
scores_vld[tgt_k] += tmp.detach().cpu().numpy().tolist()
|
267 |
+
tmp = torch.masked_select(y_batch[tgt_k], torch.logical_not(mask_y[tgt_k].squeeze()))
|
268 |
+
y_true_vld[tgt_k] += tmp.cpu().numpy().tolist()
|
269 |
+
|
270 |
+
# update progress bar
|
271 |
+
if self.verbose > 1:
|
272 |
+
batch_size = len(next(iter(x_batch.values())))
|
273 |
+
pbr_batch.update(batch_size, {})
|
274 |
+
pbr_batch.refresh()
|
275 |
+
|
276 |
+
# for better tqdm progress bar display
|
277 |
+
if self.verbose > 1:
|
278 |
+
pbr_batch.close()
|
279 |
+
|
280 |
+
# calculate and print validation performance metrics
|
281 |
+
y_pred_vld: dict[str, list[int]] = {k: [] for k in self.tgt_modalities}
|
282 |
+
y_prob_vld: dict[str, list[float]] = {k: [] for k in self.tgt_modalities}
|
283 |
+
for tgt_k in self.tgt_modalities:
|
284 |
+
for i in range(len(scores_vld[tgt_k])):
|
285 |
+
y_pred_vld[tgt_k].append(1 if scores_vld[tgt_k][i] > 0 else 0)
|
286 |
+
y_prob_vld[tgt_k].append(expit(scores_vld[tgt_k][i]))
|
287 |
+
met_vld = get_metrics_multitask(y_true_vld, y_pred_vld, y_prob_vld)
|
288 |
+
|
289 |
+
# add loss to metrics
|
290 |
+
for tgt_k in self.tgt_modalities:
|
291 |
+
met_vld[tgt_k]['Loss'] = np.mean(losses_vld[tgt_k])
|
292 |
+
|
293 |
+
if self.verbose > 2:
|
294 |
+
print_metrics_multitask(met_vld)
|
295 |
+
|
296 |
+
# save the model if it has the best validation performance criterion by far
|
297 |
+
if self.criterion is None: continue
|
298 |
+
|
299 |
+
# is current criterion better than previous best?
|
300 |
+
curr_crit = np.mean([met_vld[k][self.criterion] for k in self.tgt_modalities])
|
301 |
+
if best_crit is None or np.isnan(best_crit):
|
302 |
+
is_better = True
|
303 |
+
elif self.criterion == 'Loss' and best_crit >= curr_crit:
|
304 |
+
is_better = True
|
305 |
+
elif self.criterion != 'Loss' and best_crit <= curr_crit:
|
306 |
+
is_better = True
|
307 |
+
else:
|
308 |
+
is_better = False
|
309 |
+
|
310 |
+
# update best criterion
|
311 |
+
if is_better:
|
312 |
+
best_crit = curr_crit
|
313 |
+
best_state_dict = deepcopy(self.net_.state_dict())
|
314 |
+
|
315 |
+
if self._tmp_ckpt_filepath is not None:
|
316 |
+
self.save(self._tmp_ckpt_filepath)
|
317 |
+
|
318 |
+
if self.verbose > 2:
|
319 |
+
print('Best {}: {}'.format(self.criterion, best_crit))
|
320 |
+
|
321 |
+
if self.verbose == 1:
|
322 |
+
with self._lock if self._lock is not None else suppress():
|
323 |
+
pbr_epoch.update(1)
|
324 |
+
pbr_epoch.refresh()
|
325 |
+
|
326 |
+
if self.verbose == 1:
|
327 |
+
with self._lock if self._lock is not None else suppress():
|
328 |
+
pbr_epoch.close()
|
329 |
+
|
330 |
+
# restore the model of the best validation performance across all epoches
|
331 |
+
if ldr_vld is not None and self.criterion is not None:
|
332 |
+
self.net_.load_state_dict(best_state_dict)
|
333 |
+
|
334 |
+
return self
|
335 |
+
|
336 |
+
def predict_logits(self,
|
337 |
+
x: list[dict[str, Any]],
|
338 |
+
_batch_size: int | None = None,
|
339 |
+
) -> list[dict[str, float]]:
|
340 |
+
"""
|
341 |
+
The input x can be a single sample or a list of samples.
|
342 |
+
"""
|
343 |
+
# input validation
|
344 |
+
check_is_fitted(self)
|
345 |
+
|
346 |
+
# for PyTorch computational efficiency
|
347 |
+
torch.set_num_threads(1)
|
348 |
+
|
349 |
+
# set model to eval mode
|
350 |
+
torch.set_grad_enabled(False)
|
351 |
+
self.net_.eval()
|
352 |
+
|
353 |
+
# intialize dataset and dataloader object
|
354 |
+
dat = TransformerTestingDataset(x, self.src_modalities)
|
355 |
+
ldr = DataLoader(
|
356 |
+
dataset = dat,
|
357 |
+
batch_size = _batch_size if _batch_size is not None else len(x),
|
358 |
+
shuffle = False,
|
359 |
+
drop_last = False,
|
360 |
+
num_workers = 0,
|
361 |
+
collate_fn = TransformerTestingDataset.collate_fn,
|
362 |
+
)
|
363 |
+
|
364 |
+
# run model and collect results
|
365 |
+
logits: list[dict[str, float]] = []
|
366 |
+
for x_batch, _ in ldr:
|
367 |
+
# mount data to the proper device
|
368 |
+
x_batch = {k: x_batch[k].to(self.device) for k in self.src_modalities}
|
369 |
+
|
370 |
+
# forward
|
371 |
+
output: dict[str, Tensor] = self.net_(x_batch)
|
372 |
+
|
373 |
+
# convert output from dict-of-list to list of dict, then append
|
374 |
+
tmp = {k: output[k].tolist() for k in self.tgt_modalities}
|
375 |
+
tmp = [{k: tmp[k][i] for k in self.tgt_modalities} for i in range(len(next(iter(tmp.values()))))]
|
376 |
+
logits += tmp
|
377 |
+
|
378 |
+
return logits
|
379 |
+
|
380 |
+
def predict_proba(self,
|
381 |
+
x: list[dict[str, Any]],
|
382 |
+
temperature: float = 1.0,
|
383 |
+
_batch_size: int | None = None,
|
384 |
+
) -> list[dict[str, float]]:
|
385 |
+
''' ... '''
|
386 |
+
logits = self.predict_logits(x, _batch_size)
|
387 |
+
return [{k: expit(smp[k] / temperature) for k in self.tgt_modalities} for smp in logits]
|
388 |
+
|
389 |
+
def predict(self,
|
390 |
+
x: list[dict[str, Any]],
|
391 |
+
_batch_size: int | None = None,
|
392 |
+
) -> list[dict[str, int]]:
|
393 |
+
''' ... '''
|
394 |
+
logits = self.predict_logits(x, _batch_size)
|
395 |
+
return [{k: int(smp[k] > 0.0) for k in self.tgt_modalities} for smp in logits]
|
396 |
+
|
397 |
+
def save(self, filepath: str) -> None:
|
398 |
+
''' ... '''
|
399 |
+
check_is_fitted(self)
|
400 |
+
state_dict = self.net_.state_dict()
|
401 |
+
|
402 |
+
# attach model hyper parameters
|
403 |
+
state_dict['src_modalities'] = self.src_modalities
|
404 |
+
state_dict['tgt_modalities'] = self.tgt_modalities
|
405 |
+
print('Saving model checkpoint to {} ... '.format(filepath), end='')
|
406 |
+
torch.save(state_dict, filepath)
|
407 |
+
print('Done.')
|
408 |
+
|
409 |
+
def load(self, filepath: str) -> None:
|
410 |
+
''' ... '''
|
411 |
+
# load state_dict
|
412 |
+
state_dict = torch.load(filepath, map_location='cpu')
|
413 |
+
|
414 |
+
# load essential parameters
|
415 |
+
self.src_modalities: dict[str, dict[str, Any]] = state_dict.pop('src_modalities')
|
416 |
+
self.tgt_modalities: dict[str, dict[str, Any]] = state_dict.pop('tgt_modalities')
|
417 |
+
|
418 |
+
# initialize model
|
419 |
+
self.net_ = nn.CNNResNet3DWithLinearClassifier(
|
420 |
+
self.src_modalities,
|
421 |
+
self.tgt_modalities,
|
422 |
+
)
|
423 |
+
|
424 |
+
# load model parameters
|
425 |
+
self.net_.load_state_dict(state_dict)
|
426 |
+
self.to(self.device)
|
427 |
+
|
428 |
+
def to(self, device: str) -> Self:
|
429 |
+
''' Mount model to the given device. '''
|
430 |
+
self.device = device
|
431 |
+
if hasattr(self, 'net_'): self.net_ = self.net_.to(device)
|
432 |
+
return self
|
433 |
+
|
434 |
+
@classmethod
|
435 |
+
def from_ckpt(cls, filepath: str) -> Self:
|
436 |
+
''' ... '''
|
437 |
+
obj = cls(None, None)
|
438 |
+
obj.load(filepath)
|
439 |
+
return obj
|
440 |
+
|
441 |
+
def _init_net(self):
|
442 |
+
""" ... """
|
443 |
+
net = nn.CNNResNet3DWithLinearClassifier(
|
444 |
+
self.src_modalities,
|
445 |
+
self.tgt_modalities,
|
446 |
+
).to(self.device)
|
447 |
+
|
448 |
+
# train on multiple GPUs using torch.nn.DataParallel
|
449 |
+
if self._device_ids is not None:
|
450 |
+
net = torch.nn.DataParallel(net, device_ids=self._device_ids)
|
451 |
+
|
452 |
+
# intialize model parameters using xavier_uniform
|
453 |
+
for p in net.parameters():
|
454 |
+
if p.dim() > 1:
|
455 |
+
torch.nn.init.xavier_uniform_(p)
|
456 |
+
|
457 |
+
return net
|
458 |
+
|
459 |
+
def _init_dataloader(self, x, y):
|
460 |
+
""" ... """
|
461 |
+
# split dataset
|
462 |
+
x_trn, x_vld, y_trn, y_vld = train_test_split(
|
463 |
+
x, y, test_size = 0.2, random_state = 0,
|
464 |
+
)
|
465 |
+
|
466 |
+
# initialize dataset and dataloader
|
467 |
+
# dat_trn = TransformerTrainingDataset(
|
468 |
+
dat_trn = Transformer2ndOrderBalancedTrainingDataset(
|
469 |
+
x_trn, y_trn,
|
470 |
+
self.src_modalities,
|
471 |
+
self.tgt_modalities,
|
472 |
+
dropout_rate = .5,
|
473 |
+
# dropout_strategy = 'compensated',
|
474 |
+
dropout_strategy = 'permutation',
|
475 |
+
)
|
476 |
+
|
477 |
+
dat_vld = TransformerValidationDataset(
|
478 |
+
x_vld, y_vld,
|
479 |
+
self.src_modalities,
|
480 |
+
self.tgt_modalities,
|
481 |
+
)
|
482 |
+
|
483 |
+
ldr_trn = DataLoader(
|
484 |
+
dataset = dat_trn,
|
485 |
+
batch_size = self.batch_size,
|
486 |
+
shuffle = True,
|
487 |
+
drop_last = False,
|
488 |
+
num_workers = self._dataloader_num_workers,
|
489 |
+
collate_fn = TransformerTrainingDataset.collate_fn,
|
490 |
+
# pin_memory = True
|
491 |
+
)
|
492 |
+
|
493 |
+
ldr_vld = DataLoader(
|
494 |
+
dataset = dat_vld,
|
495 |
+
batch_size = self.batch_size,
|
496 |
+
shuffle = False,
|
497 |
+
drop_last = False,
|
498 |
+
num_workers = self._dataloader_num_workers,
|
499 |
+
collate_fn = TransformerValidationDataset.collate_fn,
|
500 |
+
# pin_memory = True
|
501 |
+
)
|
502 |
+
|
503 |
+
return ldr_trn, ldr_vld
|
504 |
+
|
505 |
+
def _init_optimizer(self):
|
506 |
+
""" ... """
|
507 |
+
return torch.optim.AdamW(
|
508 |
+
self.net_.parameters(),
|
509 |
+
lr = self.lr,
|
510 |
+
betas = (0.9, 0.98),
|
511 |
+
weight_decay = self.weight_decay
|
512 |
+
)
|
513 |
+
|
514 |
+
def _init_scheduler(self, optimizer):
|
515 |
+
""" ... """
|
516 |
+
return torch.optim.lr_scheduler.OneCycleLR(
|
517 |
+
optimizer = optimizer,
|
518 |
+
max_lr = self.lr,
|
519 |
+
total_steps = self.num_epochs,
|
520 |
+
verbose = (self.verbose > 2)
|
521 |
+
)
|
522 |
+
|
523 |
+
def _init_loss_func(self,
|
524 |
+
num_per_cls: dict[str, tuple[int, int]],
|
525 |
+
) -> dict[str, Module]:
|
526 |
+
""" ... """
|
527 |
+
return {k: nn.SigmoidFocalLoss(
|
528 |
+
beta = self.beta,
|
529 |
+
gamma = self.gamma,
|
530 |
+
scale = self.scale,
|
531 |
+
num_per_cls = num_per_cls[k],
|
532 |
+
reduction = 'none',
|
533 |
+
) for k in self.tgt_modalities}
|
adrd/model/imaging_model.py
ADDED
@@ -0,0 +1,843 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__all__ = ['Transformer']
|
2 |
+
|
3 |
+
import wandb
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
import functools
|
7 |
+
import inspect
|
8 |
+
import monai
|
9 |
+
import random
|
10 |
+
|
11 |
+
from tqdm import tqdm
|
12 |
+
from functools import wraps
|
13 |
+
from sklearn.base import BaseEstimator
|
14 |
+
from sklearn.utils.validation import check_is_fitted
|
15 |
+
from sklearn.model_selection import train_test_split
|
16 |
+
from scipy.special import expit
|
17 |
+
from copy import deepcopy
|
18 |
+
from contextlib import suppress
|
19 |
+
from typing import Any, Self, Type
|
20 |
+
Tensor = Type[torch.Tensor]
|
21 |
+
Module = Type[torch.nn.Module]
|
22 |
+
from torch.utils.data import DataLoader
|
23 |
+
from monai.utils.type_conversion import convert_to_tensor
|
24 |
+
from monai.transforms import (
|
25 |
+
LoadImaged,
|
26 |
+
Compose,
|
27 |
+
CropForegroundd,
|
28 |
+
CopyItemsd,
|
29 |
+
SpatialPadd,
|
30 |
+
EnsureChannelFirstd,
|
31 |
+
Spacingd,
|
32 |
+
OneOf,
|
33 |
+
ScaleIntensityRanged,
|
34 |
+
HistogramNormalized,
|
35 |
+
RandSpatialCropSamplesd,
|
36 |
+
RandSpatialCropd,
|
37 |
+
CenterSpatialCropd,
|
38 |
+
RandCoarseDropoutd,
|
39 |
+
RandCoarseShuffled,
|
40 |
+
Resized,
|
41 |
+
)
|
42 |
+
|
43 |
+
# for DistributedDataParallel
|
44 |
+
import torch.distributed as dist
|
45 |
+
import torch.multiprocessing as mp
|
46 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
47 |
+
|
48 |
+
from .. import nn
|
49 |
+
from ..utils.misc import ProgressBar
|
50 |
+
from ..utils.misc import get_metrics_multitask, print_metrics_multitask
|
51 |
+
from ..utils.misc import convert_args_kwargs_to_kwargs
|
52 |
+
|
53 |
+
import warnings
|
54 |
+
warnings.filterwarnings("ignore")
|
55 |
+
|
56 |
+
|
57 |
+
def _manage_ctx_fit(func):
|
58 |
+
''' ... '''
|
59 |
+
@wraps(func)
|
60 |
+
def wrapper(*args, **kwargs):
|
61 |
+
# format arguments
|
62 |
+
kwargs = convert_args_kwargs_to_kwargs(func, args, kwargs)
|
63 |
+
|
64 |
+
if kwargs['self']._device_ids is None:
|
65 |
+
return func(**kwargs)
|
66 |
+
else:
|
67 |
+
# change primary device
|
68 |
+
default_device = kwargs['self'].device
|
69 |
+
kwargs['self'].device = kwargs['self']._device_ids[0]
|
70 |
+
rtn = func(**kwargs)
|
71 |
+
kwargs['self'].to(default_device)
|
72 |
+
return rtn
|
73 |
+
return wrapper
|
74 |
+
|
75 |
+
def collate_handle_corrupted(samples_list, dataset, labels, dtype=torch.half):
|
76 |
+
# print(len(samples_list))
|
77 |
+
orig_len = len(samples_list)
|
78 |
+
# for the loss to be consistent, we drop samples with NaN values in any of their corresponding crops
|
79 |
+
for i, s in enumerate(samples_list):
|
80 |
+
ic(s is None)
|
81 |
+
if s is None:
|
82 |
+
continue
|
83 |
+
samples_list = list(filter(lambda x: x is not None, samples_list))
|
84 |
+
|
85 |
+
if len(samples_list) == 0:
|
86 |
+
ic('recursive call')
|
87 |
+
return collate_handle_corrupted([dataset[random.randint(0, len(dataset)-1)] for _ in range(orig_len)], dataset, labels)
|
88 |
+
|
89 |
+
# collated_images = torch.stack([convert_to_tensor(s["image"]) for s in samples_list])
|
90 |
+
try:
|
91 |
+
if "image" in samples_list[0]:
|
92 |
+
samples_list = [s for s in samples_list if not torch.isnan(s["image"]).any()]
|
93 |
+
# print('samples list: ', len(samples_list))
|
94 |
+
collated_images = torch.stack([convert_to_tensor(s["image"]) for s in samples_list])
|
95 |
+
# print("here1")
|
96 |
+
collated_labels = {k: torch.Tensor([s["label"][k] if s["label"][k] is not None else 0 for s in samples_list]) for k in labels}
|
97 |
+
# print("here2")
|
98 |
+
collated_mask = {k: torch.Tensor([1 if s["label"][k] is not None else 0 for s in samples_list]) for k in labels}
|
99 |
+
# print("here3")
|
100 |
+
return {"image": collated_images,
|
101 |
+
"label": collated_labels,
|
102 |
+
"mask": collated_mask}
|
103 |
+
except:
|
104 |
+
return collate_handle_corrupted([dataset[random.randint(0, len(dataset)-1)] for _ in range(orig_len)], dataset, labels)
|
105 |
+
|
106 |
+
|
107 |
+
|
108 |
+
def get_backend(img_backend):
|
109 |
+
if img_backend == 'C3D':
|
110 |
+
return nn.C3D
|
111 |
+
elif img_backend == 'DenseNet':
|
112 |
+
return nn.DenseNet
|
113 |
+
|
114 |
+
|
115 |
+
class ImagingModel(BaseEstimator):
|
116 |
+
''' ... '''
|
117 |
+
def __init__(self,
|
118 |
+
tgt_modalities: list[str],
|
119 |
+
label_fractions: dict[str, float],
|
120 |
+
num_epochs: int = 32,
|
121 |
+
batch_size: int = 8,
|
122 |
+
batch_size_multiplier: int = 1,
|
123 |
+
lr: float = 1e-2,
|
124 |
+
weight_decay: float = 0.0,
|
125 |
+
beta: float = 0.9999,
|
126 |
+
gamma: float = 2.0,
|
127 |
+
bn_size: int = 4,
|
128 |
+
growth_rate: int = 12,
|
129 |
+
block_config: tuple = (3, 3, 3),
|
130 |
+
compression: float = 0.5,
|
131 |
+
num_init_features: int = 16,
|
132 |
+
drop_rate: float = 0.2,
|
133 |
+
criterion: str | None = None,
|
134 |
+
device: str = 'cpu',
|
135 |
+
cuda_devices: list = [1],
|
136 |
+
ckpt_path: str = '/home/skowshik/ADRD_repo/adrd_tool/dev/ckpt/ckpt.pt',
|
137 |
+
load_from_ckpt: bool = True,
|
138 |
+
save_intermediate_ckpts: bool = False,
|
139 |
+
data_parallel: bool = False,
|
140 |
+
verbose: int = 0,
|
141 |
+
img_backend: str | None = None,
|
142 |
+
label_distribution: dict = {},
|
143 |
+
wandb_ = 1,
|
144 |
+
_device_ids: list | None = None,
|
145 |
+
_dataloader_num_workers: int = 4,
|
146 |
+
_amp_enabled: bool = False,
|
147 |
+
) -> None:
|
148 |
+
''' ... '''
|
149 |
+
# for multiprocessing
|
150 |
+
self._rank = 0
|
151 |
+
self._lock = None
|
152 |
+
|
153 |
+
# positional parameters
|
154 |
+
self.tgt_modalities = tgt_modalities
|
155 |
+
|
156 |
+
# training parameters
|
157 |
+
self.label_fractions = label_fractions
|
158 |
+
self.num_epochs = num_epochs
|
159 |
+
self.batch_size = batch_size
|
160 |
+
self.batch_size_multiplier = batch_size_multiplier
|
161 |
+
self.lr = lr
|
162 |
+
self.weight_decay = weight_decay
|
163 |
+
self.beta = beta
|
164 |
+
self.gamma = gamma
|
165 |
+
self.bn_size = bn_size
|
166 |
+
self.growth_rate = growth_rate
|
167 |
+
self.block_config = block_config
|
168 |
+
self.compression = compression
|
169 |
+
self.num_init_features = num_init_features
|
170 |
+
self.drop_rate = drop_rate
|
171 |
+
self.criterion = criterion
|
172 |
+
self.device = device
|
173 |
+
self.cuda_devices = cuda_devices
|
174 |
+
self.ckpt_path = ckpt_path
|
175 |
+
self.load_from_ckpt = load_from_ckpt
|
176 |
+
self.save_intermediate_ckpts = save_intermediate_ckpts
|
177 |
+
self.data_parallel = data_parallel
|
178 |
+
self.verbose = verbose
|
179 |
+
self.img_backend = img_backend
|
180 |
+
self.label_distribution = label_distribution
|
181 |
+
self.wandb_ = wandb_
|
182 |
+
self._device_ids = _device_ids
|
183 |
+
self._dataloader_num_workers = _dataloader_num_workers
|
184 |
+
self._amp_enabled = _amp_enabled
|
185 |
+
self.scaler = torch.cuda.amp.GradScaler()
|
186 |
+
|
187 |
+
@_manage_ctx_fit
|
188 |
+
def fit(self, trn_list, vld_list, img_train_trans=None, img_vld_trans=None) -> Self:
|
189 |
+
# def fit(self, x, y) -> Self:
|
190 |
+
''' ... '''
|
191 |
+
|
192 |
+
# start a new wandb run to track this script
|
193 |
+
if self.wandb_ == 1:
|
194 |
+
wandb.init(
|
195 |
+
# set the wandb project where this run will be logged
|
196 |
+
project="ADRD_main",
|
197 |
+
|
198 |
+
# track hyperparameters and run metadata
|
199 |
+
config={
|
200 |
+
"Model": "DenseNet",
|
201 |
+
"Loss": 'Focalloss',
|
202 |
+
"EMB": "ALL_EMB",
|
203 |
+
"epochs": 256,
|
204 |
+
}
|
205 |
+
)
|
206 |
+
wandb.run.log_code("/home/skowshik/ADRD_repo/pipeline_v1_main/adrd_tool")
|
207 |
+
else:
|
208 |
+
wandb.init(mode="disabled")
|
209 |
+
# for PyTorch computational efficiency
|
210 |
+
torch.set_num_threads(1)
|
211 |
+
print(self.criterion)
|
212 |
+
|
213 |
+
# initialize neural network
|
214 |
+
self._init_net()
|
215 |
+
|
216 |
+
# for k, info in self.src_modalities.items():
|
217 |
+
# if info['type'] == 'imaging' and self.img_net != 'EMB':
|
218 |
+
# info['shape'] = (1,) + (self.img_size,) * 3
|
219 |
+
# info['img_shape'] = (1,) + (self.img_size,) * 3
|
220 |
+
# print(info['shape'])
|
221 |
+
|
222 |
+
# initialize dataloaders
|
223 |
+
# ldr_trn, ldr_vld = self._init_dataloader(x, y)
|
224 |
+
# ldr_trn, ldr_vld = self._init_dataloader(x_trn, x_vld, y_trn, y_vld)
|
225 |
+
ldr_trn, ldr_vld = self._init_dataloader(trn_list, vld_list, img_train_trans=img_train_trans, img_vld_trans=img_vld_trans)
|
226 |
+
|
227 |
+
# initialize optimizer and scheduler
|
228 |
+
if not self.load_from_ckpt:
|
229 |
+
self.optimizer = self._init_optimizer()
|
230 |
+
self.scheduler = self._init_scheduler(self.optimizer)
|
231 |
+
|
232 |
+
# gradient scaler for AMP
|
233 |
+
if self._amp_enabled:
|
234 |
+
self.scaler = torch.cuda.amp.GradScaler()
|
235 |
+
|
236 |
+
# initialize focal loss function
|
237 |
+
self.loss_fn = {}
|
238 |
+
|
239 |
+
for k in self.tgt_modalities:
|
240 |
+
if self.label_fractions[k] >= 0.3:
|
241 |
+
alpha = -1
|
242 |
+
else:
|
243 |
+
alpha = pow((1 - self.label_fractions[k]), 2)
|
244 |
+
# alpha = -1
|
245 |
+
self.loss_fn[k] = nn.SigmoidFocalLoss(
|
246 |
+
alpha = alpha,
|
247 |
+
gamma = self.gamma,
|
248 |
+
reduction = 'none'
|
249 |
+
)
|
250 |
+
|
251 |
+
# to record the best validation performance criterion
|
252 |
+
if self.criterion is not None:
|
253 |
+
best_crit = None
|
254 |
+
best_crit_AUPR = None
|
255 |
+
|
256 |
+
# progress bar for epoch loops
|
257 |
+
if self.verbose == 1:
|
258 |
+
with self._lock if self._lock is not None else suppress():
|
259 |
+
pbr_epoch = tqdm(
|
260 |
+
desc = 'Rank {:02d}'.format(self._rank),
|
261 |
+
total = self.num_epochs,
|
262 |
+
position = self._rank,
|
263 |
+
ascii = True,
|
264 |
+
leave = False,
|
265 |
+
bar_format='{l_bar}{r_bar}'
|
266 |
+
)
|
267 |
+
|
268 |
+
# Define a hook function to print and store the gradient of a layer
|
269 |
+
def print_and_store_grad(grad, grad_list):
|
270 |
+
grad_list.append(grad)
|
271 |
+
# print(grad)
|
272 |
+
|
273 |
+
# grad_list = []
|
274 |
+
# self.net_.modules_emb_src['img_MRI_T1'].downsample[0].weight.register_hook(lambda grad: print_and_store_grad(grad, grad_list))
|
275 |
+
|
276 |
+
# lambda_coeff = 0.0001
|
277 |
+
# margin_loss = torch.nn.MarginRankingLoss(reduction='sum', margin=0.05)
|
278 |
+
|
279 |
+
# training loop
|
280 |
+
for epoch in range(self.start_epoch, self.num_epochs):
|
281 |
+
met_trn = self.train_one_epoch(ldr_trn, epoch)
|
282 |
+
met_vld = self.validate_one_epoch(ldr_vld, epoch)
|
283 |
+
|
284 |
+
print(self.ckpt_path.split('/')[-1])
|
285 |
+
|
286 |
+
# save the model if it has the best validation performance criterion by far
|
287 |
+
if self.criterion is None: continue
|
288 |
+
|
289 |
+
|
290 |
+
# is current criterion better than previous best?
|
291 |
+
curr_crit = np.mean([met_vld[i][self.criterion] for i in range(len(self.tgt_modalities))])
|
292 |
+
curr_crit_AUPR = np.mean([met_vld[i]["AUC (PR)"] for i in range(len(self.tgt_modalities))])
|
293 |
+
# AUROC
|
294 |
+
if best_crit is None or np.isnan(best_crit):
|
295 |
+
is_better = True
|
296 |
+
elif self.criterion == 'Loss' and best_crit >= curr_crit:
|
297 |
+
is_better = True
|
298 |
+
elif self.criterion != 'Loss' and best_crit <= curr_crit :
|
299 |
+
is_better = True
|
300 |
+
else:
|
301 |
+
is_better = False
|
302 |
+
|
303 |
+
# AUPR
|
304 |
+
if best_crit_AUPR is None or np.isnan(best_crit_AUPR):
|
305 |
+
is_better_AUPR = True
|
306 |
+
elif best_crit_AUPR <= curr_crit_AUPR :
|
307 |
+
is_better_AUPR = True
|
308 |
+
else:
|
309 |
+
is_better_AUPR = False
|
310 |
+
|
311 |
+
# update best criterion
|
312 |
+
if is_better_AUPR:
|
313 |
+
best_crit_AUPR = curr_crit_AUPR
|
314 |
+
if self.save_intermediate_ckpts:
|
315 |
+
print(f"Saving the model to {self.ckpt_path[:-3]}_AUPR.pt...")
|
316 |
+
self.save(self.ckpt_path[:-3]+"_AUPR.pt", epoch)
|
317 |
+
|
318 |
+
if is_better:
|
319 |
+
best_crit = curr_crit
|
320 |
+
best_state_dict = deepcopy(self.net_.state_dict())
|
321 |
+
if self.save_intermediate_ckpts:
|
322 |
+
print(f"Saving the model to {self.ckpt_path}...")
|
323 |
+
self.save(self.ckpt_path, epoch)
|
324 |
+
|
325 |
+
if self.verbose > 2:
|
326 |
+
print('Best {}: {}'.format(self.criterion, best_crit))
|
327 |
+
print('Best {}: {}'.format('AUC (PR)', best_crit_AUPR))
|
328 |
+
|
329 |
+
if self.verbose == 1:
|
330 |
+
with self._lock if self._lock is not None else suppress():
|
331 |
+
pbr_epoch.update(1)
|
332 |
+
pbr_epoch.refresh()
|
333 |
+
|
334 |
+
return self
|
335 |
+
|
336 |
+
def train_one_epoch(self, ldr_trn, epoch):
|
337 |
+
|
338 |
+
# progress bar for batch loops
|
339 |
+
if self.verbose > 1:
|
340 |
+
pbr_batch = ProgressBar(len(ldr_trn.dataset), 'Epoch {:03d} (TRN)'.format(epoch))
|
341 |
+
|
342 |
+
torch.set_grad_enabled(True)
|
343 |
+
self.net_.train()
|
344 |
+
|
345 |
+
scores_trn, y_true_trn, y_mask_trn = [], [], []
|
346 |
+
losses_trn = [[] for _ in self.tgt_modalities]
|
347 |
+
iters = len(ldr_trn)
|
348 |
+
print(iters)
|
349 |
+
for n_iter, batch_data in enumerate(ldr_trn):
|
350 |
+
# if len(batch_data["image"]) < self.batch_size:
|
351 |
+
# continue
|
352 |
+
|
353 |
+
x_batch = batch_data["image"].to(self.device, non_blocking=True)
|
354 |
+
y_batch = {k: v.to(self.device, non_blocking=True) for k,v in batch_data["label"].items()}
|
355 |
+
y_mask = {k: v.to(self.device, non_blocking=True) for k,v in batch_data["mask"].items()}
|
356 |
+
|
357 |
+
with torch.autocast(
|
358 |
+
device_type = 'cpu' if self.device == 'cpu' else 'cuda',
|
359 |
+
dtype = torch.bfloat16 if self.device == 'cpu' else torch.float16,
|
360 |
+
enabled = self._amp_enabled,
|
361 |
+
):
|
362 |
+
|
363 |
+
outputs = self.net_(x_batch, shap=False)
|
364 |
+
# print(outputs.shape)
|
365 |
+
# calculate multitask loss
|
366 |
+
loss = 0
|
367 |
+
for i, k in enumerate(self.tgt_modalities):
|
368 |
+
loss_task = self.loss_fn[k](outputs[k], y_batch[k])
|
369 |
+
msk_loss_task = loss_task * y_mask[k]
|
370 |
+
msk_loss_mean = msk_loss_task.sum() / y_mask[k].sum()
|
371 |
+
loss += msk_loss_mean
|
372 |
+
losses_trn[i] += msk_loss_task.detach().cpu().numpy().tolist()
|
373 |
+
|
374 |
+
# backward
|
375 |
+
if self._amp_enabled:
|
376 |
+
self.scaler.scale(loss).backward()
|
377 |
+
else:
|
378 |
+
loss.backward()
|
379 |
+
|
380 |
+
# print(len(grad_list), len(grad_list[-1]))
|
381 |
+
# print(f"Gradient at {n_iter}: {grad_list[-1][0]}")
|
382 |
+
|
383 |
+
# update parameters
|
384 |
+
if n_iter != 0 and n_iter % self.batch_size_multiplier == 0:
|
385 |
+
if self._amp_enabled:
|
386 |
+
self.scaler.step(self.optimizer)
|
387 |
+
self.scaler.update()
|
388 |
+
self.optimizer.zero_grad()
|
389 |
+
else:
|
390 |
+
self.optimizer.step()
|
391 |
+
self.optimizer.zero_grad()
|
392 |
+
# set self.scheduler
|
393 |
+
self.scheduler.step(epoch + n_iter / iters)
|
394 |
+
# print(f"Weight: {self.net_.module.features[0].weight[0]}")
|
395 |
+
|
396 |
+
''' TODO: change array to dictionary later '''
|
397 |
+
outputs = torch.stack(list(outputs.values()), dim=1)
|
398 |
+
y_batch = torch.stack(list(y_batch.values()), dim=1)
|
399 |
+
y_mask = torch.stack(list(y_mask.values()), dim=1)
|
400 |
+
|
401 |
+
# save outputs to evaluate performance later
|
402 |
+
scores_trn.append(outputs.detach().to(torch.float).cpu())
|
403 |
+
y_true_trn.append(y_batch.cpu())
|
404 |
+
y_mask_trn.append(y_mask.cpu())
|
405 |
+
|
406 |
+
# log metrics to wandb
|
407 |
+
|
408 |
+
# update progress bar
|
409 |
+
if self.verbose > 1:
|
410 |
+
batch_size = len(x_batch)
|
411 |
+
pbr_batch.update(batch_size, {})
|
412 |
+
pbr_batch.refresh()
|
413 |
+
|
414 |
+
# clear cuda cache
|
415 |
+
if "cuda" in self.device:
|
416 |
+
torch.cuda.empty_cache()
|
417 |
+
|
418 |
+
# for better tqdm progress bar display
|
419 |
+
if self.verbose > 1:
|
420 |
+
pbr_batch.close()
|
421 |
+
|
422 |
+
# # set self.scheduler
|
423 |
+
# self.scheduler.step()
|
424 |
+
|
425 |
+
# calculate and print training performance metrics
|
426 |
+
scores_trn = torch.cat(scores_trn)
|
427 |
+
y_true_trn = torch.cat(y_true_trn)
|
428 |
+
y_mask_trn = torch.cat(y_mask_trn)
|
429 |
+
y_pred_trn = (scores_trn > 0).to(torch.int)
|
430 |
+
y_prob_trn = torch.sigmoid(scores_trn)
|
431 |
+
met_trn = get_metrics_multitask(
|
432 |
+
y_true_trn.numpy(),
|
433 |
+
y_pred_trn.numpy(),
|
434 |
+
y_prob_trn.numpy(),
|
435 |
+
y_mask_trn.numpy()
|
436 |
+
)
|
437 |
+
|
438 |
+
# add loss to metrics
|
439 |
+
for i in range(len(self.tgt_modalities)):
|
440 |
+
met_trn[i]['Loss'] = np.mean(losses_trn[i])
|
441 |
+
|
442 |
+
wandb.log({f"Train loss {list(self.tgt_modalities)[i]}": met_trn[i]['Loss'] for i in range(len(self.tgt_modalities))}, step=epoch)
|
443 |
+
wandb.log({f"Train Balanced Accuracy {list(self.tgt_modalities)[i]}": met_trn[i]['Balanced Accuracy'] for i in range(len(self.tgt_modalities))}, step=epoch)
|
444 |
+
|
445 |
+
wandb.log({f"Train AUC (ROC) {list(self.tgt_modalities)[i]}": met_trn[i]['AUC (ROC)'] for i in range(len(self.tgt_modalities))}, step=epoch)
|
446 |
+
wandb.log({f"Train AUPR {list(self.tgt_modalities)[i]}": met_trn[i]['AUC (PR)'] for i in range(len(self.tgt_modalities))}, step=epoch)
|
447 |
+
|
448 |
+
if self.verbose > 2:
|
449 |
+
print_metrics_multitask(met_trn)
|
450 |
+
|
451 |
+
return met_trn
|
452 |
+
|
453 |
+
# @torch.no_grad()
|
454 |
+
def validate_one_epoch(self, ldr_vld, epoch):
|
455 |
+
# progress bar for validation
|
456 |
+
if self.verbose > 1:
|
457 |
+
pbr_batch = ProgressBar(len(ldr_vld.dataset), 'Epoch {:03d} (VLD)'.format(epoch))
|
458 |
+
|
459 |
+
# set model to validation mode
|
460 |
+
torch.set_grad_enabled(False)
|
461 |
+
self.net_.eval()
|
462 |
+
|
463 |
+
scores_vld, y_true_vld, y_mask_vld = [], [], []
|
464 |
+
losses_vld = [[] for _ in self.tgt_modalities]
|
465 |
+
for batch_data in ldr_vld:
|
466 |
+
# if len(batch_data["image"]) < self.batch_size:
|
467 |
+
# continue
|
468 |
+
x_batch = batch_data["image"].to(self.device, non_blocking=True)
|
469 |
+
y_batch = {k: v.to(self.device, non_blocking=True) for k,v in batch_data["label"].items()}
|
470 |
+
y_mask = {k: v.to(self.device, non_blocking=True) for k,v in batch_data["mask"].items()}
|
471 |
+
|
472 |
+
# forward
|
473 |
+
with torch.autocast(
|
474 |
+
device_type = 'cpu' if self.device == 'cpu' else 'cuda',
|
475 |
+
dtype = torch.bfloat16 if self.device == 'cpu' else torch.float16,
|
476 |
+
enabled = self._amp_enabled
|
477 |
+
):
|
478 |
+
|
479 |
+
outputs = self.net_(x_batch, shap=False)
|
480 |
+
|
481 |
+
# calculate multitask loss
|
482 |
+
for i, k in enumerate(self.tgt_modalities):
|
483 |
+
loss_task = self.loss_fn[k](outputs[k], y_batch[k])
|
484 |
+
msk_loss_task = loss_task * y_mask[k]
|
485 |
+
losses_vld[i] += msk_loss_task.detach().cpu().numpy().tolist()
|
486 |
+
|
487 |
+
''' TODO: change array to dictionary later '''
|
488 |
+
outputs = torch.stack(list(outputs.values()), dim=1)
|
489 |
+
y_batch = torch.stack(list(y_batch.values()), dim=1)
|
490 |
+
y_mask = torch.stack(list(y_mask.values()), dim=1)
|
491 |
+
|
492 |
+
# save outputs to evaluate performance later
|
493 |
+
scores_vld.append(outputs.detach().to(torch.float).cpu())
|
494 |
+
y_true_vld.append(y_batch.cpu())
|
495 |
+
y_mask_vld.append(y_mask.cpu())
|
496 |
+
|
497 |
+
# update progress bar
|
498 |
+
if self.verbose > 1:
|
499 |
+
batch_size = len(x_batch)
|
500 |
+
pbr_batch.update(batch_size, {})
|
501 |
+
pbr_batch.refresh()
|
502 |
+
|
503 |
+
# clear cuda cache
|
504 |
+
if "cuda" in self.device:
|
505 |
+
torch.cuda.empty_cache()
|
506 |
+
|
507 |
+
# for better tqdm progress bar display
|
508 |
+
if self.verbose > 1:
|
509 |
+
pbr_batch.close()
|
510 |
+
|
511 |
+
# calculate and print validation performance metrics
|
512 |
+
scores_vld = torch.cat(scores_vld)
|
513 |
+
y_true_vld = torch.cat(y_true_vld)
|
514 |
+
y_mask_vld = torch.cat(y_mask_vld)
|
515 |
+
y_pred_vld = (scores_vld > 0).to(torch.int)
|
516 |
+
y_prob_vld = torch.sigmoid(scores_vld)
|
517 |
+
met_vld = get_metrics_multitask(
|
518 |
+
y_true_vld.numpy(),
|
519 |
+
y_pred_vld.numpy(),
|
520 |
+
y_prob_vld.numpy(),
|
521 |
+
y_mask_vld.numpy()
|
522 |
+
)
|
523 |
+
|
524 |
+
# add loss to metrics
|
525 |
+
for i in range(len(self.tgt_modalities)):
|
526 |
+
met_vld[i]['Loss'] = np.mean(losses_vld[i])
|
527 |
+
|
528 |
+
wandb.log({f"Validation loss {list(self.tgt_modalities)[i]}": met_vld[i]['Loss'] for i in range(len(self.tgt_modalities))}, step=epoch)
|
529 |
+
wandb.log({f"Validation Balanced Accuracy {list(self.tgt_modalities)[i]}": met_vld[i]['Balanced Accuracy'] for i in range(len(self.tgt_modalities))}, step=epoch)
|
530 |
+
|
531 |
+
wandb.log({f"Validation AUC (ROC) {list(self.tgt_modalities)[i]}": met_vld[i]['AUC (ROC)'] for i in range(len(self.tgt_modalities))}, step=epoch)
|
532 |
+
wandb.log({f"Validation AUPR {list(self.tgt_modalities)[i]}": met_vld[i]['AUC (PR)'] for i in range(len(self.tgt_modalities))}, step=epoch)
|
533 |
+
|
534 |
+
if self.verbose > 2:
|
535 |
+
print_metrics_multitask(met_vld)
|
536 |
+
|
537 |
+
return met_vld
|
538 |
+
|
539 |
+
|
540 |
+
def save(self, filepath: str, epoch: int = 0) -> None:
|
541 |
+
''' ... '''
|
542 |
+
check_is_fitted(self)
|
543 |
+
if self.data_parallel:
|
544 |
+
state_dict = self.net_.module.state_dict()
|
545 |
+
else:
|
546 |
+
state_dict = self.net_.state_dict()
|
547 |
+
|
548 |
+
# attach model hyper parameters
|
549 |
+
state_dict['tgt_modalities'] = self.tgt_modalities
|
550 |
+
state_dict['optimizer'] = self.optimizer
|
551 |
+
state_dict['bn_size'] = self.bn_size
|
552 |
+
state_dict['growth_rate'] = self.growth_rate
|
553 |
+
state_dict['block_config'] = self.block_config
|
554 |
+
state_dict['compression'] = self.compression
|
555 |
+
state_dict['num_init_features'] = self.num_init_features
|
556 |
+
state_dict['drop_rate'] = self.drop_rate
|
557 |
+
state_dict['epoch'] = epoch
|
558 |
+
|
559 |
+
if self.scaler is not None:
|
560 |
+
state_dict['scaler'] = self.scaler.state_dict()
|
561 |
+
if self.label_distribution:
|
562 |
+
state_dict['label_distribution'] = self.label_distribution
|
563 |
+
|
564 |
+
torch.save(state_dict, filepath)
|
565 |
+
|
566 |
+
def load(self, filepath: str, map_location: str = 'cpu', how='latest') -> None:
|
567 |
+
''' ... '''
|
568 |
+
# load state_dict
|
569 |
+
if how == 'latest':
|
570 |
+
if torch.load(filepath)['epoch'] > torch.load(f'{filepath[:-3]}_AUPR.pt')['epoch']:
|
571 |
+
print("Loading model saved using AUROC")
|
572 |
+
state_dict = torch.load(filepath, map_location=map_location)
|
573 |
+
else:
|
574 |
+
print("Loading model saved using AUPR")
|
575 |
+
state_dict = torch.load(f'{filepath[:-3]}_AUPR.pt', map_location=map_location)
|
576 |
+
else:
|
577 |
+
state_dict = torch.load(filepath, map_location=map_location)
|
578 |
+
|
579 |
+
# load data modalities
|
580 |
+
self.tgt_modalities: dict[str, dict[str, Any]] = state_dict.pop('tgt_modalities')
|
581 |
+
if 'label_distribution' in state_dict:
|
582 |
+
self.label_distribution: dict[str, dict[int, int]] = state_dict.pop('label_distribution')
|
583 |
+
if 'optimizer' in state_dict:
|
584 |
+
self.optimizer = state_dict.pop('optimizer')
|
585 |
+
if 'bn_size' in state_dict:
|
586 |
+
self.bn_size = state_dict.pop('bn_size')
|
587 |
+
if 'growth_rate' in state_dict:
|
588 |
+
self.growth_rate = state_dict.pop('growth_rate')
|
589 |
+
if 'block_config' in state_dict:
|
590 |
+
self.block_config = state_dict.pop('block_config')
|
591 |
+
if 'compression' in state_dict:
|
592 |
+
self.compression = state_dict.pop('compression')
|
593 |
+
if 'num_init_features' in state_dict:
|
594 |
+
self.num_init_features = state_dict.pop('num_init_features')
|
595 |
+
if 'drop_rate' in state_dict:
|
596 |
+
self.drop_rate = state_dict.pop('drop_rate')
|
597 |
+
if 'epoch' in state_dict:
|
598 |
+
self.start_epoch = state_dict.pop('epoch')
|
599 |
+
print(f'Epoch: {self.start_epoch}')
|
600 |
+
|
601 |
+
# initialize model
|
602 |
+
|
603 |
+
self.net_ = get_backend(self.img_backend)(
|
604 |
+
tgt_modalities = self.tgt_modalities,
|
605 |
+
bn_size = self.bn_size,
|
606 |
+
growth_rate=self.growth_rate,
|
607 |
+
block_config=self.block_config,
|
608 |
+
compression=self.compression,
|
609 |
+
num_init_features=self.num_init_features,
|
610 |
+
drop_rate=self.drop_rate,
|
611 |
+
load_from_ckpt=self.load_from_ckpt
|
612 |
+
)
|
613 |
+
print(self.net_)
|
614 |
+
|
615 |
+
if 'scaler' in state_dict and state_dict['scaler']:
|
616 |
+
self.scaler.load_state_dict(state_dict.pop('scaler'))
|
617 |
+
self.net_.load_state_dict(state_dict)
|
618 |
+
check_is_fitted(self)
|
619 |
+
self.net_.to(self.device)
|
620 |
+
|
621 |
+
def to(self, device: str) -> Self:
|
622 |
+
''' Mount model to the given device. '''
|
623 |
+
self.device = device
|
624 |
+
if hasattr(self, 'model'): self.net_ = self.net_.to(device)
|
625 |
+
return self
|
626 |
+
|
627 |
+
@classmethod
|
628 |
+
def from_ckpt(cls, filepath: str, device='cpu', img_backend=None, load_from_ckpt=True, how='latest') -> Self:
|
629 |
+
''' ... '''
|
630 |
+
obj = cls(None, None, None,device=device)
|
631 |
+
if device == 'cuda':
|
632 |
+
obj.device = "{}:{}".format(obj.device, str(obj.cuda_devices[0]))
|
633 |
+
print(obj.device)
|
634 |
+
obj.img_backend=img_backend
|
635 |
+
obj.load_from_ckpt = load_from_ckpt
|
636 |
+
obj.load(filepath, map_location=obj.device, how=how)
|
637 |
+
return obj
|
638 |
+
|
639 |
+
def _init_net(self):
|
640 |
+
""" ... """
|
641 |
+
self.start_epoch = 0
|
642 |
+
# set the device for use
|
643 |
+
if self.device == 'cuda':
|
644 |
+
self.device = "{}:{}".format(self.device, str(self.cuda_devices[0]))
|
645 |
+
# self.load(self.ckpt_path, map_location=self.device)
|
646 |
+
# print("Loading model from checkpoint...")
|
647 |
+
# self.load(self.ckpt_path, map_location=self.device)
|
648 |
+
|
649 |
+
if self.load_from_ckpt:
|
650 |
+
try:
|
651 |
+
print("Loading model from checkpoint...")
|
652 |
+
self.load(self.ckpt_path, map_location=self.device)
|
653 |
+
except:
|
654 |
+
print("Cannot load from checkpoint. Initializing new model...")
|
655 |
+
self.load_from_ckpt = False
|
656 |
+
|
657 |
+
if not self.load_from_ckpt:
|
658 |
+
self.net_ = get_backend(self.img_backend)(
|
659 |
+
tgt_modalities = self.tgt_modalities,
|
660 |
+
bn_size = self.bn_size,
|
661 |
+
growth_rate=self.growth_rate,
|
662 |
+
block_config=self.block_config,
|
663 |
+
compression=self.compression,
|
664 |
+
num_init_features=self.num_init_features,
|
665 |
+
drop_rate=self.drop_rate,
|
666 |
+
load_from_ckpt=self.load_from_ckpt
|
667 |
+
)
|
668 |
+
|
669 |
+
# # intialize model parameters using xavier_uniform
|
670 |
+
# for p in self.net_.parameters():
|
671 |
+
# if p.dim() > 1:
|
672 |
+
# torch.nn.init.xavier_uniform_(p)
|
673 |
+
|
674 |
+
self.net_.to(self.device)
|
675 |
+
|
676 |
+
# Initialize the number of GPUs
|
677 |
+
if self.data_parallel and torch.cuda.device_count() > 1:
|
678 |
+
print("Available", torch.cuda.device_count(), "GPUs!")
|
679 |
+
self.net_ = torch.nn.DataParallel(self.net_, device_ids=self.cuda_devices)
|
680 |
+
|
681 |
+
# return net
|
682 |
+
|
683 |
+
def _init_dataloader(self, trn_list, vld_list, img_train_trans=None, img_vld_trans=None):
|
684 |
+
# def _init_dataloader(self, x, y):
|
685 |
+
""" ... """
|
686 |
+
# # split dataset
|
687 |
+
# x_trn, x_vld, y_trn, y_vld = train_test_split(
|
688 |
+
# x, y, test_size = 0.2, random_state = 0,
|
689 |
+
# )
|
690 |
+
|
691 |
+
# # initialize dataset and dataloader
|
692 |
+
# dat_trn = CNNTrainingValidationDataset(
|
693 |
+
# x_trn, y_trn,
|
694 |
+
# self.tgt_modalities,
|
695 |
+
# img_transform=img_train_trans,
|
696 |
+
# )
|
697 |
+
|
698 |
+
# dat_vld = CNNTrainingValidationDataset(
|
699 |
+
# x_vld, y_vld,
|
700 |
+
# self.tgt_modalities,
|
701 |
+
# img_transform=img_vld_trans,
|
702 |
+
# )
|
703 |
+
|
704 |
+
dat_trn = monai.data.Dataset(data=trn_list, transform=img_train_trans)
|
705 |
+
dat_vld = monai.data.Dataset(data=vld_list, transform=img_vld_trans)
|
706 |
+
collate_fn_trn = functools.partial(collate_handle_corrupted, dataset=dat_trn, dtype=torch.FloatTensor, labels=self.tgt_modalities)
|
707 |
+
collate_fn_vld = functools.partial(collate_handle_corrupted, dataset=dat_vld, dtype=torch.FloatTensor, labels=self.tgt_modalities)
|
708 |
+
|
709 |
+
ldr_trn = DataLoader(
|
710 |
+
dataset = dat_trn,
|
711 |
+
batch_size = self.batch_size,
|
712 |
+
shuffle = True,
|
713 |
+
drop_last = False,
|
714 |
+
num_workers = self._dataloader_num_workers,
|
715 |
+
collate_fn = collate_fn_trn,
|
716 |
+
# pin_memory = True
|
717 |
+
)
|
718 |
+
|
719 |
+
ldr_vld = DataLoader(
|
720 |
+
dataset = dat_vld,
|
721 |
+
batch_size = self.batch_size,
|
722 |
+
shuffle = False,
|
723 |
+
drop_last = False,
|
724 |
+
num_workers = self._dataloader_num_workers,
|
725 |
+
collate_fn = collate_fn_vld,
|
726 |
+
# pin_memory = True
|
727 |
+
)
|
728 |
+
|
729 |
+
return ldr_trn, ldr_vld
|
730 |
+
|
731 |
+
def _init_optimizer(self):
|
732 |
+
""" ... """
|
733 |
+
params = list(self.net_.parameters())
|
734 |
+
# for p in params:
|
735 |
+
# print(p.requires_grad)
|
736 |
+
return torch.optim.AdamW(
|
737 |
+
params,
|
738 |
+
lr = self.lr,
|
739 |
+
betas = (0.9, 0.98),
|
740 |
+
weight_decay = self.weight_decay
|
741 |
+
)
|
742 |
+
|
743 |
+
def _init_scheduler(self, optimizer):
|
744 |
+
""" ... """
|
745 |
+
# return torch.optim.lr_scheduler.OneCycleLR(
|
746 |
+
# optimizer = optimizer,
|
747 |
+
# max_lr = self.lr,
|
748 |
+
# total_steps = self.num_epochs,
|
749 |
+
# verbose = (self.verbose > 2)
|
750 |
+
# )
|
751 |
+
|
752 |
+
# return torch.optim.lr_scheduler.CosineAnnealingLR(
|
753 |
+
# optimizer=optimizer,
|
754 |
+
# T_max=64,
|
755 |
+
# verbose=(self.verbose > 2)
|
756 |
+
# )
|
757 |
+
|
758 |
+
return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
|
759 |
+
optimizer=optimizer,
|
760 |
+
T_0=64,
|
761 |
+
T_mult=2,
|
762 |
+
eta_min = 0,
|
763 |
+
verbose=(self.verbose > 2)
|
764 |
+
)
|
765 |
+
|
766 |
+
def _init_loss_func(self,
|
767 |
+
num_per_cls: dict[str, tuple[int, int]],
|
768 |
+
) -> dict[str, Module]:
|
769 |
+
""" ... """
|
770 |
+
return {k: nn.SigmoidFocalLossBeta(
|
771 |
+
beta = self.beta,
|
772 |
+
gamma = self.gamma,
|
773 |
+
num_per_cls = num_per_cls[k],
|
774 |
+
reduction = 'none',
|
775 |
+
) for k in self.tgt_modalities}
|
776 |
+
|
777 |
+
def _proc_fit(self):
|
778 |
+
""" ... """
|
779 |
+
|
780 |
+
def _init_test_dataloader(self, batch_size, tst_list, img_tst_trans=None):
|
781 |
+
# input validation
|
782 |
+
check_is_fitted(self)
|
783 |
+
print(self.device)
|
784 |
+
|
785 |
+
# for PyTorch computational efficiency
|
786 |
+
torch.set_num_threads(1)
|
787 |
+
|
788 |
+
# set model to eval mode
|
789 |
+
torch.set_grad_enabled(False)
|
790 |
+
self.net_.eval()
|
791 |
+
|
792 |
+
dat_tst = monai.data.Dataset(data=tst_list, transform=img_tst_trans)
|
793 |
+
collate_fn_tst = functools.partial(collate_handle_corrupted, dataset=dat_tst, dtype=torch.FloatTensor, labels=self.tgt_modalities)
|
794 |
+
# print(collate_fn_tst)
|
795 |
+
|
796 |
+
ldr_tst = DataLoader(
|
797 |
+
dataset = dat_tst,
|
798 |
+
batch_size = batch_size,
|
799 |
+
shuffle = False,
|
800 |
+
drop_last = False,
|
801 |
+
num_workers = self._dataloader_num_workers,
|
802 |
+
collate_fn = collate_fn_tst,
|
803 |
+
# pin_memory = True
|
804 |
+
)
|
805 |
+
return ldr_tst
|
806 |
+
|
807 |
+
|
808 |
+
def predict_logits(self,
|
809 |
+
ldr_tst: Any | None = None,
|
810 |
+
) -> list[dict[str, float]]:
|
811 |
+
|
812 |
+
# run model and collect results
|
813 |
+
logits: list[dict[str, float]] = []
|
814 |
+
for batch_data in tqdm(ldr_tst):
|
815 |
+
# print(batch_data["image"])
|
816 |
+
if len(batch_data) == 0:
|
817 |
+
continue
|
818 |
+
x_batch = batch_data["image"].to(self.device, non_blocking=True)
|
819 |
+
outputs = self.net_(x_batch, shap=False)
|
820 |
+
|
821 |
+
# convert output from dict-of-list to list of dict, then append
|
822 |
+
tmp = {k: outputs[k].tolist() for k in self.tgt_modalities}
|
823 |
+
tmp = [{k: tmp[k][i] for k in self.tgt_modalities} for i in range(len(next(iter(tmp.values()))))]
|
824 |
+
logits += tmp
|
825 |
+
|
826 |
+
return logits
|
827 |
+
|
828 |
+
def predict_proba(self,
|
829 |
+
ldr_tst: Any | None = None,
|
830 |
+
temperature: float = 1.0,
|
831 |
+
) -> list[dict[str, float]]:
|
832 |
+
''' ... '''
|
833 |
+
logits = self.predict_logits(ldr_tst)
|
834 |
+
print("got logits")
|
835 |
+
return logits, [{k: expit(smp[k] / temperature) for k in self.tgt_modalities} for smp in logits]
|
836 |
+
|
837 |
+
def predict(self,
|
838 |
+
ldr_tst: Any | None = None,
|
839 |
+
) -> list[dict[str, int]]:
|
840 |
+
''' ... '''
|
841 |
+
logits, proba = self.predict_proba(ldr_tst)
|
842 |
+
print("got proba")
|
843 |
+
return logits, proba, [{k: int(smp[k] > 0.5) for k in self.tgt_modalities} for smp in proba]
|
adrd/model/train_resnet.py
ADDED
@@ -0,0 +1,484 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import tqdm
|
4 |
+
from sklearn.base import BaseEstimator
|
5 |
+
from sklearn.utils.validation import check_is_fitted
|
6 |
+
from sklearn.model_selection import train_test_split
|
7 |
+
from scipy.special import expit
|
8 |
+
from copy import deepcopy
|
9 |
+
from contextlib import suppress
|
10 |
+
from typing import Any, Self
|
11 |
+
from icecream import ic
|
12 |
+
|
13 |
+
from .. import nn
|
14 |
+
from ..utils import TransformerTrainingDataset
|
15 |
+
from ..utils import TransformerValidationDataset
|
16 |
+
from ..utils import MissingMasker
|
17 |
+
from ..utils import ConstantImputer
|
18 |
+
from ..utils import Formatter
|
19 |
+
from ..utils.misc import ProgressBar
|
20 |
+
from ..utils.misc import get_metrics_multitask, print_metrics_multitask
|
21 |
+
|
22 |
+
|
23 |
+
class TrainResNet(BaseEstimator):
|
24 |
+
''' ... '''
|
25 |
+
def __init__(self,
|
26 |
+
src_modalities: dict[str, dict[str, Any]],
|
27 |
+
tgt_modalities: dict[str, dict[str, Any]],
|
28 |
+
label_fractions: dict[str, float],
|
29 |
+
num_epochs: int = 32,
|
30 |
+
batch_size: int = 8,
|
31 |
+
lr: float = 1e-2,
|
32 |
+
weight_decay: float = 0.0,
|
33 |
+
gamma: float = 0.0,
|
34 |
+
criterion: str | None = None,
|
35 |
+
device: str = 'cpu',
|
36 |
+
cuda_devices: list = [1,2],
|
37 |
+
mri_feature: str = 'img_MRI_T1',
|
38 |
+
ckpt_path: str = '/home/skowshik/ADRD_repo/adrd_tool/adrd/dev/ckpt/ckpt.pt',
|
39 |
+
load_from_ckpt: bool = True,
|
40 |
+
save_intermediate_ckpts: bool = False,
|
41 |
+
data_parallel: bool = False,
|
42 |
+
verbose: int = 0,
|
43 |
+
):
|
44 |
+
''' ... '''
|
45 |
+
# for multiprocessing
|
46 |
+
self._rank = 0
|
47 |
+
self._lock = None
|
48 |
+
|
49 |
+
# positional parameters
|
50 |
+
self.src_modalities = src_modalities
|
51 |
+
self.tgt_modalities = tgt_modalities
|
52 |
+
|
53 |
+
# training parameters
|
54 |
+
self.label_fractions = label_fractions
|
55 |
+
self.num_epochs = num_epochs
|
56 |
+
self.batch_size = batch_size
|
57 |
+
self.lr = lr
|
58 |
+
self.weight_decay = weight_decay
|
59 |
+
self.gamma = gamma
|
60 |
+
self.criterion = criterion
|
61 |
+
self.device = device
|
62 |
+
self.cuda_devices = cuda_devices
|
63 |
+
self.mri_feature = mri_feature
|
64 |
+
self.ckpt_path = ckpt_path
|
65 |
+
self.load_from_ckpt = load_from_ckpt
|
66 |
+
self.save_intermediate_ckpts = save_intermediate_ckpts
|
67 |
+
self.data_parallel = data_parallel
|
68 |
+
self.verbose = verbose
|
69 |
+
|
70 |
+
def fit(self, x, y):
|
71 |
+
''' ... '''
|
72 |
+
# for PyTorch computational efficiency
|
73 |
+
torch.set_num_threads(1)
|
74 |
+
|
75 |
+
# set the device for use
|
76 |
+
if self.device == 'cuda':
|
77 |
+
self.device = "{}:{}".format(self.device, str(self.cuda_devices[0]))
|
78 |
+
|
79 |
+
# initialize model
|
80 |
+
if self.load_from_ckpt:
|
81 |
+
try:
|
82 |
+
print("Loading model from checkpoint...")
|
83 |
+
self.load(self.ckpt_path, map_location=self.device)
|
84 |
+
except:
|
85 |
+
print("Cannot load from checkpoint. Initializing new model...")
|
86 |
+
self.load_from_ckpt = False
|
87 |
+
|
88 |
+
# initialize model
|
89 |
+
if not self.load_from_ckpt:
|
90 |
+
self.net_ = nn.ResNetModel(
|
91 |
+
self.tgt_modalities,
|
92 |
+
mri_feature = self.mri_feature
|
93 |
+
)
|
94 |
+
# intialize model parameters using xavier_uniform
|
95 |
+
for p in self.net_.parameters():
|
96 |
+
if p.dim() > 1:
|
97 |
+
torch.nn.init.xavier_uniform_(p)
|
98 |
+
|
99 |
+
self.net_.to(self.device)
|
100 |
+
|
101 |
+
# Initialize the number of GPUs
|
102 |
+
if self.data_parallel and torch.cuda.device_count() > 1:
|
103 |
+
print("Available", torch.cuda.device_count(), "GPUs!")
|
104 |
+
self.net_ = torch.nn.DataParallel(self.net_, device_ids=self.cuda_devices)
|
105 |
+
|
106 |
+
|
107 |
+
# split dataset
|
108 |
+
x_trn, x_vld, y_trn, y_vld = train_test_split(
|
109 |
+
x, y, test_size = 0.2, random_state = 0,
|
110 |
+
)
|
111 |
+
|
112 |
+
# initialize dataset and dataloader
|
113 |
+
dat_trn = TransformerTrainingDataset(
|
114 |
+
x_trn, y_trn,
|
115 |
+
self.src_modalities,
|
116 |
+
self.tgt_modalities,
|
117 |
+
dropout_rate = .5,
|
118 |
+
dropout_strategy = 'compensated',
|
119 |
+
mri_feature = self.mri_feature,
|
120 |
+
)
|
121 |
+
|
122 |
+
dat_vld = TransformerValidationDataset(
|
123 |
+
x_vld, y_vld,
|
124 |
+
self.src_modalities,
|
125 |
+
self.tgt_modalities,
|
126 |
+
mri_feature = self.mri_feature,
|
127 |
+
)
|
128 |
+
|
129 |
+
# ic(dat_trn[0])
|
130 |
+
|
131 |
+
ldr_trn = torch.utils.data.DataLoader(
|
132 |
+
dat_trn,
|
133 |
+
batch_size = self.batch_size,
|
134 |
+
shuffle = True,
|
135 |
+
drop_last = False,
|
136 |
+
num_workers = 0,
|
137 |
+
collate_fn = TransformerTrainingDataset.collate_fn,
|
138 |
+
# pin_memory = True
|
139 |
+
)
|
140 |
+
|
141 |
+
ldr_vld = torch.utils.data.DataLoader(
|
142 |
+
dat_vld,
|
143 |
+
batch_size = self.batch_size,
|
144 |
+
shuffle = False,
|
145 |
+
drop_last = False,
|
146 |
+
num_workers = 0,
|
147 |
+
collate_fn = TransformerTrainingDataset.collate_fn,
|
148 |
+
# pin_memory = True
|
149 |
+
)
|
150 |
+
|
151 |
+
# initialize optimizer
|
152 |
+
optimizer = torch.optim.AdamW(
|
153 |
+
self.net_.parameters(),
|
154 |
+
lr = self.lr,
|
155 |
+
betas = (0.9, 0.98),
|
156 |
+
weight_decay = self.weight_decay
|
157 |
+
)
|
158 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=64, verbose=(self.verbose > 2))
|
159 |
+
|
160 |
+
# initialize loss function (binary cross entropy)
|
161 |
+
loss_fn = {}
|
162 |
+
|
163 |
+
for k in self.tgt_modalities:
|
164 |
+
alpha = pow((1 - self.label_fractions[k]), self.gamma)
|
165 |
+
# if alpha < 0.5:
|
166 |
+
# alpha = -1
|
167 |
+
loss_fn[k] = nn.SigmoidFocalLoss(
|
168 |
+
alpha = alpha,
|
169 |
+
gamma = self.gamma,
|
170 |
+
reduction = 'none'
|
171 |
+
)
|
172 |
+
|
173 |
+
# to record the best validation performance criterion
|
174 |
+
if self.criterion is not None:
|
175 |
+
best_crit = None
|
176 |
+
|
177 |
+
# progress bar for epoch loops
|
178 |
+
if self.verbose == 1:
|
179 |
+
with self._lock if self._lock is not None else suppress():
|
180 |
+
pbr_epoch = tqdm.tqdm(
|
181 |
+
desc = 'Rank {:02d}'.format(self._rank),
|
182 |
+
total = self.num_epochs,
|
183 |
+
position = self._rank,
|
184 |
+
ascii = True,
|
185 |
+
leave = False,
|
186 |
+
bar_format='{l_bar}{r_bar}'
|
187 |
+
)
|
188 |
+
|
189 |
+
# Define a hook function to print and store the gradient of a layer
|
190 |
+
def print_and_store_grad(grad, grad_list):
|
191 |
+
grad_list.append(grad)
|
192 |
+
# print(grad)
|
193 |
+
|
194 |
+
# grad_list = []
|
195 |
+
# self.net_.module.img_net_.featurizer.down_tr64.ops[0].conv1.weight.register_hook(lambda grad: print_and_store_grad(grad, grad_list))
|
196 |
+
# self.net_.module.modules_emb_src['gender'].weight.register_hook(lambda grad: print_and_store_grad(grad, grad_list))
|
197 |
+
|
198 |
+
|
199 |
+
# training loop
|
200 |
+
for epoch in range(self.num_epochs):
|
201 |
+
# progress bar for batch loops
|
202 |
+
if self.verbose > 1:
|
203 |
+
pbr_batch = ProgressBar(len(dat_trn), 'Epoch {:03d} (TRN)'.format(epoch))
|
204 |
+
|
205 |
+
# set model to train mode
|
206 |
+
torch.set_grad_enabled(True)
|
207 |
+
self.net_.train()
|
208 |
+
|
209 |
+
scores_trn, y_true_trn = [], []
|
210 |
+
losses_trn = [[] for _ in self.tgt_modalities]
|
211 |
+
for x_batch, y_batch, mask in ldr_trn:
|
212 |
+
|
213 |
+
# mount data to the proper device
|
214 |
+
x_batch = {k: x_batch[k].to(self.device) for k in x_batch}
|
215 |
+
y_batch = {k: y_batch[k].to(torch.float).to(self.device) for k in y_batch}
|
216 |
+
|
217 |
+
# forward
|
218 |
+
outputs = self.net_(x_batch)
|
219 |
+
|
220 |
+
# calculate multitask loss
|
221 |
+
loss = 0
|
222 |
+
for i, k in enumerate(self.tgt_modalities):
|
223 |
+
loss_task = loss_fn[k](outputs[k], y_batch[k])
|
224 |
+
loss += loss_task.mean()
|
225 |
+
losses_trn[i] += loss_task.detach().cpu().numpy().tolist()
|
226 |
+
|
227 |
+
# backward
|
228 |
+
optimizer.zero_grad(set_to_none=True)
|
229 |
+
loss.backward()
|
230 |
+
optimizer.step()
|
231 |
+
|
232 |
+
''' TODO: change array to dictionary later '''
|
233 |
+
outputs = torch.stack(list(outputs.values()), dim=1)
|
234 |
+
y_batch = torch.stack(list(y_batch.values()), dim=1)
|
235 |
+
|
236 |
+
# save outputs to evaluate performance later
|
237 |
+
scores_trn.append(outputs.detach().to(torch.float).cpu())
|
238 |
+
y_true_trn.append(y_batch.cpu())
|
239 |
+
|
240 |
+
# update progress bar
|
241 |
+
if self.verbose > 1:
|
242 |
+
batch_size = len(next(iter(x_batch.values())))
|
243 |
+
pbr_batch.update(batch_size, {})
|
244 |
+
pbr_batch.refresh()
|
245 |
+
|
246 |
+
# clear cuda cache
|
247 |
+
if "cuda" in self.device:
|
248 |
+
torch.cuda.empty_cache()
|
249 |
+
|
250 |
+
# for better tqdm progress bar display
|
251 |
+
if self.verbose > 1:
|
252 |
+
pbr_batch.close()
|
253 |
+
|
254 |
+
# set scheduler
|
255 |
+
scheduler.step()
|
256 |
+
|
257 |
+
# calculate and print training performance metrics
|
258 |
+
scores_trn = torch.cat(scores_trn)
|
259 |
+
y_true_trn = torch.cat(y_true_trn)
|
260 |
+
y_pred_trn = (scores_trn > 0).to(torch.int)
|
261 |
+
y_prob_trn = torch.sigmoid(scores_trn)
|
262 |
+
met_trn = get_metrics_multitask(
|
263 |
+
y_true_trn.numpy(),
|
264 |
+
y_pred_trn.numpy(),
|
265 |
+
y_prob_trn.numpy()
|
266 |
+
)
|
267 |
+
|
268 |
+
# add loss to metrics
|
269 |
+
for i in range(len(self.tgt_modalities)):
|
270 |
+
met_trn[i]['Loss'] = np.mean(losses_trn[i])
|
271 |
+
|
272 |
+
if self.verbose > 2:
|
273 |
+
print_metrics_multitask(met_trn)
|
274 |
+
|
275 |
+
# progress bar for validation
|
276 |
+
if self.verbose > 1:
|
277 |
+
pbr_batch = ProgressBar(len(dat_vld), 'Epoch {:03d} (VLD)'.format(epoch))
|
278 |
+
|
279 |
+
# set model to validation mode
|
280 |
+
torch.set_grad_enabled(False)
|
281 |
+
self.net_.eval()
|
282 |
+
|
283 |
+
scores_vld, y_true_vld = [], []
|
284 |
+
losses_vld = [[] for _ in self.tgt_modalities]
|
285 |
+
for x_batch, y_batch, mask in ldr_vld:
|
286 |
+
# mount data to the proper device
|
287 |
+
x_batch = {k: x_batch[k].to(self.device) for k in x_batch}
|
288 |
+
y_batch = {k: y_batch[k].to(torch.float).to(self.device) for k in y_batch}
|
289 |
+
|
290 |
+
# forward
|
291 |
+
outputs = self.net_(x_batch)
|
292 |
+
|
293 |
+
# calculate multitask loss
|
294 |
+
for i, k in enumerate(self.tgt_modalities):
|
295 |
+
loss_task = loss_fn[k](outputs[k], y_batch[k])
|
296 |
+
losses_vld[i] += loss_task.detach().cpu().numpy().tolist()
|
297 |
+
|
298 |
+
''' TODO: change array to dictionary later '''
|
299 |
+
outputs = torch.stack(list(outputs.values()), dim=1)
|
300 |
+
y_batch = torch.stack(list(y_batch.values()), dim=1)
|
301 |
+
|
302 |
+
# save outputs to evaluate performance later
|
303 |
+
scores_vld.append(outputs.detach().to(torch.float).cpu())
|
304 |
+
y_true_vld.append(y_batch.cpu())
|
305 |
+
|
306 |
+
# update progress bar
|
307 |
+
if self.verbose > 1:
|
308 |
+
batch_size = len(next(iter(x_batch.values())))
|
309 |
+
pbr_batch.update(batch_size, {})
|
310 |
+
pbr_batch.refresh()
|
311 |
+
|
312 |
+
# clear cuda cache
|
313 |
+
if "cuda" in self.device:
|
314 |
+
torch.cuda.empty_cache()
|
315 |
+
|
316 |
+
# for better tqdm progress bar display
|
317 |
+
if self.verbose > 1:
|
318 |
+
pbr_batch.close()
|
319 |
+
|
320 |
+
# calculate and print validation performance metrics
|
321 |
+
scores_vld = torch.cat(scores_vld)
|
322 |
+
y_true_vld = torch.cat(y_true_vld)
|
323 |
+
y_pred_vld = (scores_vld > 0).to(torch.int)
|
324 |
+
y_prob_vld = torch.sigmoid(scores_vld)
|
325 |
+
met_vld = get_metrics_multitask(
|
326 |
+
y_true_vld.numpy(),
|
327 |
+
y_pred_vld.numpy(),
|
328 |
+
y_prob_vld.numpy()
|
329 |
+
)
|
330 |
+
|
331 |
+
# add loss to metrics
|
332 |
+
for i in range(len(self.tgt_modalities)):
|
333 |
+
met_vld[i]['Loss'] = np.mean(losses_vld[i])
|
334 |
+
|
335 |
+
if self.verbose > 2:
|
336 |
+
print_metrics_multitask(met_vld)
|
337 |
+
|
338 |
+
# save the model if it has the best validation performance criterion by far
|
339 |
+
if self.criterion is None: continue
|
340 |
+
|
341 |
+
# is current criterion better than previous best?
|
342 |
+
curr_crit = np.mean([met_vld[i][self.criterion] for i in range(len(self.tgt_modalities))])
|
343 |
+
if best_crit is None or np.isnan(best_crit):
|
344 |
+
is_better = True
|
345 |
+
elif self.criterion == 'Loss' and best_crit >= curr_crit:
|
346 |
+
is_better = True
|
347 |
+
elif self.criterion != 'Loss' and best_crit <= curr_crit:
|
348 |
+
is_better = True
|
349 |
+
else:
|
350 |
+
is_better = False
|
351 |
+
|
352 |
+
# update best criterion
|
353 |
+
if is_better:
|
354 |
+
best_crit = curr_crit
|
355 |
+
best_state_dict = deepcopy(self.net_.state_dict())
|
356 |
+
if self.save_intermediate_ckpts:
|
357 |
+
print("Saving the model...")
|
358 |
+
self.save(self.ckpt_path)
|
359 |
+
|
360 |
+
if self.verbose > 2:
|
361 |
+
print('Best {}: {}'.format(self.criterion, best_crit))
|
362 |
+
|
363 |
+
if self.verbose == 1:
|
364 |
+
with self._lock if self._lock is not None else suppress():
|
365 |
+
pbr_epoch.update(1)
|
366 |
+
pbr_epoch.refresh()
|
367 |
+
|
368 |
+
if self.verbose == 1:
|
369 |
+
with self._lock if self._lock is not None else suppress():
|
370 |
+
pbr_epoch.close()
|
371 |
+
|
372 |
+
# restore the model of the best validation performance across all epoches
|
373 |
+
if ldr_vld is not None and self.criterion is not None:
|
374 |
+
self.net_.load_state_dict(best_state_dict)
|
375 |
+
|
376 |
+
return self
|
377 |
+
|
378 |
+
def predict_logits(self,
|
379 |
+
x: list[dict[str, Any]],
|
380 |
+
) -> list[dict[str, float]]:
|
381 |
+
'''
|
382 |
+
The input x can be a single sample or a list of samples.
|
383 |
+
'''
|
384 |
+
# input validation
|
385 |
+
check_is_fitted(self)
|
386 |
+
|
387 |
+
# for PyTorch computational efficiency
|
388 |
+
torch.set_num_threads(1)
|
389 |
+
|
390 |
+
# set model to eval mode
|
391 |
+
torch.set_grad_enabled(False)
|
392 |
+
self.net_.eval()
|
393 |
+
|
394 |
+
# number of samples to evaluate
|
395 |
+
n_samples = len(x)
|
396 |
+
|
397 |
+
# format x
|
398 |
+
fmt = Formatter(self.src_modalities)
|
399 |
+
x = [fmt(smp) for smp in x]
|
400 |
+
|
401 |
+
# generate missing mask (BEFORE IMPUTATION)
|
402 |
+
msk = MissingMasker(self.src_modalities)
|
403 |
+
mask = [msk(smp) for smp in x]
|
404 |
+
|
405 |
+
# reformat x and then impute by 0s
|
406 |
+
imp = ConstantImputer(self.src_modalities)
|
407 |
+
x = [imp(smp) for smp in x]
|
408 |
+
|
409 |
+
# convert list-of-dict to dict-of-list
|
410 |
+
x = {k: [smp[k] for smp in x] for k in self.src_modalities}
|
411 |
+
mask = {k: [smp[k] for smp in mask] for k in self.src_modalities}
|
412 |
+
|
413 |
+
# to tensor
|
414 |
+
x = {k: torch.as_tensor(np.array(v)).to(self.device) for k, v in x.items()}
|
415 |
+
mask = {k: torch.as_tensor(np.array(v)).to(self.device) for k, v in mask.items()}
|
416 |
+
|
417 |
+
# calculate logits
|
418 |
+
logits = self.net_(x)
|
419 |
+
|
420 |
+
# convert dict-of-list to list-of-dict
|
421 |
+
logits = {k: logits[k].tolist() for k in self.tgt_modalities}
|
422 |
+
logits = [{k: logits[k][i] for k in self.tgt_modalities} for i in range(n_samples)]
|
423 |
+
|
424 |
+
return logits
|
425 |
+
|
426 |
+
def predict_proba(self,
|
427 |
+
x: list[dict[str, Any]],
|
428 |
+
temperature: float = 1.0
|
429 |
+
) -> list[dict[str, float]]:
|
430 |
+
''' ... '''
|
431 |
+
# calculate logits
|
432 |
+
logits = self.predict_logits(x)
|
433 |
+
|
434 |
+
# convert logits to probabilities and
|
435 |
+
proba = [{k: expit(smp[k] / temperature) for k in self.tgt_modalities} for smp in logits]
|
436 |
+
return proba
|
437 |
+
|
438 |
+
def predict(self,
|
439 |
+
x: list[dict[str, Any]],
|
440 |
+
) -> list[dict[str, int]]:
|
441 |
+
''' ... '''
|
442 |
+
proba = self.predict_proba(x)
|
443 |
+
return [{k: int(smp[k] > 0.5) for k in self.tgt_modalities} for smp in proba]
|
444 |
+
|
445 |
+
def save(self, filepath: str) -> None:
|
446 |
+
''' ... '''
|
447 |
+
check_is_fitted(self)
|
448 |
+
if self.data_parallel:
|
449 |
+
state_dict = self.net_.module.state_dict()
|
450 |
+
else:
|
451 |
+
state_dict = self.net_.state_dict()
|
452 |
+
|
453 |
+
# attach model hyper parameters
|
454 |
+
state_dict['src_modalities'] = self.src_modalities
|
455 |
+
state_dict['tgt_modalities'] = self.tgt_modalities
|
456 |
+
state_dict['mri_feature'] = self.mri_feature
|
457 |
+
|
458 |
+
torch.save(state_dict, filepath)
|
459 |
+
|
460 |
+
def load(self, filepath: str, map_location: str='cpu') -> None:
|
461 |
+
''' ... '''
|
462 |
+
# load state_dict
|
463 |
+
state_dict = torch.load(filepath, map_location=map_location)
|
464 |
+
|
465 |
+
# load data modalities
|
466 |
+
self.src_modalities = state_dict.pop('src_modalities')
|
467 |
+
self.tgt_modalities = state_dict.pop('tgt_modalities')
|
468 |
+
|
469 |
+
# initialize model
|
470 |
+
self.net_ = nn.ResNetModel(
|
471 |
+
self.tgt_modalities,
|
472 |
+
mri_feature = state_dict.pop('mri_feature')
|
473 |
+
)
|
474 |
+
|
475 |
+
# load model parameters
|
476 |
+
self.net_.load_state_dict(state_dict)
|
477 |
+
self.net_.to(self.device)
|
478 |
+
|
479 |
+
@classmethod
|
480 |
+
def from_ckpt(cls, filepath: str, device='cpu') -> Self:
|
481 |
+
''' ... '''
|
482 |
+
obj = cls(None, None, None,device=device)
|
483 |
+
obj.load(filepath)
|
484 |
+
return obj
|
adrd/model/transformer.py
ADDED
@@ -0,0 +1,600 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__all__ = ['Transformer']
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
import numpy as np
|
6 |
+
import tqdm
|
7 |
+
from sklearn.base import BaseEstimator
|
8 |
+
from sklearn.utils.validation import check_is_fitted
|
9 |
+
from sklearn.model_selection import train_test_split
|
10 |
+
from scipy.special import expit
|
11 |
+
from copy import deepcopy
|
12 |
+
from contextlib import suppress
|
13 |
+
from typing import Any, Self, Type
|
14 |
+
from functools import wraps
|
15 |
+
Tensor = Type[torch.Tensor]
|
16 |
+
Module = Type[torch.nn.Module]
|
17 |
+
|
18 |
+
from .. import nn
|
19 |
+
from ..utils import TransformerTrainingDataset
|
20 |
+
from ..utils import TransformerBalancedTrainingDataset
|
21 |
+
from ..utils import Transformer2ndOrderBalancedTrainingDataset
|
22 |
+
from ..utils import TransformerValidationDataset
|
23 |
+
from ..utils import TransformerTestingDataset
|
24 |
+
from ..utils.misc import ProgressBar
|
25 |
+
from ..utils.misc import get_metrics_multitask, print_metrics_multitask
|
26 |
+
from ..utils.misc import convert_args_kwargs_to_kwargs
|
27 |
+
|
28 |
+
|
29 |
+
def _manage_ctx_fit(func):
|
30 |
+
''' ... '''
|
31 |
+
@wraps(func)
|
32 |
+
def wrapper(*args, **kwargs):
|
33 |
+
# format arguments
|
34 |
+
kwargs = convert_args_kwargs_to_kwargs(func, args, kwargs)
|
35 |
+
|
36 |
+
if kwargs['self']._device_ids is None:
|
37 |
+
return func(**kwargs)
|
38 |
+
else:
|
39 |
+
# change primary device
|
40 |
+
default_device = kwargs['self'].device
|
41 |
+
kwargs['self'].device = kwargs['self']._device_ids[0]
|
42 |
+
rtn = func(**kwargs)
|
43 |
+
kwargs['self'].to(default_device)
|
44 |
+
return rtn
|
45 |
+
return wrapper
|
46 |
+
|
47 |
+
|
48 |
+
class Transformer(BaseEstimator):
|
49 |
+
''' ... '''
|
50 |
+
def __init__(self,
|
51 |
+
src_modalities: dict[str, dict[str, Any]],
|
52 |
+
tgt_modalities: dict[str, dict[str, Any]],
|
53 |
+
d_model: int = 32,
|
54 |
+
nhead: int = 1,
|
55 |
+
num_layers: int = 1,
|
56 |
+
num_epochs: int = 32,
|
57 |
+
batch_size: int = 8,
|
58 |
+
batch_size_multiplier: int = 1,
|
59 |
+
lr: float = 1e-2,
|
60 |
+
weight_decay: float = 0.0,
|
61 |
+
beta: float = 0.9999,
|
62 |
+
gamma: float = 2.0,
|
63 |
+
scale: float = 1.0,
|
64 |
+
lambd: float = 0.0,
|
65 |
+
criterion: str | None = None,
|
66 |
+
device: str = 'cpu',
|
67 |
+
verbose: int = 0,
|
68 |
+
_device_ids: list | None = None,
|
69 |
+
_dataloader_num_workers: int = 0,
|
70 |
+
_amp_enabled: bool = False,
|
71 |
+
) -> None:
|
72 |
+
''' ... '''
|
73 |
+
# for multiprocessing
|
74 |
+
self._rank = 0
|
75 |
+
self._lock = None
|
76 |
+
|
77 |
+
# positional parameters
|
78 |
+
self.src_modalities = src_modalities
|
79 |
+
self.tgt_modalities = tgt_modalities
|
80 |
+
|
81 |
+
# training parameters
|
82 |
+
self.d_model = d_model
|
83 |
+
self.nhead = nhead
|
84 |
+
self.num_layers = num_layers
|
85 |
+
self.num_epochs = num_epochs
|
86 |
+
self.batch_size = batch_size
|
87 |
+
self.batch_size_multiplier = batch_size_multiplier
|
88 |
+
self.lr = lr
|
89 |
+
self.weight_decay = weight_decay
|
90 |
+
self.beta = beta
|
91 |
+
self.gamma = gamma
|
92 |
+
self.scale = scale
|
93 |
+
self.lambd = lambd
|
94 |
+
self.criterion = criterion
|
95 |
+
self.device = device
|
96 |
+
self.verbose = verbose
|
97 |
+
self._device_ids = _device_ids
|
98 |
+
self._dataloader_num_workers = _dataloader_num_workers
|
99 |
+
self._amp_enabled = _amp_enabled
|
100 |
+
|
101 |
+
@_manage_ctx_fit
|
102 |
+
def fit(self,
|
103 |
+
x, y,
|
104 |
+
is_embedding: dict[str, bool] | None = None,
|
105 |
+
) -> Self:
|
106 |
+
''' ... '''
|
107 |
+
# for PyTorch computational efficiency
|
108 |
+
torch.set_num_threads(1)
|
109 |
+
|
110 |
+
# initialize neural network
|
111 |
+
self.net_ = self._init_net()
|
112 |
+
|
113 |
+
# initialize dataloaders
|
114 |
+
ldr_trn, ldr_vld = self._init_dataloader(x, y, is_embedding)
|
115 |
+
|
116 |
+
# initialize optimizer and scheduler
|
117 |
+
optimizer = self._init_optimizer()
|
118 |
+
scheduler = self._init_scheduler(optimizer)
|
119 |
+
|
120 |
+
# gradient scaler for AMP
|
121 |
+
if self._amp_enabled: scaler = torch.cuda.amp.GradScaler()
|
122 |
+
|
123 |
+
# initialize loss function (binary cross entropy)
|
124 |
+
loss_func = self._init_loss_func({
|
125 |
+
k: (
|
126 |
+
sum([_[k] == 0 for _ in ldr_trn.dataset.tgt]),
|
127 |
+
sum([_[k] == 1 for _ in ldr_trn.dataset.tgt]),
|
128 |
+
) for k in self.tgt_modalities
|
129 |
+
})
|
130 |
+
|
131 |
+
# to record the best validation performance criterion
|
132 |
+
if self.criterion is not None: best_crit = None
|
133 |
+
|
134 |
+
# progress bar for epoch loops
|
135 |
+
if self.verbose == 1:
|
136 |
+
with self._lock if self._lock is not None else suppress():
|
137 |
+
pbr_epoch = tqdm.tqdm(
|
138 |
+
desc = 'Rank {:02d}'.format(self._rank),
|
139 |
+
total = self.num_epochs,
|
140 |
+
position = self._rank,
|
141 |
+
ascii = True,
|
142 |
+
leave = False,
|
143 |
+
bar_format='{l_bar}{r_bar}'
|
144 |
+
)
|
145 |
+
|
146 |
+
# training loop
|
147 |
+
for epoch in range(self.num_epochs):
|
148 |
+
# progress bar for batch loops
|
149 |
+
if self.verbose > 1:
|
150 |
+
pbr_batch = ProgressBar(len(ldr_trn.dataset), 'Epoch {:03d} (TRN)'.format(epoch))
|
151 |
+
|
152 |
+
# set model to train mode
|
153 |
+
torch.set_grad_enabled(True)
|
154 |
+
self.net_.train()
|
155 |
+
|
156 |
+
scores_trn: dict[str, list[float]] = {k: [] for k in self.tgt_modalities}
|
157 |
+
y_true_trn: dict[str, list[int]] = {k: [] for k in self.tgt_modalities}
|
158 |
+
losses_trn: dict[str, list[float]] = {k: [] for k in self.tgt_modalities}
|
159 |
+
for n_iter, (x_batch, y_batch, mask_x, mask_y) in enumerate(ldr_trn):
|
160 |
+
# mount data to the proper device
|
161 |
+
x_batch = {k: x_batch[k].to(self.device) for k in self.src_modalities}
|
162 |
+
y_batch = {k: y_batch[k].to(torch.float).to(self.device) for k in self.tgt_modalities}
|
163 |
+
mask_x = {k: mask_x[k].to(self.device) for k in self.src_modalities}
|
164 |
+
mask_y = {k: mask_y[k].to(self.device) for k in self.tgt_modalities}
|
165 |
+
|
166 |
+
# forward
|
167 |
+
with torch.autocast(
|
168 |
+
device_type = 'cpu' if self.device == 'cpu' else 'cuda',
|
169 |
+
dtype = torch.bfloat16 if self.device == 'cpu' else torch.float16,
|
170 |
+
enabled = self._amp_enabled,
|
171 |
+
):
|
172 |
+
outputs = self.net_(x_batch, mask_x, is_embedding)
|
173 |
+
|
174 |
+
# calculate multitask loss
|
175 |
+
loss = 0
|
176 |
+
for i, tgt_k in enumerate(self.tgt_modalities):
|
177 |
+
loss_k = loss_func[tgt_k](outputs[tgt_k], y_batch[tgt_k])
|
178 |
+
loss_k = torch.masked_select(loss_k, torch.logical_not(mask_y[tgt_k].squeeze()))
|
179 |
+
loss += loss_k.mean()
|
180 |
+
losses_trn[tgt_k] += loss_k.detach().cpu().numpy().tolist()
|
181 |
+
|
182 |
+
# if self.lambd != 0:
|
183 |
+
|
184 |
+
# backward
|
185 |
+
if self._amp_enabled:
|
186 |
+
scaler.scale(loss).backward()
|
187 |
+
else:
|
188 |
+
loss.backward()
|
189 |
+
|
190 |
+
# update parameters
|
191 |
+
if n_iter != 0 and n_iter % self.batch_size_multiplier == 0:
|
192 |
+
if self._amp_enabled:
|
193 |
+
scaler.step(optimizer)
|
194 |
+
scaler.update()
|
195 |
+
optimizer.zero_grad()
|
196 |
+
else:
|
197 |
+
optimizer.step()
|
198 |
+
optimizer.zero_grad()
|
199 |
+
|
200 |
+
# save outputs to evaluate performance later
|
201 |
+
for tgt_k in self.tgt_modalities:
|
202 |
+
tmp = torch.masked_select(outputs[tgt_k], torch.logical_not(mask_y[tgt_k].squeeze()))
|
203 |
+
scores_trn[tgt_k] += tmp.detach().cpu().numpy().tolist()
|
204 |
+
tmp = torch.masked_select(y_batch[tgt_k], torch.logical_not(mask_y[tgt_k].squeeze()))
|
205 |
+
y_true_trn[tgt_k] += tmp.cpu().numpy().tolist()
|
206 |
+
|
207 |
+
# update progress bar
|
208 |
+
if self.verbose > 1:
|
209 |
+
batch_size = len(next(iter(x_batch.values())))
|
210 |
+
pbr_batch.update(batch_size, {})
|
211 |
+
pbr_batch.refresh()
|
212 |
+
|
213 |
+
# for better tqdm progress bar display
|
214 |
+
if self.verbose > 1:
|
215 |
+
pbr_batch.close()
|
216 |
+
|
217 |
+
# set scheduler
|
218 |
+
scheduler.step()
|
219 |
+
|
220 |
+
# calculate and print training performance metrics
|
221 |
+
y_pred_trn: dict[str, list[int]] = {k: [] for k in self.tgt_modalities}
|
222 |
+
y_prob_trn: dict[str, list[float]] = {k: [] for k in self.tgt_modalities}
|
223 |
+
for tgt_k in self.tgt_modalities:
|
224 |
+
for i in range(len(scores_trn[tgt_k])):
|
225 |
+
y_pred_trn[tgt_k].append(1 if scores_trn[tgt_k][i] > 0 else 0)
|
226 |
+
y_prob_trn[tgt_k].append(expit(scores_trn[tgt_k][i]))
|
227 |
+
met_trn = get_metrics_multitask(y_true_trn, y_pred_trn, y_prob_trn)
|
228 |
+
|
229 |
+
# add loss to metrics
|
230 |
+
for tgt_k in self.tgt_modalities:
|
231 |
+
met_trn[tgt_k]['Loss'] = np.mean(losses_trn[tgt_k])
|
232 |
+
|
233 |
+
if self.verbose > 2:
|
234 |
+
print_metrics_multitask(met_trn)
|
235 |
+
|
236 |
+
# progress bar for validation
|
237 |
+
if self.verbose > 1:
|
238 |
+
pbr_batch = ProgressBar(len(ldr_vld.dataset), 'Epoch {:03d} (VLD)'.format(epoch))
|
239 |
+
|
240 |
+
# set model to validation mode
|
241 |
+
torch.set_grad_enabled(False)
|
242 |
+
self.net_.eval()
|
243 |
+
|
244 |
+
scores_vld: dict[str, list[float]] = {k: [] for k in self.tgt_modalities}
|
245 |
+
y_true_vld: dict[str, list[int]] = {k: [] for k in self.tgt_modalities}
|
246 |
+
losses_vld: dict[str, list[float]] = {k: [] for k in self.tgt_modalities}
|
247 |
+
for x_batch, y_batch, mask_x, mask_y in ldr_vld:
|
248 |
+
# mount data to the proper device
|
249 |
+
x_batch = {k: x_batch[k].to(self.device) for k in self.src_modalities}
|
250 |
+
y_batch = {k: y_batch[k].to(torch.float).to(self.device) for k in self.tgt_modalities}
|
251 |
+
mask_x = {k: mask_x[k].to(self.device) for k in self.src_modalities}
|
252 |
+
mask_y = {k: mask_y[k].to(self.device) for k in self.tgt_modalities}
|
253 |
+
|
254 |
+
# forward
|
255 |
+
with torch.autocast(
|
256 |
+
device_type = 'cpu' if self.device == 'cpu' else 'cuda',
|
257 |
+
dtype = torch.bfloat16 if self.device == 'cpu' else torch.float16,
|
258 |
+
enabled = self._amp_enabled
|
259 |
+
):
|
260 |
+
outputs = self.net_(x_batch, mask_x, is_embedding)
|
261 |
+
|
262 |
+
# calculate multitask loss
|
263 |
+
for i, tgt_k in enumerate(self.tgt_modalities):
|
264 |
+
loss_k = loss_func[tgt_k](outputs[tgt_k], y_batch[tgt_k])
|
265 |
+
loss_k = torch.masked_select(loss_k, torch.logical_not(mask_y[tgt_k].squeeze()))
|
266 |
+
losses_vld[tgt_k] += loss_k.detach().cpu().numpy().tolist()
|
267 |
+
|
268 |
+
# save outputs to evaluate performance later
|
269 |
+
for tgt_k in self.tgt_modalities:
|
270 |
+
tmp = torch.masked_select(outputs[tgt_k], torch.logical_not(mask_y[tgt_k].squeeze()))
|
271 |
+
scores_vld[tgt_k] += tmp.detach().cpu().numpy().tolist()
|
272 |
+
tmp = torch.masked_select(y_batch[tgt_k], torch.logical_not(mask_y[tgt_k].squeeze()))
|
273 |
+
y_true_vld[tgt_k] += tmp.cpu().numpy().tolist()
|
274 |
+
|
275 |
+
# update progress bar
|
276 |
+
if self.verbose > 1:
|
277 |
+
batch_size = len(next(iter(x_batch.values())))
|
278 |
+
pbr_batch.update(batch_size, {})
|
279 |
+
pbr_batch.refresh()
|
280 |
+
|
281 |
+
# for better tqdm progress bar display
|
282 |
+
if self.verbose > 1:
|
283 |
+
pbr_batch.close()
|
284 |
+
|
285 |
+
# calculate and print validation performance metrics
|
286 |
+
y_pred_vld: dict[str, list[int]] = {k: [] for k in self.tgt_modalities}
|
287 |
+
y_prob_vld: dict[str, list[float]] = {k: [] for k in self.tgt_modalities}
|
288 |
+
for tgt_k in self.tgt_modalities:
|
289 |
+
for i in range(len(scores_vld[tgt_k])):
|
290 |
+
y_pred_vld[tgt_k].append(1 if scores_vld[tgt_k][i] > 0 else 0)
|
291 |
+
y_prob_vld[tgt_k].append(expit(scores_vld[tgt_k][i]))
|
292 |
+
met_vld = get_metrics_multitask(y_true_vld, y_pred_vld, y_prob_vld)
|
293 |
+
|
294 |
+
# add loss to metrics
|
295 |
+
for tgt_k in self.tgt_modalities:
|
296 |
+
met_vld[tgt_k]['Loss'] = np.mean(losses_vld[tgt_k])
|
297 |
+
|
298 |
+
if self.verbose > 2:
|
299 |
+
print_metrics_multitask(met_vld)
|
300 |
+
|
301 |
+
# save the model if it has the best validation performance criterion by far
|
302 |
+
if self.criterion is None: continue
|
303 |
+
|
304 |
+
# is current criterion better than previous best?
|
305 |
+
curr_crit = np.mean([met_vld[k][self.criterion] for k in self.tgt_modalities])
|
306 |
+
if best_crit is None or np.isnan(best_crit):
|
307 |
+
is_better = True
|
308 |
+
elif self.criterion == 'Loss' and best_crit >= curr_crit:
|
309 |
+
is_better = True
|
310 |
+
elif self.criterion != 'Loss' and best_crit <= curr_crit:
|
311 |
+
is_better = True
|
312 |
+
else:
|
313 |
+
is_better = False
|
314 |
+
|
315 |
+
# update best criterion
|
316 |
+
if is_better:
|
317 |
+
best_crit = curr_crit
|
318 |
+
best_state_dict = deepcopy(self.net_.state_dict())
|
319 |
+
|
320 |
+
if self.verbose > 2:
|
321 |
+
print('Best {}: {}'.format(self.criterion, best_crit))
|
322 |
+
|
323 |
+
if self.verbose == 1:
|
324 |
+
with self._lock if self._lock is not None else suppress():
|
325 |
+
pbr_epoch.update(1)
|
326 |
+
pbr_epoch.refresh()
|
327 |
+
|
328 |
+
if self.verbose == 1:
|
329 |
+
with self._lock if self._lock is not None else suppress():
|
330 |
+
pbr_epoch.close()
|
331 |
+
|
332 |
+
# restore the model of the best validation performance across all epoches
|
333 |
+
if ldr_vld is not None and self.criterion is not None:
|
334 |
+
self.net_.load_state_dict(best_state_dict)
|
335 |
+
|
336 |
+
return self
|
337 |
+
|
338 |
+
def predict_logits(self,
|
339 |
+
x: list[dict[str, Any]],
|
340 |
+
is_embedding: dict[str, bool] | None = None,
|
341 |
+
_batch_size: int | None = None,
|
342 |
+
) -> list[dict[str, float]]:
|
343 |
+
'''
|
344 |
+
The input x can be a single sample or a list of samples.
|
345 |
+
'''
|
346 |
+
# input validation
|
347 |
+
check_is_fitted(self)
|
348 |
+
|
349 |
+
# for PyTorch computational efficiency
|
350 |
+
torch.set_num_threads(1)
|
351 |
+
|
352 |
+
# set model to eval mode
|
353 |
+
torch.set_grad_enabled(False)
|
354 |
+
self.net_.eval()
|
355 |
+
|
356 |
+
# intialize dataset and dataloader object
|
357 |
+
dat = TransformerTestingDataset(x, self.src_modalities, is_embedding)
|
358 |
+
ldr = DataLoader(
|
359 |
+
dataset = dat,
|
360 |
+
batch_size = _batch_size if _batch_size is not None else len(x),
|
361 |
+
shuffle = False,
|
362 |
+
drop_last = False,
|
363 |
+
num_workers = 0,
|
364 |
+
collate_fn = TransformerTestingDataset.collate_fn,
|
365 |
+
)
|
366 |
+
|
367 |
+
# run model and collect results
|
368 |
+
logits: list[dict[str, float]] = []
|
369 |
+
for x_batch, mask_x in ldr:
|
370 |
+
# mount data to the proper device
|
371 |
+
x_batch = {k: x_batch[k].to(self.device) for k in self.src_modalities}
|
372 |
+
mask_x = {k: mask_x[k].to(self.device) for k in self.src_modalities}
|
373 |
+
|
374 |
+
# forward
|
375 |
+
output: dict[str, Tensor] = self.net_(x_batch, mask_x, is_embedding)
|
376 |
+
|
377 |
+
# convert output from dict-of-list to list of dict, then append
|
378 |
+
tmp = {k: output[k].tolist() for k in self.tgt_modalities}
|
379 |
+
tmp = [{k: tmp[k][i] for k in self.tgt_modalities} for i in range(len(next(iter(tmp.values()))))]
|
380 |
+
logits += tmp
|
381 |
+
|
382 |
+
return logits
|
383 |
+
|
384 |
+
def predict_proba(self,
|
385 |
+
x: list[dict[str, Any]],
|
386 |
+
is_embedding: dict[str, bool] | None = None,
|
387 |
+
temperature: float = 1.0,
|
388 |
+
_batch_size: int | None = None,
|
389 |
+
) -> list[dict[str, float]]:
|
390 |
+
''' ... '''
|
391 |
+
logits = self.predict_logits(x, is_embedding, _batch_size)
|
392 |
+
return [{k: expit(smp[k] / temperature) for k in self.tgt_modalities} for smp in logits]
|
393 |
+
|
394 |
+
def predict(self,
|
395 |
+
x: list[dict[str, Any]],
|
396 |
+
is_embedding: dict[str, bool] | None = None,
|
397 |
+
_batch_size: int | None = None,
|
398 |
+
) -> list[dict[str, int]]:
|
399 |
+
''' ... '''
|
400 |
+
logits = self.predict_logits(x, is_embedding, _batch_size)
|
401 |
+
return [{k: int(smp[k] > 0.0) for k in self.tgt_modalities} for smp in logits]
|
402 |
+
|
403 |
+
def save(self, filepath: str) -> None:
|
404 |
+
''' ... '''
|
405 |
+
check_is_fitted(self)
|
406 |
+
state_dict = self.net_.state_dict()
|
407 |
+
|
408 |
+
# attach model hyper parameters
|
409 |
+
state_dict['src_modalities'] = self.src_modalities
|
410 |
+
state_dict['tgt_modalities'] = self.tgt_modalities
|
411 |
+
state_dict['d_model'] = self.d_model
|
412 |
+
state_dict['nhead'] = self.nhead
|
413 |
+
state_dict['num_layers'] = self.num_layers
|
414 |
+
torch.save(state_dict, filepath)
|
415 |
+
|
416 |
+
def load(self, filepath: str) -> None:
|
417 |
+
''' ... '''
|
418 |
+
# load state_dict
|
419 |
+
state_dict = torch.load(filepath, map_location='cpu')
|
420 |
+
|
421 |
+
# load essential parameters
|
422 |
+
self.src_modalities: dict[str, dict[str, Any]] = state_dict.pop('src_modalities')
|
423 |
+
self.tgt_modalities: dict[str, dict[str, Any]] = state_dict.pop('tgt_modalities')
|
424 |
+
self.d_model = state_dict.pop('d_model')
|
425 |
+
self.nhead = state_dict.pop('nhead')
|
426 |
+
self.num_layers = state_dict.pop('num_layers')
|
427 |
+
|
428 |
+
# initialize model
|
429 |
+
self.net_ = nn.Transformer(
|
430 |
+
self.src_modalities,
|
431 |
+
self.tgt_modalities,
|
432 |
+
self.d_model,
|
433 |
+
self.nhead,
|
434 |
+
self.num_layers,
|
435 |
+
)
|
436 |
+
|
437 |
+
# load model parameters
|
438 |
+
self.net_.load_state_dict(state_dict)
|
439 |
+
self.to(self.device)
|
440 |
+
|
441 |
+
def to(self, device: str) -> Self:
|
442 |
+
''' Mount model to the given device. '''
|
443 |
+
self.device = device
|
444 |
+
if hasattr(self, 'net_'): self.net_ = self.net_.to(device)
|
445 |
+
return self
|
446 |
+
|
447 |
+
@classmethod
|
448 |
+
def from_ckpt(cls, filepath: str) -> Self:
|
449 |
+
''' ... '''
|
450 |
+
obj = cls(None, None)
|
451 |
+
obj.load(filepath)
|
452 |
+
return obj
|
453 |
+
|
454 |
+
def _init_net(self):
|
455 |
+
""" ... """
|
456 |
+
net = nn.Transformer(
|
457 |
+
self.src_modalities,
|
458 |
+
self.tgt_modalities,
|
459 |
+
self.d_model,
|
460 |
+
self.nhead,
|
461 |
+
self.num_layers,
|
462 |
+
).to(self.device)
|
463 |
+
|
464 |
+
# train on multiple GPUs using torch.nn.DataParallel
|
465 |
+
if self._device_ids is not None:
|
466 |
+
net = torch.nn.DataParallel(net, device_ids=self._device_ids)
|
467 |
+
|
468 |
+
# intialize model parameters using xavier_uniform
|
469 |
+
for p in net.parameters():
|
470 |
+
if p.dim() > 1:
|
471 |
+
torch.nn.init.xavier_uniform_(p)
|
472 |
+
|
473 |
+
return net
|
474 |
+
|
475 |
+
def _init_dataloader(self, x, y, is_embedding):
|
476 |
+
""" ... """
|
477 |
+
# split dataset
|
478 |
+
x_trn, x_vld, y_trn, y_vld = train_test_split(
|
479 |
+
x, y, test_size = 0.2, random_state = 0,
|
480 |
+
)
|
481 |
+
|
482 |
+
# initialize dataset and dataloader
|
483 |
+
# dat_trn = TransformerTrainingDataset(
|
484 |
+
# dat_trn = TransformerBalancedTrainingDataset(
|
485 |
+
dat_trn = Transformer2ndOrderBalancedTrainingDataset(
|
486 |
+
x_trn, y_trn,
|
487 |
+
self.src_modalities,
|
488 |
+
self.tgt_modalities,
|
489 |
+
dropout_rate = .5,
|
490 |
+
# dropout_strategy = 'compensated',
|
491 |
+
dropout_strategy = 'permutation',
|
492 |
+
)
|
493 |
+
|
494 |
+
dat_vld = TransformerValidationDataset(
|
495 |
+
x_vld, y_vld,
|
496 |
+
self.src_modalities,
|
497 |
+
self.tgt_modalities,
|
498 |
+
is_embedding,
|
499 |
+
)
|
500 |
+
|
501 |
+
ldr_trn = DataLoader(
|
502 |
+
dataset = dat_trn,
|
503 |
+
batch_size = self.batch_size,
|
504 |
+
shuffle = True,
|
505 |
+
drop_last = False,
|
506 |
+
num_workers = self._dataloader_num_workers,
|
507 |
+
collate_fn = TransformerTrainingDataset.collate_fn,
|
508 |
+
# pin_memory = True
|
509 |
+
)
|
510 |
+
|
511 |
+
ldr_vld = DataLoader(
|
512 |
+
dataset = dat_vld,
|
513 |
+
batch_size = self.batch_size,
|
514 |
+
shuffle = False,
|
515 |
+
drop_last = False,
|
516 |
+
num_workers = self._dataloader_num_workers,
|
517 |
+
collate_fn = TransformerValidationDataset.collate_fn,
|
518 |
+
# pin_memory = True
|
519 |
+
)
|
520 |
+
|
521 |
+
return ldr_trn, ldr_vld
|
522 |
+
|
523 |
+
def _init_optimizer(self):
|
524 |
+
""" ... """
|
525 |
+
return torch.optim.AdamW(
|
526 |
+
self.net_.parameters(),
|
527 |
+
lr = self.lr,
|
528 |
+
betas = (0.9, 0.98),
|
529 |
+
weight_decay = self.weight_decay
|
530 |
+
)
|
531 |
+
|
532 |
+
def _init_scheduler(self, optimizer):
|
533 |
+
""" ... """
|
534 |
+
return torch.optim.lr_scheduler.OneCycleLR(
|
535 |
+
optimizer = optimizer,
|
536 |
+
max_lr = self.lr,
|
537 |
+
total_steps = self.num_epochs,
|
538 |
+
verbose = (self.verbose > 2)
|
539 |
+
)
|
540 |
+
|
541 |
+
def _init_loss_func(self,
|
542 |
+
num_per_cls: dict[str, tuple[int, int]],
|
543 |
+
) -> dict[str, Module]:
|
544 |
+
""" ... """
|
545 |
+
return {k: nn.SigmoidFocalLoss(
|
546 |
+
beta = self.beta,
|
547 |
+
gamma = self.gamma,
|
548 |
+
scale = self.scale,
|
549 |
+
num_per_cls = num_per_cls[k],
|
550 |
+
reduction = 'none',
|
551 |
+
) for k in self.tgt_modalities}
|
552 |
+
|
553 |
+
def _extract_embedding(self,
|
554 |
+
x: list[dict[str, Any]],
|
555 |
+
is_embedding: dict[str, bool] | None = None,
|
556 |
+
_batch_size: int | None = None,
|
557 |
+
) -> list[dict[str, Any]]:
|
558 |
+
""" ... """
|
559 |
+
# input validation
|
560 |
+
check_is_fitted(self)
|
561 |
+
|
562 |
+
# for PyTorch computational efficiency
|
563 |
+
torch.set_num_threads(1)
|
564 |
+
|
565 |
+
# set model to eval mode
|
566 |
+
torch.set_grad_enabled(False)
|
567 |
+
self.net_.eval()
|
568 |
+
|
569 |
+
# intialize dataset and dataloader object
|
570 |
+
dat = TransformerTestingDataset(x, self.src_modalities, is_embedding)
|
571 |
+
ldr = DataLoader(
|
572 |
+
dataset = dat,
|
573 |
+
batch_size = _batch_size if _batch_size is not None else len(x),
|
574 |
+
shuffle = False,
|
575 |
+
drop_last = False,
|
576 |
+
num_workers = 0,
|
577 |
+
collate_fn = TransformerTestingDataset.collate_fn,
|
578 |
+
)
|
579 |
+
|
580 |
+
# run model and extract embeddings
|
581 |
+
embeddings: list[dict[str, Any]] = []
|
582 |
+
for x_batch, _ in ldr:
|
583 |
+
# mount data to the proper device
|
584 |
+
x_batch = {k: x_batch[k].to(self.device) for k in self.src_modalities}
|
585 |
+
|
586 |
+
# forward
|
587 |
+
out: dict[str, Tensor] = self.net_.forward_emb(x_batch, is_embedding)
|
588 |
+
|
589 |
+
# convert output from dict-of-list to list of dict, then append
|
590 |
+
tmp = {k: out[k].detach().cpu().numpy() for k in self.src_modalities}
|
591 |
+
tmp = [{k: tmp[k][i] for k in self.src_modalities} for i in range(len(next(iter(tmp.values()))))]
|
592 |
+
embeddings += tmp
|
593 |
+
|
594 |
+
# remove imputed embeddings
|
595 |
+
for i in range(len(x)):
|
596 |
+
avail = [k for k, v in x[i].items() if v is not None]
|
597 |
+
embeddings[i] = {k: embeddings[i][k] for k in avail}
|
598 |
+
|
599 |
+
return embeddings
|
600 |
+
|
adrd/nn/__init__.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .transformer import Transformer
|
2 |
+
from .vitautoenc import ViTAutoEnc
|
3 |
+
from .unet import UNet3D
|
4 |
+
from .unet_3d import UNet3DBase
|
5 |
+
from .focal_loss import SigmoidFocalLoss
|
6 |
+
from .unet_img_model import ImageModel
|
7 |
+
from .img_model_wrapper import ImagingModelWrapper
|
8 |
+
from .resnet_img_model import ResNetModel
|
9 |
+
from .c3d import C3D
|
10 |
+
from .dense_net import DenseNet
|
11 |
+
from .cnn_resnet3d import CNNResNet3D
|
12 |
+
from .cnn_resnet3d_with_linear_classifier import CNNResNet3DWithLinearClassifier
|
adrd/nn/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (961 Bytes). View file
|
|
adrd/nn/__pycache__/blocks.cpython-311.pyc
ADDED
Binary file (2.88 kB). View file
|
|
adrd/nn/__pycache__/c3d.cpython-311.pyc
ADDED
Binary file (5.17 kB). View file
|
|
adrd/nn/__pycache__/cnn_resnet3d.cpython-311.pyc
ADDED
Binary file (4.37 kB). View file
|
|
adrd/nn/__pycache__/cnn_resnet3d_with_linear_classifier.cpython-311.pyc
ADDED
Binary file (4.06 kB). View file
|
|
adrd/nn/__pycache__/dense_net.cpython-311.pyc
ADDED
Binary file (13.8 kB). View file
|
|
adrd/nn/__pycache__/focal_loss.cpython-311.pyc
ADDED
Binary file (6.22 kB). View file
|
|
adrd/nn/__pycache__/img_model_wrapper.cpython-311.pyc
ADDED
Binary file (8.7 kB). View file
|
|
adrd/nn/__pycache__/net_resnet3d.cpython-311.pyc
ADDED
Binary file (17.2 kB). View file
|
|
adrd/nn/__pycache__/resnet3d.cpython-311.pyc
ADDED
Binary file (13.3 kB). View file
|
|
adrd/nn/__pycache__/resnet_img_model.cpython-311.pyc
ADDED
Binary file (2.85 kB). View file
|
|
adrd/nn/__pycache__/selfattention.cpython-311.pyc
ADDED
Binary file (3.58 kB). View file
|
|
adrd/nn/__pycache__/transformer.cpython-311.pyc
ADDED
Binary file (14.1 kB). View file
|
|
adrd/nn/__pycache__/unet.cpython-311.pyc
ADDED
Binary file (15.8 kB). View file
|
|
adrd/nn/__pycache__/unet_3d.cpython-311.pyc
ADDED
Binary file (3.03 kB). View file
|
|
adrd/nn/__pycache__/unet_img_model.cpython-311.pyc
ADDED
Binary file (14.1 kB). View file
|
|
adrd/nn/__pycache__/vitautoenc.cpython-311.pyc
ADDED
Binary file (8.61 kB). View file
|
|
adrd/nn/blocks.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) MONAI Consortium
|
2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6 |
+
# Unless required by applicable law or agreed to in writing, software
|
7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9 |
+
# See the License for the specific language governing permissions and
|
10 |
+
# limitations under the License.
|
11 |
+
|
12 |
+
from monai.networks.blocks.mlp import MLPBlock
|
13 |
+
from typing import Sequence, Union
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
|
17 |
+
from ..nn.selfattention import SABlock
|
18 |
+
|
19 |
+
class TransformerBlock(nn.Module):
|
20 |
+
"""
|
21 |
+
A transformer block, based on: "Dosovitskiy et al.,
|
22 |
+
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __init__(
|
26 |
+
self, hidden_size: int, mlp_dim: int, num_heads: int, dropout_rate: float = 0.0, qkv_bias: bool = False
|
27 |
+
) -> None:
|
28 |
+
"""
|
29 |
+
Args:
|
30 |
+
hidden_size: dimension of hidden layer.
|
31 |
+
mlp_dim: dimension of feedforward layer.
|
32 |
+
num_heads: number of attention heads.
|
33 |
+
dropout_rate: faction of the input units to drop.
|
34 |
+
qkv_bias: apply bias term for the qkv linear layer
|
35 |
+
|
36 |
+
"""
|
37 |
+
|
38 |
+
super().__init__()
|
39 |
+
|
40 |
+
if not (0 <= dropout_rate <= 1):
|
41 |
+
raise ValueError("dropout_rate should be between 0 and 1.")
|
42 |
+
|
43 |
+
if hidden_size % num_heads != 0:
|
44 |
+
raise ValueError("hidden_size should be divisible by num_heads.")
|
45 |
+
|
46 |
+
self.mlp = MLPBlock(hidden_size, mlp_dim, dropout_rate)
|
47 |
+
self.norm1 = nn.LayerNorm(hidden_size)
|
48 |
+
self.attn = SABlock(hidden_size, num_heads, dropout_rate, qkv_bias)
|
49 |
+
self.norm2 = nn.LayerNorm(hidden_size)
|
50 |
+
|
51 |
+
def forward(self, x, return_attention=False):
|
52 |
+
y, attn = self.attn(self.norm1(x))
|
53 |
+
if return_attention:
|
54 |
+
return attn
|
55 |
+
x = x + y
|
56 |
+
x = x + self.mlp(self.norm2(x))
|
57 |
+
return x
|
adrd/nn/c3d.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# From https://github.com/xmuyzz/3D-CNN-PyTorch/blob/master/models/C3DNet.py
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import sys
|
6 |
+
# from icecream import ic
|
7 |
+
import math
|
8 |
+
|
9 |
+
class C3D(torch.nn.Module):
|
10 |
+
|
11 |
+
def __init__(self, tgt_modalities, in_channels=1, load_from_ckpt=None):
|
12 |
+
|
13 |
+
super(C3D, self).__init__()
|
14 |
+
self.conv_group1 = nn.Sequential(
|
15 |
+
nn.Conv3d(in_channels, 64, kernel_size=3, padding=1),
|
16 |
+
nn.BatchNorm3d(64),
|
17 |
+
nn.ReLU(),
|
18 |
+
nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(1, 2, 2)))
|
19 |
+
self.conv_group2 = nn.Sequential(
|
20 |
+
nn.Conv3d(64, 128, kernel_size=3, padding=1),
|
21 |
+
nn.BatchNorm3d(128),
|
22 |
+
nn.ReLU(),
|
23 |
+
nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2)))
|
24 |
+
self.conv_group3 = nn.Sequential(
|
25 |
+
nn.Conv3d(128, 256, kernel_size=3, padding=1),
|
26 |
+
nn.BatchNorm3d(256),
|
27 |
+
nn.ReLU(),
|
28 |
+
nn.Conv3d(256, 256, kernel_size=3, padding=1),
|
29 |
+
nn.BatchNorm3d(256),
|
30 |
+
nn.ReLU(),
|
31 |
+
nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2))
|
32 |
+
)
|
33 |
+
self.conv_group4 = nn.Sequential(
|
34 |
+
nn.Conv3d(256, 512, kernel_size=3, padding=1),
|
35 |
+
nn.BatchNorm3d(512),
|
36 |
+
nn.ReLU(),
|
37 |
+
nn.Conv3d(512, 512, kernel_size=3, padding=1),
|
38 |
+
nn.BatchNorm3d(512),
|
39 |
+
nn.ReLU(),
|
40 |
+
nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=(0, 1, 1))
|
41 |
+
)
|
42 |
+
|
43 |
+
# last_duration = int(math.floor(128 / 16))
|
44 |
+
# last_size = int(math.ceil(128 / 32))
|
45 |
+
self.fc1 = nn.Sequential(
|
46 |
+
nn.Linear((512 * 15 * 9 * 9) , 512),
|
47 |
+
nn.ReLU(),
|
48 |
+
nn.Dropout(0.5))
|
49 |
+
self.fc2 = nn.Sequential(
|
50 |
+
nn.Linear(512, 256),
|
51 |
+
nn.ReLU(),
|
52 |
+
nn.Dropout(0.5))
|
53 |
+
# self.fc = nn.Sequential(
|
54 |
+
# nn.Linear(4096, num_classes))
|
55 |
+
|
56 |
+
self.fc = torch.nn.ModuleDict()
|
57 |
+
for k in tgt_modalities:
|
58 |
+
self.fc[k] = torch.nn.Linear(256, 1)
|
59 |
+
|
60 |
+
def forward(self, x):
|
61 |
+
# for k in x.keys():
|
62 |
+
# x[k] = x[k].to(torch.float32)
|
63 |
+
|
64 |
+
# x = torch.stack([o for o in x.values()], dim=0)[0]
|
65 |
+
# print(x.shape)
|
66 |
+
|
67 |
+
out = self.conv_group1(x)
|
68 |
+
out = self.conv_group2(out)
|
69 |
+
out = self.conv_group3(out)
|
70 |
+
out = self.conv_group4(out)
|
71 |
+
out = out.view(out.size(0), -1)
|
72 |
+
# print(out.shape)
|
73 |
+
out = self.fc1(out)
|
74 |
+
out = self.fc2(out)
|
75 |
+
# out = self.fc(out)
|
76 |
+
|
77 |
+
tgt_iter = self.fc.keys()
|
78 |
+
out_tgt = {k: self.fc[k](out).squeeze(1) for k in tgt_iter}
|
79 |
+
return out_tgt
|
80 |
+
|
81 |
+
|
82 |
+
if __name__ == "__main__":
|
83 |
+
model = C3D(tgt_modalities=['NC', 'MCI', 'DE'])
|
84 |
+
print(model)
|
85 |
+
x = torch.rand((1, 1, 128, 128, 128))
|
86 |
+
# layers = list(model.features.named_children())
|
87 |
+
# features = nn.Sequential(*list(model.features.children()))(x)
|
88 |
+
# print(features.shape)
|
89 |
+
print(sum(p.numel() for p in model.parameters()))
|
90 |
+
# layer_found = False
|
91 |
+
# features = None
|
92 |
+
# desired_layer_name = 'transition3'
|
93 |
+
|
94 |
+
# for name, layer in layers:
|
95 |
+
# if name == desired_layer_name:
|
96 |
+
# x = layer(x)
|
97 |
+
# print(x)
|
98 |
+
# model(x)
|
99 |
+
# print(features)
|
adrd/nn/cnn_resnet3d.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from typing import Any, Type
|
4 |
+
Tensor = Type[torch.Tensor]
|
5 |
+
|
6 |
+
from .resnet3d import r3d_18
|
7 |
+
|
8 |
+
|
9 |
+
class CNNResNet3D(nn.Module):
|
10 |
+
|
11 |
+
def __init__(self,
|
12 |
+
src_modalities: dict[str, dict[str, Any]],
|
13 |
+
tgt_modalities: dict[str, dict[str, Any]]
|
14 |
+
) -> None:
|
15 |
+
""" ... """
|
16 |
+
super().__init__()
|
17 |
+
|
18 |
+
# resnet
|
19 |
+
# embedding modules for source
|
20 |
+
self.modules_emb_src = nn.ModuleDict()
|
21 |
+
for k, info in src_modalities.items():
|
22 |
+
if info['type'] == 'imaging' and len(info['img_shape']) == 4:
|
23 |
+
self.modules_emb_src[k] = nn.Sequential(
|
24 |
+
r3d_18(),
|
25 |
+
nn.Dropout(0.5)
|
26 |
+
)
|
27 |
+
else:
|
28 |
+
# unrecognized
|
29 |
+
raise ValueError('{} is an unrecognized data modality'.format(k))
|
30 |
+
|
31 |
+
# classifiers (binary only)
|
32 |
+
self.modules_cls = nn.ModuleDict()
|
33 |
+
for k, info in tgt_modalities.items():
|
34 |
+
if info['type'] == 'categorical' and info['num_categories'] == 2:
|
35 |
+
# categorical
|
36 |
+
self.modules_cls[k] = nn.Linear(256, 1)
|
37 |
+
else:
|
38 |
+
# unrecognized
|
39 |
+
raise ValueError
|
40 |
+
|
41 |
+
def forward(self,
|
42 |
+
x: dict[str, Tensor],
|
43 |
+
) -> dict[str, Tensor]:
|
44 |
+
""" ... """
|
45 |
+
out_emb = self.forward_emb(x)
|
46 |
+
out_emb = out_emb[list(out_emb.keys())[0]]
|
47 |
+
out_cls = self.forward_cls(out_emb)
|
48 |
+
return out_cls
|
49 |
+
|
50 |
+
def forward_emb(self,
|
51 |
+
x: dict[str, Tensor],
|
52 |
+
) -> dict[str, Tensor]:
|
53 |
+
""" ... """
|
54 |
+
out_emb = dict()
|
55 |
+
for k in self.modules_emb_src.keys():
|
56 |
+
out_emb[k] = self.modules_emb_src[k](x[k])
|
57 |
+
return out_emb
|
58 |
+
|
59 |
+
def forward_cls(self,
|
60 |
+
out_emb: dict[str, Tensor]
|
61 |
+
) -> dict[str, Tensor]:
|
62 |
+
""" ... """
|
63 |
+
out_cls = dict()
|
64 |
+
for k in self.modules_cls.keys():
|
65 |
+
out_cls[k] = self.modules_cls[k](out_emb).squeeze(1)
|
66 |
+
return out_cls
|
67 |
+
|
68 |
+
|
69 |
+
# for testing purpose only
|
70 |
+
if __name__ == '__main__':
|
71 |
+
src_modalities = {
|
72 |
+
'img_MRI_T1': {'type': 'imaging', 'img_shape': [1, 182, 218, 182]}
|
73 |
+
}
|
74 |
+
tgt_modalities = {
|
75 |
+
'AD': {'type': 'categorical', 'num_categories': 2},
|
76 |
+
'PD': {'type': 'categorical', 'num_categories': 2}
|
77 |
+
}
|
78 |
+
net = CNNResNet3D(src_modalities, tgt_modalities)
|
79 |
+
net.eval()
|
80 |
+
x = {'img_MRI_T1': torch.zeros(2, 1, 182, 218, 182)}
|
81 |
+
print(net(x))
|
adrd/nn/cnn_resnet3d_with_linear_classifier.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from typing import Any, Type
|
4 |
+
Tensor = Type[torch.Tensor]
|
5 |
+
|
6 |
+
from .resnet3d import r3d_18
|
7 |
+
|
8 |
+
class CNNResNet3DWithLinearClassifier(nn.Module):
|
9 |
+
|
10 |
+
def __init__(self,
|
11 |
+
src_modalities: dict[str, dict[str, Any]],
|
12 |
+
tgt_modalities: dict[str, dict[str, Any]]
|
13 |
+
) -> None:
|
14 |
+
""" ... """
|
15 |
+
super().__init__()
|
16 |
+
self.core = _CNNResNet3DWithLinearClassifier(len(tgt_modalities))
|
17 |
+
self.src_modalities = src_modalities
|
18 |
+
self.tgt_modalities = tgt_modalities
|
19 |
+
|
20 |
+
def forward(self,
|
21 |
+
x: dict[str, Tensor],
|
22 |
+
) -> dict[str, Tensor]:
|
23 |
+
""" x is expected to be a singleton dictionary """
|
24 |
+
src_k = list(x.keys())[0]
|
25 |
+
x = x[src_k]
|
26 |
+
out = self.core(x)
|
27 |
+
out = {tgt_k: out[:, i] for i, tgt_k in enumerate(self.tgt_modalities)}
|
28 |
+
return out
|
29 |
+
|
30 |
+
|
31 |
+
class _CNNResNet3DWithLinearClassifier(nn.Module):
|
32 |
+
|
33 |
+
def __init__(self,
|
34 |
+
len_tgt_modalities: int,
|
35 |
+
) -> None:
|
36 |
+
""" ... """
|
37 |
+
super().__init__()
|
38 |
+
self.cnn = r3d_18()
|
39 |
+
self.cls = nn.Sequential(
|
40 |
+
nn.Dropout(0.5),
|
41 |
+
nn.Linear(256, len_tgt_modalities),
|
42 |
+
)
|
43 |
+
|
44 |
+
def forward(self, x: Tensor) -> Tensor:
|
45 |
+
""" ... """
|
46 |
+
out_emb = self.forward_emb(x)
|
47 |
+
out_cls = self.forward_cls(out_emb)
|
48 |
+
return out_cls
|
49 |
+
|
50 |
+
def forward_emb(self, x: Tensor) -> Tensor:
|
51 |
+
""" ... """
|
52 |
+
return self.cnn(x)
|
53 |
+
|
54 |
+
def forward_cls(self, out_emb: Tensor) -> Tensor:
|
55 |
+
""" ... """
|
56 |
+
return self.cls(out_emb)
|
adrd/nn/dense_net.py
ADDED
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This implementation is based on the DenseNet-BC implementation in torchvision
|
2 |
+
# https://github.com/pytorch/vision/blob/master/torchvision/models/densenet.py
|
3 |
+
# https://github.com/gpleiss/efficient_densenet_pytorch/blob/master/models/densenet.py
|
4 |
+
|
5 |
+
|
6 |
+
import math
|
7 |
+
import torch
|
8 |
+
import numpy as np
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
import torch.utils.checkpoint as cp
|
12 |
+
from collections import OrderedDict
|
13 |
+
|
14 |
+
|
15 |
+
def _bn_function_factory(norm, relu, conv):
|
16 |
+
def bn_function(*inputs):
|
17 |
+
concated_features = torch.cat(inputs, 1)
|
18 |
+
bottleneck_output = conv(relu(norm(concated_features)))
|
19 |
+
return bottleneck_output
|
20 |
+
|
21 |
+
return bn_function
|
22 |
+
|
23 |
+
|
24 |
+
class _DenseLayer(nn.Module):
|
25 |
+
def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, efficient=False):
|
26 |
+
super(_DenseLayer, self).__init__()
|
27 |
+
self.add_module('norm1', nn.BatchNorm3d(num_input_features)),
|
28 |
+
self.add_module('relu1', nn.ReLU(inplace=True)),
|
29 |
+
self.add_module('conv1', nn.Conv3d(num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False)),
|
30 |
+
self.add_module('norm2', nn.BatchNorm3d(bn_size * growth_rate)),
|
31 |
+
self.add_module('relu2', nn.ReLU(inplace=True)),
|
32 |
+
self.add_module('conv2', nn.Conv3d(bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False)),
|
33 |
+
self.drop_rate = drop_rate
|
34 |
+
self.efficient = efficient
|
35 |
+
|
36 |
+
def forward(self, *prev_features):
|
37 |
+
bn_function = _bn_function_factory(self.norm1, self.relu1, self.conv1)
|
38 |
+
if self.efficient and any(prev_feature.requires_grad for prev_feature in prev_features):
|
39 |
+
bottleneck_output = cp.checkpoint(bn_function, *prev_features)
|
40 |
+
else:
|
41 |
+
bottleneck_output = bn_function(*prev_features)
|
42 |
+
new_features = self.conv2(self.relu2(self.norm2(bottleneck_output)))
|
43 |
+
if self.drop_rate > 0:
|
44 |
+
new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)
|
45 |
+
return new_features
|
46 |
+
|
47 |
+
|
48 |
+
class _Transition(nn.Sequential):
|
49 |
+
def __init__(self, num_input_features, num_output_features):
|
50 |
+
super(_Transition, self).__init__()
|
51 |
+
self.add_module('norm', nn.BatchNorm3d(num_input_features))
|
52 |
+
self.add_module('relu', nn.ReLU(inplace=True))
|
53 |
+
self.add_module('conv', nn.Conv3d(num_input_features, num_output_features, kernel_size=1, stride=1, bias=False))
|
54 |
+
self.add_module('pool', nn.AvgPool3d(kernel_size=2, stride=2))
|
55 |
+
|
56 |
+
|
57 |
+
class _DenseBlock(nn.Module):
|
58 |
+
def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate, efficient=False):
|
59 |
+
super(_DenseBlock, self).__init__()
|
60 |
+
for i in range(num_layers):
|
61 |
+
layer = _DenseLayer(
|
62 |
+
num_input_features + i * growth_rate,
|
63 |
+
growth_rate=growth_rate,
|
64 |
+
bn_size=bn_size,
|
65 |
+
drop_rate=drop_rate,
|
66 |
+
efficient=efficient,
|
67 |
+
)
|
68 |
+
self.add_module('denselayer%d' % (i + 1), layer)
|
69 |
+
|
70 |
+
def forward(self, init_features):
|
71 |
+
features = [init_features]
|
72 |
+
for name, layer in self.named_children():
|
73 |
+
new_features = layer(*features)
|
74 |
+
features.append(new_features)
|
75 |
+
return torch.cat(features, 1)
|
76 |
+
|
77 |
+
|
78 |
+
class DenseNet(nn.Module):
|
79 |
+
r"""Densenet-BC model class, based on
|
80 |
+
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
|
81 |
+
Args:
|
82 |
+
growth_rate (int) - how many filters to add each layer (`k` in paper)
|
83 |
+
block_config (list of 3 or 4 ints) - how many layers in each pooling block
|
84 |
+
num_init_features (int) - the number of filters to learn in the first convolution layer
|
85 |
+
bn_size (int) - multiplicative factor for number of bottle neck layers
|
86 |
+
(i.e. bn_size * k features in the bottleneck layer)
|
87 |
+
drop_rate (float) - dropout rate after each dense layer
|
88 |
+
tgt_modalities (list) - list of target modalities
|
89 |
+
efficient (bool) - set to True to use checkpointing. Much more memory efficient, but slower.
|
90 |
+
"""
|
91 |
+
# def __init__(self, tgt_modalities, growth_rate=12, block_config=(3, 3, 3), compression=0.5,
|
92 |
+
# num_init_features=16, bn_size=4, drop_rate=0, efficient=False, load_from_ckpt=False): # config 1
|
93 |
+
|
94 |
+
def __init__(self, tgt_modalities, growth_rate=12, block_config=(3, 3, 3), compression=0.5,
|
95 |
+
num_init_features=16, bn_size=4, drop_rate=0, efficient=False, load_from_ckpt=False): # config 2
|
96 |
+
|
97 |
+
super(DenseNet, self).__init__()
|
98 |
+
|
99 |
+
# First convolution
|
100 |
+
self.features = nn.Sequential(OrderedDict([('conv0', nn.Conv3d(1, num_init_features, kernel_size=7, stride=2, padding=0, bias=False)),]))
|
101 |
+
self.features.add_module('norm0', nn.BatchNorm3d(num_init_features))
|
102 |
+
self.features.add_module('relu0', nn.ReLU(inplace=True))
|
103 |
+
self.features.add_module('pool0', nn.MaxPool3d(kernel_size=3, stride=2, padding=0, ceil_mode=False))
|
104 |
+
self.tgt_modalities = tgt_modalities
|
105 |
+
|
106 |
+
# Each denseblock
|
107 |
+
num_features = num_init_features
|
108 |
+
for i, num_layers in enumerate(block_config):
|
109 |
+
block = _DenseBlock(
|
110 |
+
num_layers=num_layers,
|
111 |
+
num_input_features=num_features,
|
112 |
+
bn_size=bn_size,
|
113 |
+
growth_rate=growth_rate,
|
114 |
+
drop_rate=drop_rate,
|
115 |
+
efficient=efficient,
|
116 |
+
)
|
117 |
+
self.features.add_module('denseblock%d' % (i + 1), block)
|
118 |
+
num_features = num_features + num_layers * growth_rate
|
119 |
+
if i != len(block_config):
|
120 |
+
trans = _Transition(num_input_features=num_features,
|
121 |
+
num_output_features=int(num_features * compression))
|
122 |
+
self.features.add_module('transition%d' % (i + 1), trans)
|
123 |
+
num_features = int(num_features * compression)
|
124 |
+
|
125 |
+
# Final batch norm
|
126 |
+
self.features.add_module('norm_final', nn.BatchNorm3d(num_features))
|
127 |
+
|
128 |
+
# Classification heads
|
129 |
+
self.tgt = torch.nn.ModuleDict()
|
130 |
+
for k in tgt_modalities:
|
131 |
+
# self.tgt[k] = torch.nn.Linear(621, 1) # config 2
|
132 |
+
self.tgt[k] = torch.nn.Sequential(
|
133 |
+
torch.nn.Linear(self.test_size(), 256),
|
134 |
+
torch.nn.ReLU(),
|
135 |
+
torch.nn.Linear(256, 1)
|
136 |
+
)
|
137 |
+
|
138 |
+
print(f'load_from_ckpt: {load_from_ckpt}')
|
139 |
+
# Initialization
|
140 |
+
if not load_from_ckpt:
|
141 |
+
for name, param in self.named_parameters():
|
142 |
+
if 'conv' in name and 'weight' in name:
|
143 |
+
n = param.size(0) * param.size(2) * param.size(3) * param.size(4)
|
144 |
+
param.data.normal_().mul_(math.sqrt(2. / n))
|
145 |
+
elif 'norm' in name and 'weight' in name:
|
146 |
+
param.data.fill_(1)
|
147 |
+
elif 'norm' in name and 'bias' in name:
|
148 |
+
param.data.fill_(0)
|
149 |
+
elif ('classifier' in name or 'tgt' in name) and 'bias' in name:
|
150 |
+
param.data.fill_(0)
|
151 |
+
|
152 |
+
# self.size = self.test_size()
|
153 |
+
|
154 |
+
def forward(self, x, shap=True):
|
155 |
+
# print(x.shape)
|
156 |
+
features = self.features(x)
|
157 |
+
# print(features.shape)
|
158 |
+
out = F.relu(features, inplace=True)
|
159 |
+
# out = F.adaptive_avg_pool3d(out, (1, 1, 1))
|
160 |
+
out = torch.flatten(out, 1)
|
161 |
+
|
162 |
+
# print(out.shape)
|
163 |
+
|
164 |
+
# out_tgt = self.tgt(out).squeeze(1)
|
165 |
+
# print(out_tgt)
|
166 |
+
# return F.softmax(out_tgt)
|
167 |
+
|
168 |
+
tgt_iter = self.tgt.keys()
|
169 |
+
out_tgt = {k: self.tgt[k](out).squeeze(1) for k in tgt_iter}
|
170 |
+
if shap:
|
171 |
+
out_tgt = torch.stack(list(out_tgt.values()))
|
172 |
+
return out_tgt.T
|
173 |
+
else:
|
174 |
+
return out_tgt
|
175 |
+
|
176 |
+
def test_size(self):
|
177 |
+
case = torch.ones((1, 1, 182, 218, 182))
|
178 |
+
output = self.features(case).view(-1).size(0)
|
179 |
+
return output
|
180 |
+
|
181 |
+
|
182 |
+
if __name__ == "__main__":
|
183 |
+
model = DenseNet(
|
184 |
+
tgt_modalities=['NC', 'MCI', 'DE'],
|
185 |
+
growth_rate=12,
|
186 |
+
block_config=(2, 3, 2),
|
187 |
+
compression=0.5,
|
188 |
+
num_init_features=16,
|
189 |
+
drop_rate=0.2)
|
190 |
+
print(model)
|
191 |
+
torch.manual_seed(42)
|
192 |
+
x = torch.rand((1, 1, 182, 218, 182))
|
193 |
+
# layers = list(model.features.named_children())
|
194 |
+
features = nn.Sequential(*list(model.features.children()))(x)
|
195 |
+
print(features.shape)
|
196 |
+
print(sum(p.numel() for p in model.parameters()))
|
197 |
+
# out = mdl.net_(x, shap=False)
|
198 |
+
# print(out)
|
199 |
+
|
200 |
+
out = model(x, shap=False)
|
201 |
+
print(out)
|
202 |
+
# layer_found = False
|
203 |
+
# features = None
|
204 |
+
# desired_layer_name = 'transition3'
|
205 |
+
|
206 |
+
# for name, layer in layers:
|
207 |
+
# if name == desired_layer_name:
|
208 |
+
# x = layer(x)
|
209 |
+
# print(x)
|
210 |
+
# model(x)
|
211 |
+
# print(features)
|
adrd/nn/focal_loss.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import sys
|
5 |
+
|
6 |
+
class SigmoidFocalLoss(nn.Module):
|
7 |
+
''' ... '''
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
alpha: float = -1,
|
11 |
+
gamma: float = 2.0,
|
12 |
+
reduction: str = 'mean',
|
13 |
+
):
|
14 |
+
''' ... '''
|
15 |
+
super().__init__()
|
16 |
+
self.alpha = alpha
|
17 |
+
self.gamma = gamma
|
18 |
+
self.reduction = reduction
|
19 |
+
|
20 |
+
def forward(self, input, target):
|
21 |
+
''' ... '''
|
22 |
+
p = torch.sigmoid(input)
|
23 |
+
ce_loss = F.binary_cross_entropy_with_logits(input, target, reduction='none')
|
24 |
+
p_t = p * target + (1 - p) * (1 - target)
|
25 |
+
loss = ce_loss * ((1 - p_t) ** self.gamma)
|
26 |
+
|
27 |
+
if self.alpha >= 0:
|
28 |
+
alpha_t = self.alpha * target + (1 - self.alpha) * (1 - target)
|
29 |
+
loss = alpha_t * loss
|
30 |
+
|
31 |
+
if self.reduction == 'mean':
|
32 |
+
loss = loss.mean()
|
33 |
+
elif self.reduction == 'sum':
|
34 |
+
loss = loss.sum()
|
35 |
+
|
36 |
+
return loss
|
37 |
+
|
38 |
+
|
39 |
+
class SigmoidFocalLossBeta(nn.Module):
|
40 |
+
''' ... '''
|
41 |
+
def __init__(
|
42 |
+
self,
|
43 |
+
beta: float = 0.9999,
|
44 |
+
gamma: float = 2.0,
|
45 |
+
num_per_cls = (1, 1),
|
46 |
+
reduction: str = 'mean',
|
47 |
+
):
|
48 |
+
''' ... '''
|
49 |
+
super().__init__()
|
50 |
+
eps = sys.float_info.epsilon
|
51 |
+
self.gamma = gamma
|
52 |
+
self.reduction = reduction
|
53 |
+
|
54 |
+
# weights to balance loss
|
55 |
+
self.weight_neg = ((1 - beta) / (1 - beta ** num_per_cls[0] + eps))
|
56 |
+
self.weight_pos = ((1 - beta) / (1 - beta ** num_per_cls[1] + eps))
|
57 |
+
# weight_neg = (1 - beta) / (1 - beta ** num_per_cls[0])
|
58 |
+
# weight_pos = (1 - beta) / (1 - beta ** num_per_cls[1])
|
59 |
+
# self.weight_neg = weight_neg / (weight_neg + weight_pos)
|
60 |
+
# self.weight_pos = weight_pos / (weight_neg + weight_pos)
|
61 |
+
|
62 |
+
def forward(self, input, target):
|
63 |
+
''' ... '''
|
64 |
+
p = torch.sigmoid(input)
|
65 |
+
p_t = p * target + (1 - p) * (1 - target)
|
66 |
+
ce_loss = F.binary_cross_entropy_with_logits(input, target, reduction='none')
|
67 |
+
loss = ce_loss * ((1 - p_t) ** self.gamma)
|
68 |
+
|
69 |
+
alpha_t = self.weight_pos * target + self.weight_neg * (1 - target)
|
70 |
+
loss = alpha_t * loss
|
71 |
+
|
72 |
+
if self.reduction == 'mean':
|
73 |
+
loss = loss.mean()
|
74 |
+
elif self.reduction == 'sum':
|
75 |
+
loss = loss.sum()
|
76 |
+
|
77 |
+
return loss
|
78 |
+
|
79 |
+
class AsymmetricLoss(nn.Module):
|
80 |
+
def __init__(self, gamma_neg=4, gamma_pos=1, alpha=0.5, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=True):
|
81 |
+
super(AsymmetricLoss, self).__init__()
|
82 |
+
self.alpha = alpha
|
83 |
+
self.gamma_neg = gamma_neg
|
84 |
+
self.gamma_pos = gamma_pos
|
85 |
+
self.clip = clip
|
86 |
+
self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
|
87 |
+
self.eps = eps
|
88 |
+
|
89 |
+
|
90 |
+
def forward(self, x, y):
|
91 |
+
""""
|
92 |
+
Parameters
|
93 |
+
----------
|
94 |
+
x: input logits
|
95 |
+
y: targets (multi-label binarized vector)
|
96 |
+
"""
|
97 |
+
# Calculating Probabilities
|
98 |
+
x_sigmoid = torch.sigmoid(x)
|
99 |
+
xs_pos = x_sigmoid
|
100 |
+
xs_neg = 1 - x_sigmoid
|
101 |
+
# Asymmetric Clipping
|
102 |
+
if self.clip is not None and self.clip > 0:
|
103 |
+
xs_neg = (xs_neg + self.clip).clamp(max=1)
|
104 |
+
# Basic CE calculation
|
105 |
+
los_pos = y * torch.log(xs_pos.clamp(min=self.eps))
|
106 |
+
los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps))
|
107 |
+
loss = self.alpha*los_pos + (1-self.alpha)*los_neg
|
108 |
+
# Asymmetric Focusing
|
109 |
+
if self.gamma_neg > 0 or self.gamma_pos > 0:
|
110 |
+
if self.disable_torch_grad_focal_loss:
|
111 |
+
torch.set_grad_enabled(False)
|
112 |
+
pt0 = xs_pos * y
|
113 |
+
pt1 = xs_neg * (1 - y) # pt = p if t > 0 else 1-p
|
114 |
+
pt = pt0 + pt1
|
115 |
+
one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y)
|
116 |
+
one_sided_w = torch.pow(1 - pt, one_sided_gamma)
|
117 |
+
if self.disable_torch_grad_focal_loss:
|
118 |
+
torch.set_grad_enabled(True)
|
119 |
+
loss *= one_sided_w
|
120 |
+
return -loss#.sum()
|
adrd/nn/img_model_wrapper.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from .. import nn
|
3 |
+
from .. import model
|
4 |
+
import numpy as np
|
5 |
+
from icecream import ic
|
6 |
+
from monai.networks.nets.swin_unetr import SwinUNETR
|
7 |
+
from typing import Any
|
8 |
+
|
9 |
+
class ImagingModelWrapper(torch.nn.Module):
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
arch: str = 'ViTAutoEnc',
|
13 |
+
tgt_modalities: dict | None = {},
|
14 |
+
img_size: int | None = 128,
|
15 |
+
patch_size: int | None = 16,
|
16 |
+
ckpt_path: str | None = None,
|
17 |
+
train_backbone: bool = False,
|
18 |
+
out_dim: int = 128,
|
19 |
+
layers: int | None = 1,
|
20 |
+
device: str = 'cpu',
|
21 |
+
fusion_stage: str = 'middle',
|
22 |
+
):
|
23 |
+
super(ImagingModelWrapper, self).__init__()
|
24 |
+
|
25 |
+
self.arch = arch
|
26 |
+
self.tgt_modalities = tgt_modalities
|
27 |
+
self.img_size = img_size
|
28 |
+
self.patch_size = patch_size
|
29 |
+
self.train_backbone = train_backbone
|
30 |
+
self.ckpt_path = ckpt_path
|
31 |
+
self.device = device
|
32 |
+
self.out_dim = out_dim
|
33 |
+
self.layers = layers
|
34 |
+
self.fusion_stage = fusion_stage
|
35 |
+
|
36 |
+
|
37 |
+
if "swinunetr" in self.arch.lower():
|
38 |
+
if "emb" not in self.arch.lower():
|
39 |
+
ckpt_path = '/projectnb/ivc-ml/dlteif/pretrained_models/model_swinvit.pt'
|
40 |
+
ckpt = torch.load(ckpt_path, map_location='cpu')
|
41 |
+
self.img_model = SwinUNETR(
|
42 |
+
in_channels=1,
|
43 |
+
out_channels=1,
|
44 |
+
img_size=128,
|
45 |
+
feature_size=48,
|
46 |
+
use_checkpoint=True,
|
47 |
+
)
|
48 |
+
ckpt["state_dict"] = {k.replace("swinViT.", "module."): v for k, v in ckpt["state_dict"].items()}
|
49 |
+
ic(ckpt["state_dict"].keys())
|
50 |
+
self.img_model.load_from(ckpt)
|
51 |
+
self.dim = 768
|
52 |
+
|
53 |
+
elif "vit" in self.arch.lower():
|
54 |
+
if "emb" not in self.arch.lower():
|
55 |
+
# Initialize image model
|
56 |
+
self.img_model = nn.__dict__[self.arch](
|
57 |
+
in_channels = 1,
|
58 |
+
img_size = self.img_size,
|
59 |
+
patch_size = self.patch_size,
|
60 |
+
)
|
61 |
+
|
62 |
+
if self.ckpt_path:
|
63 |
+
self.img_model.load(self.ckpt_path, map_location=self.device)
|
64 |
+
self.dim = self.img_model.hidden_size
|
65 |
+
else:
|
66 |
+
self.dim = 768
|
67 |
+
|
68 |
+
if "vit" in self.arch.lower() or "swinunetr" in self.arch.lower():
|
69 |
+
dim = self.dim
|
70 |
+
if self.fusion_stage == 'middle':
|
71 |
+
downsample = torch.nn.ModuleList()
|
72 |
+
# print('Number of layers: ', self.layers)
|
73 |
+
for i in range(self.layers):
|
74 |
+
if i == self.layers - 1:
|
75 |
+
dim_out = self.out_dim
|
76 |
+
# print(layers)
|
77 |
+
ks = 2
|
78 |
+
stride = 2
|
79 |
+
else:
|
80 |
+
dim_out = dim // 2
|
81 |
+
ks = 2
|
82 |
+
stride = 2
|
83 |
+
|
84 |
+
downsample.append(
|
85 |
+
torch.nn.Conv1d(in_channels=dim, out_channels=dim_out, kernel_size=ks, stride=stride)
|
86 |
+
)
|
87 |
+
|
88 |
+
dim = dim_out
|
89 |
+
|
90 |
+
downsample.append(
|
91 |
+
torch.nn.BatchNorm1d(dim)
|
92 |
+
)
|
93 |
+
downsample.append(
|
94 |
+
torch.nn.ReLU()
|
95 |
+
)
|
96 |
+
|
97 |
+
|
98 |
+
self.downsample = torch.nn.Sequential(*downsample)
|
99 |
+
elif self.fusion_stage == 'late':
|
100 |
+
self.downsample = torch.nn.Identity()
|
101 |
+
else:
|
102 |
+
pass
|
103 |
+
|
104 |
+
# print('Downsample layers: ', self.downsample)
|
105 |
+
|
106 |
+
elif "densenet" in self.arch.lower():
|
107 |
+
if "emb" not in self.arch.lower():
|
108 |
+
self.img_model = model.ImagingModel.from_ckpt(self.ckpt_path, device=self.device, img_backend=self.arch, load_from_ckpt=True).net_
|
109 |
+
|
110 |
+
self.downsample = torch.nn.Linear(3900, self.out_dim)
|
111 |
+
|
112 |
+
# randomly initialize weights for downsample block
|
113 |
+
for p in self.downsample.parameters():
|
114 |
+
if p.dim() > 1:
|
115 |
+
torch.nn.init.xavier_uniform_(p)
|
116 |
+
p.requires_grad = True
|
117 |
+
|
118 |
+
if "emb" not in self.arch.lower():
|
119 |
+
# freeze imaging model parameters
|
120 |
+
if "densenet" in self.arch.lower():
|
121 |
+
for n, p in self.img_model.features.named_parameters():
|
122 |
+
if not self.train_backbone:
|
123 |
+
p.requires_grad = False
|
124 |
+
else:
|
125 |
+
p.requires_grad = True
|
126 |
+
for n, p in self.img_model.tgt.named_parameters():
|
127 |
+
p.requires_grad = False
|
128 |
+
else:
|
129 |
+
for n, p in self.img_model.named_parameters():
|
130 |
+
# print(n, p.requires_grad)
|
131 |
+
if not self.train_backbone:
|
132 |
+
p.requires_grad = False
|
133 |
+
else:
|
134 |
+
p.requires_grad = True
|
135 |
+
|
136 |
+
def forward(self, x):
|
137 |
+
# print("--------ImagingModelWrapper forward--------")
|
138 |
+
if "emb" not in self.arch.lower():
|
139 |
+
if "swinunetr" in self.arch.lower():
|
140 |
+
# print(x.size())
|
141 |
+
out = self.img_model(x)
|
142 |
+
# print(out.size())
|
143 |
+
out = self.downsample(out)
|
144 |
+
# print(out.size())
|
145 |
+
out = torch.mean(out, dim=-1)
|
146 |
+
# print(out.size())
|
147 |
+
elif "vit" in self.arch.lower():
|
148 |
+
out = self.img_model(x, return_emb=True)
|
149 |
+
ic(out.size())
|
150 |
+
out = self.downsample(out)
|
151 |
+
out = torch.mean(out, dim=-1)
|
152 |
+
elif "densenet" in self.arch.lower():
|
153 |
+
out = torch.nn.Sequential(*list(self.img_model.features.children()))(x)
|
154 |
+
# print(out.size())
|
155 |
+
out = torch.flatten(out, 1)
|
156 |
+
out = self.downsample(out)
|
157 |
+
else:
|
158 |
+
# print(x.size())
|
159 |
+
if "swinunetr" in self.arch.lower():
|
160 |
+
x = torch.squeeze(x, dim=1)
|
161 |
+
x = x.view(x.size(0),self.dim, -1)
|
162 |
+
# print('x: ', x.size())
|
163 |
+
out = self.downsample(x)
|
164 |
+
# print('out: ', out.size())
|
165 |
+
if self.fusion_stage == 'middle':
|
166 |
+
if "vit" in self.arch.lower() or "swinunetr" in self.arch.lower():
|
167 |
+
out = torch.mean(out, dim=-1)
|
168 |
+
else:
|
169 |
+
out = torch.squeeze(out, dim=1)
|
170 |
+
elif self.fusion_stage == 'late':
|
171 |
+
pass
|
172 |
+
|
173 |
+
return out
|
174 |
+
|
adrd/nn/net_resnet3d.py
ADDED
@@ -0,0 +1,338 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Created on Sat Nov 21 10:49:39 2021
|
3 |
+
|
4 |
+
@author: cxue2
|
5 |
+
"""
|
6 |
+
|
7 |
+
import torch.nn as nn
|
8 |
+
|
9 |
+
|
10 |
+
__all__ = ['r3d_18', 'mc3_18', 'r2plus1d_18']
|
11 |
+
|
12 |
+
|
13 |
+
class Conv3DSimple(nn.Conv3d):
|
14 |
+
def __init__(self,
|
15 |
+
in_planes,
|
16 |
+
out_planes,
|
17 |
+
midplanes=None,
|
18 |
+
stride=1,
|
19 |
+
padding=1):
|
20 |
+
|
21 |
+
super(Conv3DSimple, self).__init__(
|
22 |
+
in_channels=in_planes,
|
23 |
+
out_channels=out_planes,
|
24 |
+
kernel_size=(3, 3, 3),
|
25 |
+
stride=stride,
|
26 |
+
padding=padding,
|
27 |
+
bias=False)
|
28 |
+
|
29 |
+
@staticmethod
|
30 |
+
def get_downsample_stride(stride):
|
31 |
+
return stride, stride, stride
|
32 |
+
|
33 |
+
|
34 |
+
class Conv2Plus1D(nn.Sequential):
|
35 |
+
|
36 |
+
def __init__(self,
|
37 |
+
in_planes,
|
38 |
+
out_planes,
|
39 |
+
midplanes,
|
40 |
+
stride=1,
|
41 |
+
padding=1):
|
42 |
+
super(Conv2Plus1D, self).__init__(
|
43 |
+
nn.Conv3d(in_planes, midplanes, kernel_size=(1, 3, 3),
|
44 |
+
stride=(1, stride, stride), padding=(0, padding, padding),
|
45 |
+
bias=False),
|
46 |
+
nn.BatchNorm3d(midplanes),
|
47 |
+
nn.ReLU(inplace=True),
|
48 |
+
nn.Conv3d(midplanes, out_planes, kernel_size=(3, 1, 1),
|
49 |
+
stride=(stride, 1, 1), padding=(padding, 0, 0),
|
50 |
+
bias=False))
|
51 |
+
|
52 |
+
@staticmethod
|
53 |
+
def get_downsample_stride(stride):
|
54 |
+
return stride, stride, stride
|
55 |
+
|
56 |
+
|
57 |
+
class Conv3DNoTemporal(nn.Conv3d):
|
58 |
+
|
59 |
+
def __init__(self,
|
60 |
+
in_planes,
|
61 |
+
out_planes,
|
62 |
+
midplanes=None,
|
63 |
+
stride=1,
|
64 |
+
padding=1):
|
65 |
+
|
66 |
+
super(Conv3DNoTemporal, self).__init__(
|
67 |
+
in_channels=in_planes,
|
68 |
+
out_channels=out_planes,
|
69 |
+
kernel_size=(1, 3, 3),
|
70 |
+
stride=(1, stride, stride),
|
71 |
+
padding=(0, padding, padding),
|
72 |
+
bias=False)
|
73 |
+
|
74 |
+
@staticmethod
|
75 |
+
def get_downsample_stride(stride):
|
76 |
+
return 1, stride, stride
|
77 |
+
|
78 |
+
|
79 |
+
class BasicBlock(nn.Module):
|
80 |
+
|
81 |
+
expansion = 1
|
82 |
+
|
83 |
+
def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None):
|
84 |
+
midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes)
|
85 |
+
|
86 |
+
super(BasicBlock, self).__init__()
|
87 |
+
self.conv1 = nn.Sequential(
|
88 |
+
conv_builder(inplanes, planes, midplanes, stride),
|
89 |
+
nn.BatchNorm3d(planes),
|
90 |
+
nn.ReLU(inplace=True)
|
91 |
+
)
|
92 |
+
self.conv2 = nn.Sequential(
|
93 |
+
conv_builder(planes, planes, midplanes),
|
94 |
+
nn.BatchNorm3d(planes)
|
95 |
+
)
|
96 |
+
self.relu = nn.ReLU(inplace=True)
|
97 |
+
self.downsample = downsample
|
98 |
+
self.stride = stride
|
99 |
+
|
100 |
+
def forward(self, x):
|
101 |
+
residual = x
|
102 |
+
|
103 |
+
out = self.conv1(x)
|
104 |
+
out = self.conv2(out)
|
105 |
+
if self.downsample is not None:
|
106 |
+
residual = self.downsample(x)
|
107 |
+
|
108 |
+
out += residual
|
109 |
+
out = self.relu(out)
|
110 |
+
|
111 |
+
return out
|
112 |
+
|
113 |
+
|
114 |
+
class Bottleneck(nn.Module):
|
115 |
+
expansion = 4
|
116 |
+
|
117 |
+
def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None):
|
118 |
+
|
119 |
+
super(Bottleneck, self).__init__()
|
120 |
+
midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes)
|
121 |
+
|
122 |
+
# 1x1x1
|
123 |
+
self.conv1 = nn.Sequential(
|
124 |
+
nn.Conv3d(inplanes, planes, kernel_size=1, bias=False),
|
125 |
+
nn.BatchNorm3d(planes),
|
126 |
+
nn.ReLU(inplace=True)
|
127 |
+
)
|
128 |
+
# Second kernel
|
129 |
+
self.conv2 = nn.Sequential(
|
130 |
+
conv_builder(planes, planes, midplanes, stride),
|
131 |
+
nn.BatchNorm3d(planes),
|
132 |
+
nn.ReLU(inplace=True)
|
133 |
+
)
|
134 |
+
|
135 |
+
# 1x1x1
|
136 |
+
self.conv3 = nn.Sequential(
|
137 |
+
nn.Conv3d(planes, planes * self.expansion, kernel_size=1, bias=False),
|
138 |
+
nn.BatchNorm3d(planes * self.expansion)
|
139 |
+
)
|
140 |
+
self.relu = nn.ReLU(inplace=True)
|
141 |
+
self.downsample = downsample
|
142 |
+
self.stride = stride
|
143 |
+
|
144 |
+
def forward(self, x):
|
145 |
+
residual = x
|
146 |
+
|
147 |
+
out = self.conv1(x)
|
148 |
+
out = self.conv2(out)
|
149 |
+
out = self.conv3(out)
|
150 |
+
|
151 |
+
if self.downsample is not None:
|
152 |
+
residual = self.downsample(x)
|
153 |
+
|
154 |
+
out += residual
|
155 |
+
out = self.relu(out)
|
156 |
+
|
157 |
+
return out
|
158 |
+
|
159 |
+
|
160 |
+
class BasicStem(nn.Sequential):
|
161 |
+
"""The default conv-batchnorm-relu stem
|
162 |
+
"""
|
163 |
+
def __init__(self):
|
164 |
+
super(BasicStem, self).__init__(
|
165 |
+
nn.Conv3d(1, 64, kernel_size=(7, 7, 7), stride=(2, 2, 2),
|
166 |
+
padding=(3, 3, 3), bias=False),
|
167 |
+
nn.BatchNorm3d(64),
|
168 |
+
nn.ReLU(inplace=True))
|
169 |
+
|
170 |
+
|
171 |
+
class R2Plus1dStem(nn.Sequential):
|
172 |
+
"""R(2+1)D stem is different than the default one as it uses separated 3D convolution
|
173 |
+
"""
|
174 |
+
def __init__(self):
|
175 |
+
super(R2Plus1dStem, self).__init__(
|
176 |
+
nn.Conv3d(3, 45, kernel_size=(1, 7, 7),
|
177 |
+
stride=(1, 2, 2), padding=(0, 3, 3),
|
178 |
+
bias=False),
|
179 |
+
nn.BatchNorm3d(45),
|
180 |
+
nn.ReLU(inplace=True),
|
181 |
+
nn.Conv3d(45, 64, kernel_size=(3, 1, 1),
|
182 |
+
stride=(1, 1, 1), padding=(1, 0, 0),
|
183 |
+
bias=False),
|
184 |
+
nn.BatchNorm3d(64),
|
185 |
+
nn.ReLU(inplace=True))
|
186 |
+
|
187 |
+
|
188 |
+
class VideoResNet(nn.Module):
|
189 |
+
|
190 |
+
def __init__(self, block, conv_makers, layers,
|
191 |
+
stem, num_classes=16,
|
192 |
+
zero_init_residual=False):
|
193 |
+
"""Generic resnet video generator.
|
194 |
+
Args:
|
195 |
+
block (nn.Module): resnet building block
|
196 |
+
conv_makers (list(functions)): generator function for each layer
|
197 |
+
layers (List[int]): number of blocks per layer
|
198 |
+
stem (nn.Module, optional): Resnet stem, if None, defaults to conv-bn-relu. Defaults to None.
|
199 |
+
num_classes (int, optional): Dimension of the final FC layer. Defaults to 400.
|
200 |
+
zero_init_residual (bool, optional): Zero init bottleneck residual BN. Defaults to False.
|
201 |
+
"""
|
202 |
+
super(VideoResNet, self).__init__()
|
203 |
+
self.inplanes = 64
|
204 |
+
|
205 |
+
self.stem = stem()
|
206 |
+
|
207 |
+
self.layer1 = self._make_layer(block, conv_makers[0], 64, layers[0], stride=1)
|
208 |
+
self.layer2 = self._make_layer(block, conv_makers[1], 128, layers[1], stride=2)
|
209 |
+
self.layer3 = self._make_layer(block, conv_makers[2], 192, layers[2], stride=2)
|
210 |
+
self.layer4 = self._make_layer(block, conv_makers[3], 256, layers[3], stride=2)
|
211 |
+
|
212 |
+
self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))
|
213 |
+
self.fc = nn.Linear(256 * block.expansion, num_classes)
|
214 |
+
|
215 |
+
# init weights
|
216 |
+
self._initialize_weights()
|
217 |
+
|
218 |
+
if zero_init_residual:
|
219 |
+
for m in self.modules():
|
220 |
+
if isinstance(m, Bottleneck):
|
221 |
+
nn.init.constant_(m.bn3.weight, 0)
|
222 |
+
|
223 |
+
def forward(self, x):
|
224 |
+
x = self.stem(x)
|
225 |
+
|
226 |
+
x = self.layer1(x)
|
227 |
+
x = self.layer2(x)
|
228 |
+
x = self.layer3(x)
|
229 |
+
x = self.layer4(x)
|
230 |
+
|
231 |
+
x = self.avgpool(x)
|
232 |
+
# Flatten the layer to fc
|
233 |
+
x = x.flatten(1)
|
234 |
+
x = self.fc(x)
|
235 |
+
|
236 |
+
return x
|
237 |
+
|
238 |
+
def _make_layer(self, block, conv_builder, planes, blocks, stride=1):
|
239 |
+
downsample = None
|
240 |
+
|
241 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
242 |
+
ds_stride = conv_builder.get_downsample_stride(stride)
|
243 |
+
downsample = nn.Sequential(
|
244 |
+
nn.Conv3d(self.inplanes, planes * block.expansion,
|
245 |
+
kernel_size=1, stride=ds_stride, bias=False),
|
246 |
+
nn.BatchNorm3d(planes * block.expansion)
|
247 |
+
)
|
248 |
+
layers = []
|
249 |
+
layers.append(block(self.inplanes, planes, conv_builder, stride, downsample))
|
250 |
+
|
251 |
+
self.inplanes = planes * block.expansion
|
252 |
+
for i in range(1, blocks):
|
253 |
+
layers.append(block(self.inplanes, planes, conv_builder))
|
254 |
+
|
255 |
+
return nn.Sequential(*layers)
|
256 |
+
|
257 |
+
def _initialize_weights(self):
|
258 |
+
for m in self.modules():
|
259 |
+
if isinstance(m, nn.Conv3d):
|
260 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out',
|
261 |
+
nonlinearity='relu')
|
262 |
+
if m.bias is not None:
|
263 |
+
nn.init.constant_(m.bias, 0)
|
264 |
+
elif isinstance(m, nn.BatchNorm3d):
|
265 |
+
nn.init.constant_(m.weight, 1)
|
266 |
+
nn.init.constant_(m.bias, 0)
|
267 |
+
elif isinstance(m, nn.Linear):
|
268 |
+
nn.init.normal_(m.weight, 0, 0.01)
|
269 |
+
nn.init.constant_(m.bias, 0)
|
270 |
+
|
271 |
+
|
272 |
+
def _video_resnet(arch, pretrained=False, progress=True, **kwargs):
|
273 |
+
model = VideoResNet(**kwargs)
|
274 |
+
|
275 |
+
return model
|
276 |
+
|
277 |
+
|
278 |
+
def r3d_18(pretrained=False, progress=True, **kwargs):
|
279 |
+
"""Construct 18 layer Resnet3D model as in
|
280 |
+
https://arxiv.org/abs/1711.11248
|
281 |
+
Args:
|
282 |
+
pretrained (bool): If True, returns a model pre-trained on Kinetics-400
|
283 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
284 |
+
Returns:
|
285 |
+
nn.Module: R3D-18 network
|
286 |
+
"""
|
287 |
+
|
288 |
+
return _video_resnet('r3d_18',
|
289 |
+
pretrained, progress,
|
290 |
+
block=BasicBlock,
|
291 |
+
conv_makers=[Conv3DSimple] * 4,
|
292 |
+
layers=[2, 2, 2, 2],
|
293 |
+
stem=BasicStem, **kwargs)
|
294 |
+
|
295 |
+
|
296 |
+
def mc3_18(pretrained=False, progress=True, **kwargs):
|
297 |
+
"""Constructor for 18 layer Mixed Convolution network as in
|
298 |
+
https://arxiv.org/abs/1711.11248
|
299 |
+
Args:
|
300 |
+
pretrained (bool): If True, returns a model pre-trained on Kinetics-400
|
301 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
302 |
+
Returns:
|
303 |
+
nn.Module: MC3 Network definition
|
304 |
+
"""
|
305 |
+
return _video_resnet('mc3_18',
|
306 |
+
pretrained, progress,
|
307 |
+
block=BasicBlock,
|
308 |
+
conv_makers=[Conv3DSimple] + [Conv3DNoTemporal] * 3,
|
309 |
+
layers=[2, 2, 2, 2],
|
310 |
+
stem=BasicStem, **kwargs)
|
311 |
+
|
312 |
+
|
313 |
+
def r2plus1d_18(pretrained=False, progress=True, **kwargs):
|
314 |
+
"""Constructor for the 18 layer deep R(2+1)D network as in
|
315 |
+
https://arxiv.org/abs/1711.11248
|
316 |
+
Args:
|
317 |
+
pretrained (bool): If True, returns a model pre-trained on Kinetics-400
|
318 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
319 |
+
Returns:
|
320 |
+
nn.Module: R(2+1)D-18 network
|
321 |
+
"""
|
322 |
+
return _video_resnet('r2plus1d_18',
|
323 |
+
pretrained, progress,
|
324 |
+
block=BasicBlock,
|
325 |
+
conv_makers=[Conv2Plus1D] * 4,
|
326 |
+
layers=[2, 2, 2, 2],
|
327 |
+
stem=R2Plus1dStem, **kwargs)
|
328 |
+
|
329 |
+
|
330 |
+
if __name__ == '__main__':
|
331 |
+
|
332 |
+
import torch
|
333 |
+
|
334 |
+
net = r3d_18().to(0)
|
335 |
+
x = torch.zeros(3, 1, 182, 218, 182).to(0)
|
336 |
+
|
337 |
+
print(net(x).shape)
|
338 |
+
print(net)
|
adrd/nn/resnet3d.py
ADDED
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Simplified from torchvision.models.video.r3d_18. The citation information is
|
3 |
+
shown below.
|
4 |
+
|
5 |
+
@article{DBLP:journals/corr/abs-1711-11248,
|
6 |
+
author = {Du Tran and
|
7 |
+
Heng Wang and
|
8 |
+
Lorenzo Torresani and
|
9 |
+
Jamie Ray and
|
10 |
+
Yann LeCun and
|
11 |
+
Manohar Paluri},
|
12 |
+
title = {A Closer Look at Spatiotemporal Convolutions for Action Recognition},
|
13 |
+
journal = {CoRR},
|
14 |
+
volume = {abs/1711.11248},
|
15 |
+
year = {2017},
|
16 |
+
url = {http://arxiv.org/abs/1711.11248},
|
17 |
+
archivePrefix = {arXiv},
|
18 |
+
eprint = {1711.11248},
|
19 |
+
timestamp = {Mon, 13 Aug 2018 16:46:39 +0200},
|
20 |
+
biburl = {https://dblp.org/rec/journals/corr/abs-1711-11248.bib},
|
21 |
+
bibsource = {dblp computer science bibliography, https://dblp.org}
|
22 |
+
}
|
23 |
+
"""
|
24 |
+
|
25 |
+
import torch.nn as nn
|
26 |
+
|
27 |
+
|
28 |
+
class Conv3DSimple(nn.Conv3d):
|
29 |
+
def __init__(self,
|
30 |
+
in_planes,
|
31 |
+
out_planes,
|
32 |
+
midplanes=None,
|
33 |
+
stride=1,
|
34 |
+
padding=1):
|
35 |
+
|
36 |
+
super().__init__(
|
37 |
+
in_channels=in_planes,
|
38 |
+
out_channels=out_planes,
|
39 |
+
kernel_size=(3, 3, 3),
|
40 |
+
stride=stride,
|
41 |
+
padding=padding,
|
42 |
+
bias=False)
|
43 |
+
|
44 |
+
@staticmethod
|
45 |
+
def get_downsample_stride(stride):
|
46 |
+
return stride, stride, stride
|
47 |
+
|
48 |
+
|
49 |
+
class BasicBlock(nn.Module):
|
50 |
+
|
51 |
+
expansion = 1
|
52 |
+
|
53 |
+
def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None):
|
54 |
+
midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes)
|
55 |
+
|
56 |
+
super(BasicBlock, self).__init__()
|
57 |
+
self.conv1 = nn.Sequential(
|
58 |
+
conv_builder(inplanes, planes, midplanes, stride),
|
59 |
+
nn.BatchNorm3d(planes),
|
60 |
+
nn.ReLU(inplace=True)
|
61 |
+
)
|
62 |
+
self.conv2 = nn.Sequential(
|
63 |
+
conv_builder(planes, planes, midplanes),
|
64 |
+
nn.BatchNorm3d(planes)
|
65 |
+
)
|
66 |
+
self.relu = nn.ReLU(inplace=True)
|
67 |
+
self.downsample = downsample
|
68 |
+
self.stride = stride
|
69 |
+
|
70 |
+
def forward(self, x):
|
71 |
+
residual = x
|
72 |
+
|
73 |
+
out = self.conv1(x)
|
74 |
+
out = self.conv2(out)
|
75 |
+
if self.downsample is not None:
|
76 |
+
residual = self.downsample(x)
|
77 |
+
|
78 |
+
out += residual
|
79 |
+
out = self.relu(out)
|
80 |
+
|
81 |
+
return out
|
82 |
+
|
83 |
+
|
84 |
+
class Bottleneck(nn.Module):
|
85 |
+
expansion = 4
|
86 |
+
|
87 |
+
def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None):
|
88 |
+
|
89 |
+
super(Bottleneck, self).__init__()
|
90 |
+
midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes)
|
91 |
+
|
92 |
+
# 1x1x1
|
93 |
+
self.conv1 = nn.Sequential(
|
94 |
+
nn.Conv3d(inplanes, planes, kernel_size=1, bias=False),
|
95 |
+
nn.BatchNorm3d(planes),
|
96 |
+
nn.ReLU(inplace=True)
|
97 |
+
)
|
98 |
+
# Second kernel
|
99 |
+
self.conv2 = nn.Sequential(
|
100 |
+
conv_builder(planes, planes, midplanes, stride),
|
101 |
+
nn.BatchNorm3d(planes),
|
102 |
+
nn.ReLU(inplace=True)
|
103 |
+
)
|
104 |
+
|
105 |
+
# 1x1x1
|
106 |
+
self.conv3 = nn.Sequential(
|
107 |
+
nn.Conv3d(planes, planes * self.expansion, kernel_size=1, bias=False),
|
108 |
+
nn.BatchNorm3d(planes * self.expansion)
|
109 |
+
)
|
110 |
+
self.relu = nn.ReLU(inplace=True)
|
111 |
+
self.downsample = downsample
|
112 |
+
self.stride = stride
|
113 |
+
|
114 |
+
def forward(self, x):
|
115 |
+
residual = x
|
116 |
+
|
117 |
+
out = self.conv1(x)
|
118 |
+
out = self.conv2(out)
|
119 |
+
out = self.conv3(out)
|
120 |
+
|
121 |
+
if self.downsample is not None:
|
122 |
+
residual = self.downsample(x)
|
123 |
+
|
124 |
+
out += residual
|
125 |
+
out = self.relu(out)
|
126 |
+
|
127 |
+
return out
|
128 |
+
|
129 |
+
|
130 |
+
class BasicStem(nn.Sequential):
|
131 |
+
"""The default conv-batchnorm-relu stem
|
132 |
+
"""
|
133 |
+
def __init__(self):
|
134 |
+
super(BasicStem, self).__init__(
|
135 |
+
nn.Conv3d(1, 64, kernel_size=(7, 7, 7), stride=(2, 2, 2),
|
136 |
+
padding=(3, 3, 3), bias=False),
|
137 |
+
nn.BatchNorm3d(64),
|
138 |
+
nn.ReLU(inplace=True))
|
139 |
+
|
140 |
+
|
141 |
+
class VideoResNet(nn.Module):
|
142 |
+
|
143 |
+
def __init__(self, block, conv_makers, layers,
|
144 |
+
stem, num_classes=2,
|
145 |
+
zero_init_residual=False):
|
146 |
+
"""Generic resnet video generator.
|
147 |
+
Args:
|
148 |
+
block (nn.Module): resnet building block
|
149 |
+
conv_makers (list(functions)): generator function for each layer
|
150 |
+
layers (List[int]): number of blocks per layer
|
151 |
+
stem (nn.Module, optional): Resnet stem, if None, defaults to conv-bn-relu. Defaults to None.
|
152 |
+
num_classes (int, optional): Dimension of the final FC layer. Defaults to 400.
|
153 |
+
zero_init_residual (bool, optional): Zero init bottleneck residual BN. Defaults to False.
|
154 |
+
"""
|
155 |
+
super(VideoResNet, self).__init__()
|
156 |
+
self.inplanes = 64
|
157 |
+
|
158 |
+
self.stem = stem()
|
159 |
+
|
160 |
+
self.layer1 = self._make_layer(block, conv_makers[0], 64, layers[0], stride=1)
|
161 |
+
self.layer2 = self._make_layer(block, conv_makers[1], 128, layers[1], stride=2)
|
162 |
+
self.layer3 = self._make_layer(block, conv_makers[2], 192, layers[2], stride=2)
|
163 |
+
self.layer4 = self._make_layer(block, conv_makers[3], 256, layers[3], stride=2)
|
164 |
+
|
165 |
+
self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))
|
166 |
+
# self.fc = nn.Linear(256 * block.expansion, num_classes)
|
167 |
+
|
168 |
+
# init weights
|
169 |
+
self._initialize_weights()
|
170 |
+
|
171 |
+
if zero_init_residual:
|
172 |
+
for m in self.modules():
|
173 |
+
if isinstance(m, Bottleneck):
|
174 |
+
nn.init.constant_(m.bn3.weight, 0)
|
175 |
+
|
176 |
+
def forward(self, x):
|
177 |
+
x = self.stem(x)
|
178 |
+
|
179 |
+
x = self.layer1(x)
|
180 |
+
x = self.layer2(x)
|
181 |
+
x = self.layer3(x)
|
182 |
+
x = self.layer4(x)
|
183 |
+
|
184 |
+
x = self.avgpool(x)
|
185 |
+
# Flatten the layer to fc
|
186 |
+
x = x.flatten(1)
|
187 |
+
# x = self.fc(x)
|
188 |
+
|
189 |
+
return x
|
190 |
+
|
191 |
+
def _make_layer(self, block, conv_builder, planes, blocks, stride=1):
|
192 |
+
downsample = None
|
193 |
+
|
194 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
195 |
+
ds_stride = conv_builder.get_downsample_stride(stride)
|
196 |
+
downsample = nn.Sequential(
|
197 |
+
nn.Conv3d(self.inplanes, planes * block.expansion,
|
198 |
+
kernel_size=1, stride=ds_stride, bias=False),
|
199 |
+
nn.BatchNorm3d(planes * block.expansion)
|
200 |
+
)
|
201 |
+
layers = []
|
202 |
+
layers.append(block(self.inplanes, planes, conv_builder, stride, downsample))
|
203 |
+
|
204 |
+
self.inplanes = planes * block.expansion
|
205 |
+
for i in range(1, blocks):
|
206 |
+
layers.append(block(self.inplanes, planes, conv_builder))
|
207 |
+
|
208 |
+
return nn.Sequential(*layers)
|
209 |
+
|
210 |
+
def _initialize_weights(self):
|
211 |
+
for m in self.modules():
|
212 |
+
if isinstance(m, nn.Conv3d):
|
213 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out',
|
214 |
+
nonlinearity='relu')
|
215 |
+
if m.bias is not None:
|
216 |
+
nn.init.constant_(m.bias, 0)
|
217 |
+
elif isinstance(m, nn.BatchNorm3d):
|
218 |
+
nn.init.constant_(m.weight, 1)
|
219 |
+
nn.init.constant_(m.bias, 0)
|
220 |
+
elif isinstance(m, nn.Linear):
|
221 |
+
nn.init.normal_(m.weight, 0, 0.01)
|
222 |
+
nn.init.constant_(m.bias, 0)
|
223 |
+
|
224 |
+
|
225 |
+
def _video_resnet(arch, pretrained=False, progress=True, **kwargs):
|
226 |
+
model = VideoResNet(**kwargs)
|
227 |
+
|
228 |
+
return model
|
229 |
+
|
230 |
+
|
231 |
+
def r3d_18(pretrained=False, progress=True, **kwargs):
|
232 |
+
"""Construct 18 layer Resnet3D model as in
|
233 |
+
https://arxiv.org/abs/1711.11248
|
234 |
+
Args:
|
235 |
+
pretrained (bool): If True, returns a model pre-trained on Kinetics-400
|
236 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
237 |
+
Returns:
|
238 |
+
nn.Module: R3D-18 network
|
239 |
+
"""
|
240 |
+
|
241 |
+
return _video_resnet('r3d_18',
|
242 |
+
pretrained, progress,
|
243 |
+
block=BasicBlock,
|
244 |
+
conv_makers=[Conv3DSimple] * 4,
|
245 |
+
layers=[2, 2, 2, 2],
|
246 |
+
stem=BasicStem, **kwargs)
|
247 |
+
|
248 |
+
|
249 |
+
if __name__ == '__main__':
|
250 |
+
""" ... """
|
251 |
+
import torch
|
252 |
+
|
253 |
+
net = r3d_18().to('cuda:1')
|
254 |
+
x = torch.zeros(8, 1, 182, 218, 182).to('cuda:1')
|
255 |
+
|
256 |
+
print(net(x).shape)
|
adrd/nn/resnet_img_model.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import sys
|
4 |
+
from icecream import ic
|
5 |
+
# sys.path.append('/home/skowshik/ADRD_repo/adrd_tool/adrd/')
|
6 |
+
from .net_resnet3d import r3d_18
|
7 |
+
# from dev.data.dataset_csv import CSVDataset
|
8 |
+
|
9 |
+
|
10 |
+
class ResNetModel(nn.Module):
|
11 |
+
''' ... '''
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
tgt_modalities,
|
15 |
+
mri_feature = 'img_MRI_T1',
|
16 |
+
):
|
17 |
+
''' ... '''
|
18 |
+
super().__init__()
|
19 |
+
|
20 |
+
self.mri_feature = mri_feature
|
21 |
+
|
22 |
+
self.img_net_ = r3d_18()
|
23 |
+
|
24 |
+
# self.modules_emb_src = nn.Sequential(
|
25 |
+
# nn.BatchNorm1d(9),
|
26 |
+
# nn.Linear(9, d_model)
|
27 |
+
# )
|
28 |
+
|
29 |
+
# classifiers (binary only)
|
30 |
+
self.modules_cls = nn.ModuleDict()
|
31 |
+
for k, info in tgt_modalities.items():
|
32 |
+
if info['type'] == 'categorical' and info['num_categories'] == 2:
|
33 |
+
# categorical
|
34 |
+
self.modules_cls[k] = nn.Linear(64, 1)
|
35 |
+
|
36 |
+
else:
|
37 |
+
# unrecognized
|
38 |
+
raise ValueError
|
39 |
+
|
40 |
+
def forward(self, x):
|
41 |
+
''' ... '''
|
42 |
+
tgt_iter = self.modules_cls.keys()
|
43 |
+
|
44 |
+
img_x_batch = x[self.mri_feature]
|
45 |
+
img_out = self.img_net_(img_x_batch)
|
46 |
+
|
47 |
+
# ic(img_out.shape)
|
48 |
+
|
49 |
+
# run linear classifiers
|
50 |
+
out = [self.modules_cls[k](img_out).squeeze(1) for i, k in enumerate(tgt_iter)]
|
51 |
+
out = torch.stack(out, dim=1)
|
52 |
+
|
53 |
+
# ic(out.shape)
|
54 |
+
|
55 |
+
# out to dict
|
56 |
+
out = {k: out[:, i] for i, k in enumerate(tgt_iter)}
|
57 |
+
|
58 |
+
return out
|
59 |
+
|
60 |
+
|
61 |
+
if __name__ == '__main__':
|
62 |
+
''' for testing purpose only '''
|
63 |
+
# import torch
|
64 |
+
# import numpy as np
|
65 |
+
|
66 |
+
# seed = 0
|
67 |
+
# print('Loading training dataset ... ')
|
68 |
+
# dat_trn = CSVDataset(mode=0, split=[1, 700], seed=seed)
|
69 |
+
# print(len(dat_trn))
|
70 |
+
# tgt_modalities = dat_trn.label_modalities
|
71 |
+
# net = ResNetModel(tgt_modalities).to('cuda')
|
72 |
+
# x = dat_trn.features
|
73 |
+
# x = {k: torch.as_tensor(np.array([x[i][k] for i in range(len(x))])).to('cuda') for k in x[0]}
|
74 |
+
# ic(x)
|
75 |
+
|
76 |
+
|
77 |
+
# # print(net(x).shape)
|
78 |
+
# print(net(x))
|
79 |
+
|
80 |
+
|
81 |
+
|
adrd/nn/selfattention.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) MONAI Consortium
|
2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6 |
+
# Unless required by applicable law or agreed to in writing, software
|
7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9 |
+
# See the License for the specific language governing permissions and
|
10 |
+
# limitations under the License.
|
11 |
+
from monai.utils import optional_import
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
|
15 |
+
|
16 |
+
Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")
|
17 |
+
|
18 |
+
|
19 |
+
class SABlock(nn.Module):
|
20 |
+
"""
|
21 |
+
A self-attention block, based on: "Dosovitskiy et al.,
|
22 |
+
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __init__(self, hidden_size: int, num_heads: int, dropout_rate: float = 0.0, qkv_bias: bool = False) -> None:
|
26 |
+
"""
|
27 |
+
Args:
|
28 |
+
hidden_size: dimension of hidden layer.
|
29 |
+
num_heads: number of attention heads.
|
30 |
+
dropout_rate: faction of the input units to drop.
|
31 |
+
qkv_bias: bias term for the qkv linear layer.
|
32 |
+
|
33 |
+
"""
|
34 |
+
|
35 |
+
super().__init__()
|
36 |
+
|
37 |
+
if not (0 <= dropout_rate <= 1):
|
38 |
+
raise ValueError("dropout_rate should be between 0 and 1.")
|
39 |
+
|
40 |
+
if hidden_size % num_heads != 0:
|
41 |
+
raise ValueError("hidden size should be divisible by num_heads.")
|
42 |
+
|
43 |
+
self.num_heads = num_heads
|
44 |
+
self.out_proj = nn.Linear(hidden_size, hidden_size)
|
45 |
+
self.qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias)
|
46 |
+
self.input_rearrange = Rearrange("b h (qkv l d) -> qkv b l h d", qkv=3, l=num_heads)
|
47 |
+
self.out_rearrange = Rearrange("b h l d -> b l (h d)")
|
48 |
+
self.drop_output = nn.Dropout(dropout_rate)
|
49 |
+
self.drop_weights = nn.Dropout(dropout_rate)
|
50 |
+
self.head_dim = hidden_size // num_heads
|
51 |
+
self.scale = self.head_dim**-0.5
|
52 |
+
|
53 |
+
def forward(self, x):
|
54 |
+
output = self.input_rearrange(self.qkv(x))
|
55 |
+
q, k, v = output[0], output[1], output[2]
|
56 |
+
att_mat = (torch.einsum("blxd,blyd->blxy", q, k) * self.scale).softmax(dim=-1)
|
57 |
+
att_mat = self.drop_weights(att_mat)
|
58 |
+
x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v)
|
59 |
+
x = self.out_rearrange(x)
|
60 |
+
x = self.out_proj(x)
|
61 |
+
x = self.drop_output(x)
|
62 |
+
return x, att_mat
|
adrd/nn/transformer.py
ADDED
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from .. import nn
|
4 |
+
# from ..nn import ImagingModelWrapper
|
5 |
+
from .net_resnet3d import r3d_18
|
6 |
+
from typing import Any, Type
|
7 |
+
import math
|
8 |
+
Tensor = Type[torch.Tensor]
|
9 |
+
from icecream import ic
|
10 |
+
ic.disable()
|
11 |
+
|
12 |
+
class Transformer(torch.nn.Module):
|
13 |
+
''' ... '''
|
14 |
+
def __init__(self,
|
15 |
+
src_modalities: dict[str, dict[str, Any]],
|
16 |
+
tgt_modalities: dict[str, dict[str, Any]],
|
17 |
+
d_model: int,
|
18 |
+
nhead: int,
|
19 |
+
num_encoder_layers: int = 1,
|
20 |
+
num_decoder_layers: int = 1,
|
21 |
+
device: str = 'cpu',
|
22 |
+
cuda_devices: list = [3],
|
23 |
+
img_net: str | None = None,
|
24 |
+
layers: int = 3,
|
25 |
+
img_size: int | None = 128,
|
26 |
+
patch_size: int | None = 16,
|
27 |
+
imgnet_ckpt: str | None = None,
|
28 |
+
train_imgnet: bool = False,
|
29 |
+
fusion_stage: str = 'middle',
|
30 |
+
) -> None:
|
31 |
+
''' ... '''
|
32 |
+
super().__init__()
|
33 |
+
|
34 |
+
self.d_model = d_model
|
35 |
+
self.nhead = nhead
|
36 |
+
self.num_encoder_layers = num_encoder_layers
|
37 |
+
self.num_decoder_layers = num_decoder_layers
|
38 |
+
self.img_net = img_net
|
39 |
+
self.img_size = img_size
|
40 |
+
self.patch_size = patch_size
|
41 |
+
self.imgnet_ckpt = imgnet_ckpt
|
42 |
+
self.train_imgnet = train_imgnet
|
43 |
+
self.layers = layers
|
44 |
+
self.src_modalities = src_modalities
|
45 |
+
self.tgt_modalities = tgt_modalities
|
46 |
+
self.device = device
|
47 |
+
self.fusion_stage = fusion_stage
|
48 |
+
|
49 |
+
# embedding modules for source
|
50 |
+
|
51 |
+
self.modules_emb_src = torch.nn.ModuleDict()
|
52 |
+
print('Downsample layers: ', self.layers)
|
53 |
+
self.img_model = nn.ImagingModelWrapper(arch=self.img_net, img_size=self.img_size, patch_size=self.patch_size, ckpt_path=self.imgnet_ckpt, train_backbone=self.train_imgnet, layers=self.layers, out_dim=self.d_model, device=self.device, fusion_stage=self.fusion_stage)
|
54 |
+
|
55 |
+
for k, info in src_modalities.items():
|
56 |
+
# ic(k)
|
57 |
+
# for key, val in info.items():
|
58 |
+
# ic(key, val)
|
59 |
+
if info['type'] == 'categorical':
|
60 |
+
self.modules_emb_src[k] = torch.nn.Embedding(info['num_categories'], d_model)
|
61 |
+
elif info['type'] == 'numerical':
|
62 |
+
self.modules_emb_src[k] = torch.nn.Sequential(
|
63 |
+
torch.nn.BatchNorm1d(info['shape'][0]),
|
64 |
+
torch.nn.Linear(info['shape'][0], d_model)
|
65 |
+
)
|
66 |
+
elif info['type'] == 'imaging':
|
67 |
+
# print(info['shape'], info['img_shape'])
|
68 |
+
if self.img_net:
|
69 |
+
self.modules_emb_src[k] = self.img_model
|
70 |
+
|
71 |
+
else:
|
72 |
+
# unrecognized
|
73 |
+
raise ValueError('{} is an unrecognized data modality'.format(k))
|
74 |
+
|
75 |
+
# positional encoding
|
76 |
+
self.pe = PositionalEncoding(d_model)
|
77 |
+
|
78 |
+
# auxiliary embedding vectors for targets
|
79 |
+
self.emb_aux = torch.nn.Parameter(
|
80 |
+
torch.zeros(len(tgt_modalities), 1, d_model),
|
81 |
+
requires_grad = True,
|
82 |
+
)
|
83 |
+
|
84 |
+
# transformer
|
85 |
+
enc = torch.nn.TransformerEncoderLayer(
|
86 |
+
self.d_model, self.nhead,
|
87 |
+
dim_feedforward = self.d_model,
|
88 |
+
activation = 'gelu',
|
89 |
+
dropout = 0.3,
|
90 |
+
)
|
91 |
+
self.transformer = torch.nn.TransformerEncoder(enc, self.num_encoder_layers)
|
92 |
+
|
93 |
+
|
94 |
+
# classifiers (binary only)
|
95 |
+
self.modules_cls = torch.nn.ModuleDict()
|
96 |
+
for k, info in tgt_modalities.items():
|
97 |
+
if info['type'] == 'categorical' and info['num_categories'] == 2:
|
98 |
+
self.modules_cls[k] = torch.nn.Linear(d_model, 1)
|
99 |
+
else:
|
100 |
+
# unrecognized
|
101 |
+
raise ValueError
|
102 |
+
|
103 |
+
# for n,p in self.named_parameters():
|
104 |
+
# print(n, p.requires_grad)
|
105 |
+
|
106 |
+
def forward(self,
|
107 |
+
x: dict[str, Tensor],
|
108 |
+
mask: dict[str, Tensor],
|
109 |
+
# x_img: dict[str, Tensor] | Any = None,
|
110 |
+
skip_embedding: dict[str, bool] | None = None,
|
111 |
+
return_out_emb: bool = False,
|
112 |
+
) -> dict[str, Tensor]:
|
113 |
+
""" ... """
|
114 |
+
|
115 |
+
out_emb = self.forward_emb(x, mask, skip_embedding)
|
116 |
+
if self.fusion_stage == "late":
|
117 |
+
out_emb = {k: v for k,v in out_emb.items() if "img_MRI" not in k}
|
118 |
+
img_out_emb = {k: v for k,v in out_emb.items() if "img_MRI" in k}
|
119 |
+
# for k,v in out_emb.items():
|
120 |
+
# print(k, v.size())
|
121 |
+
mask_nonimg = {k: v for k,v in mask.items() if "img_MRI" not in k}
|
122 |
+
out_trf = self.forward_trf(out_emb, mask_nonimg) # (8,128) + (8,50,128)
|
123 |
+
# print("out_trf: ", out_trf.size())
|
124 |
+
out_trf = torch.concatenate()
|
125 |
+
else:
|
126 |
+
out_trf = self.forward_trf(out_emb, mask)
|
127 |
+
|
128 |
+
out_cls = self.forward_cls(out_trf)
|
129 |
+
|
130 |
+
if return_out_emb:
|
131 |
+
return out_emb, out_cls
|
132 |
+
return out_cls
|
133 |
+
|
134 |
+
def forward_emb(self,
|
135 |
+
x: dict[str, Tensor],
|
136 |
+
mask: dict[str, Tensor],
|
137 |
+
skip_embedding: dict[str, bool] | None = None,
|
138 |
+
# x_img: dict[str, Tensor] | Any = None,
|
139 |
+
) -> dict[str, Tensor]:
|
140 |
+
""" ... """
|
141 |
+
# print("-------forward_emb--------")
|
142 |
+
out_emb = dict()
|
143 |
+
for k in self.modules_emb_src.keys():
|
144 |
+
if skip_embedding is None or k not in skip_embedding or not skip_embedding[k]:
|
145 |
+
if "img_MRI" in k:
|
146 |
+
# print("img_MRI in ", k)
|
147 |
+
if torch.all(mask[k]):
|
148 |
+
if "swinunetr" in self.img_net.lower() and self.fusion_stage == 'late':
|
149 |
+
out_emb[k] = torch.zeros((1,768,4,4,4))
|
150 |
+
else:
|
151 |
+
if 'cuda' in self.device:
|
152 |
+
device = x[k].device
|
153 |
+
# print(device)
|
154 |
+
else:
|
155 |
+
device = self.device
|
156 |
+
out_emb[k] = torch.zeros((mask[k].shape[0], self.d_model)).to(device, non_blocking=True)
|
157 |
+
# print("mask is True, out_emb[k]: ", out_emb[k].size())
|
158 |
+
else:
|
159 |
+
# print("calling modules_emb_src...")
|
160 |
+
out_emb[k] = self.modules_emb_src[k](x[k])
|
161 |
+
# print("mask is False, out_emb[k]: ", out_emb[k].size())
|
162 |
+
|
163 |
+
else:
|
164 |
+
out_emb[k] = self.modules_emb_src[k](x[k])
|
165 |
+
|
166 |
+
# out_emb[k] = self.modules_emb_src[k](x[k])
|
167 |
+
else:
|
168 |
+
out_emb[k] = x[k]
|
169 |
+
return out_emb
|
170 |
+
|
171 |
+
def forward_trf(self,
|
172 |
+
out_emb: dict[str, Tensor],
|
173 |
+
mask: dict[str, Tensor],
|
174 |
+
) -> dict[str, Tensor]:
|
175 |
+
""" ... """
|
176 |
+
# print('-----------forward_trf----------')
|
177 |
+
N = len(next(iter(out_emb.values()))) # batch size
|
178 |
+
S = len(self.modules_emb_src) # number of sources
|
179 |
+
T = len(self.modules_cls) # number of targets
|
180 |
+
if self.fusion_stage == 'late':
|
181 |
+
src_iter = [k for k in self.modules_emb_src.keys() if "img_MRI" not in k]
|
182 |
+
S = len(src_iter) # number of sources
|
183 |
+
|
184 |
+
else:
|
185 |
+
src_iter = self.modules_emb_src.keys()
|
186 |
+
tgt_iter = self.modules_cls.keys()
|
187 |
+
|
188 |
+
emb_src = torch.stack([o for o in out_emb.values()], dim=0)
|
189 |
+
# print('emb_src: ', emb_src.size())
|
190 |
+
|
191 |
+
self.pe.index = -1
|
192 |
+
emb_src = self.pe(emb_src)
|
193 |
+
# print('emb_src + pe: ', emb_src.size())
|
194 |
+
|
195 |
+
# target embedding
|
196 |
+
# print('emb_aux: ', self.emb_aux.size())
|
197 |
+
emb_tgt = self.emb_aux.repeat(1, N, 1)
|
198 |
+
# print('emb_tgt: ', emb_tgt.size())
|
199 |
+
|
200 |
+
# concatenate source embeddings and target embeddings
|
201 |
+
emb_all = torch.concatenate((emb_tgt, emb_src), dim=0)
|
202 |
+
|
203 |
+
# combine masks
|
204 |
+
mask_src = [mask[k] for k in src_iter]
|
205 |
+
mask_src = torch.stack(mask_src, dim=1)
|
206 |
+
|
207 |
+
# target masks
|
208 |
+
mask_tgt = torch.zeros((N, T), dtype=torch.bool, device=self.emb_aux.device)
|
209 |
+
|
210 |
+
# concatenate source masks and target masks
|
211 |
+
mask_all = torch.concatenate((mask_tgt, mask_src), dim=1)
|
212 |
+
|
213 |
+
# repeat mask_all to fit transformer
|
214 |
+
mask_all = mask_all.unsqueeze(1).expand(-1, S + T, -1).repeat(self.nhead, 1, 1)
|
215 |
+
|
216 |
+
# run transformer
|
217 |
+
out_trf = self.transformer(
|
218 |
+
src = emb_all,
|
219 |
+
mask = mask_all,
|
220 |
+
)[0]
|
221 |
+
# print('out_trf: ', out_trf.size())
|
222 |
+
# out_trf = {k: out_trf[i] for i, k in enumerate(tgt_iter)}
|
223 |
+
return out_trf
|
224 |
+
|
225 |
+
def forward_cls(self,
|
226 |
+
out_trf: dict[str, Tensor],
|
227 |
+
) -> dict[str, Tensor]:
|
228 |
+
""" ... """
|
229 |
+
tgt_iter = self.modules_cls.keys()
|
230 |
+
out_cls = {k: self.modules_cls[k](out_trf).squeeze(1) for k in tgt_iter}
|
231 |
+
return out_cls
|
232 |
+
|
233 |
+
class PositionalEncoding(torch.nn.Module):
|
234 |
+
|
235 |
+
def __init__(self,
|
236 |
+
d_model: int,
|
237 |
+
max_len: int = 512
|
238 |
+
):
|
239 |
+
""" ... """
|
240 |
+
super().__init__()
|
241 |
+
position = torch.arange(max_len).unsqueeze(1)
|
242 |
+
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
|
243 |
+
pe = torch.zeros(max_len, 1, d_model)
|
244 |
+
pe[:, 0, 0::2] = torch.sin(position * div_term)
|
245 |
+
pe[:, 0, 1::2] = torch.cos(position * div_term)
|
246 |
+
self.register_buffer('pe', pe)
|
247 |
+
self.index = -1
|
248 |
+
|
249 |
+
def forward(self, x: Tensor, pe_type: str = 'non_img') -> Tensor:
|
250 |
+
"""
|
251 |
+
Arguments:
|
252 |
+
x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
|
253 |
+
"""
|
254 |
+
# print('pe: ', self.pe.size())
|
255 |
+
# print('x: ', x.size())
|
256 |
+
if pe_type == 'img':
|
257 |
+
self.index += 1
|
258 |
+
return x + self.pe[self.index]
|
259 |
+
else:
|
260 |
+
self.index += 1
|
261 |
+
return x + self.pe[self.index:x.size(0)+self.index]
|
262 |
+
|
263 |
+
|
264 |
+
if __name__ == '__main__':
|
265 |
+
''' for testing purpose only '''
|
266 |
+
pass
|
267 |
+
|
268 |
+
|
adrd/nn/unet.py
ADDED
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torchvision
|
5 |
+
from torchvision import models
|
6 |
+
from torch.nn import init
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from icecream import ic
|
9 |
+
|
10 |
+
|
11 |
+
class ContBatchNorm3d(nn.modules.batchnorm._BatchNorm):
|
12 |
+
def _check_input_dim(self, input):
|
13 |
+
|
14 |
+
if input.dim() != 5:
|
15 |
+
raise ValueError('expected 5D input (got {}D input)'.format(input.dim()))
|
16 |
+
#super(ContBatchNorm3d, self)._check_input_dim(input)
|
17 |
+
|
18 |
+
def forward(self, input):
|
19 |
+
self._check_input_dim(input)
|
20 |
+
return F.batch_norm(
|
21 |
+
input, self.running_mean, self.running_var, self.weight, self.bias,
|
22 |
+
True, self.momentum, self.eps)
|
23 |
+
|
24 |
+
|
25 |
+
class LUConv(nn.Module):
|
26 |
+
def __init__(self, in_chan, out_chan, act):
|
27 |
+
super(LUConv, self).__init__()
|
28 |
+
self.conv1 = nn.Conv3d(in_chan, out_chan, kernel_size=3, padding=1)
|
29 |
+
self.bn1 = ContBatchNorm3d(out_chan)
|
30 |
+
|
31 |
+
if act == 'relu':
|
32 |
+
self.activation = nn.ReLU(out_chan)
|
33 |
+
elif act == 'prelu':
|
34 |
+
self.activation = nn.PReLU(out_chan)
|
35 |
+
elif act == 'elu':
|
36 |
+
self.activation = nn.ELU(inplace=True)
|
37 |
+
else:
|
38 |
+
raise
|
39 |
+
|
40 |
+
def forward(self, x):
|
41 |
+
out = self.activation(self.bn1(self.conv1(x)))
|
42 |
+
return out
|
43 |
+
|
44 |
+
|
45 |
+
def _make_nConv(in_channel, depth, act, double_chnnel=False):
|
46 |
+
if double_chnnel:
|
47 |
+
layer1 = LUConv(in_channel, 32 * (2 ** (depth+1)),act)
|
48 |
+
layer2 = LUConv(32 * (2 ** (depth+1)), 32 * (2 ** (depth+1)),act)
|
49 |
+
else:
|
50 |
+
layer1 = LUConv(in_channel, 32*(2**depth),act)
|
51 |
+
layer2 = LUConv(32*(2**depth), 32*(2**depth)*2,act)
|
52 |
+
|
53 |
+
return nn.Sequential(layer1,layer2)
|
54 |
+
|
55 |
+
|
56 |
+
class DownTransition(nn.Module):
|
57 |
+
def __init__(self, in_channel,depth, act):
|
58 |
+
super(DownTransition, self).__init__()
|
59 |
+
self.ops = _make_nConv(in_channel, depth,act)
|
60 |
+
self.maxpool = nn.MaxPool3d(2)
|
61 |
+
self.current_depth = depth
|
62 |
+
|
63 |
+
def forward(self, x):
|
64 |
+
if self.current_depth == 3:
|
65 |
+
out = self.ops(x)
|
66 |
+
out_before_pool = out
|
67 |
+
else:
|
68 |
+
out_before_pool = self.ops(x)
|
69 |
+
out = self.maxpool(out_before_pool)
|
70 |
+
return out, out_before_pool
|
71 |
+
|
72 |
+
class UpTransition(nn.Module):
|
73 |
+
def __init__(self, inChans, outChans, depth,act):
|
74 |
+
super(UpTransition, self).__init__()
|
75 |
+
self.depth = depth
|
76 |
+
self.up_conv = nn.ConvTranspose3d(inChans, outChans, kernel_size=2, stride=2)
|
77 |
+
self.ops = _make_nConv(inChans+ outChans//2,depth, act, double_chnnel=True)
|
78 |
+
|
79 |
+
def forward(self, x, skip_x):
|
80 |
+
out_up_conv = self.up_conv(x)
|
81 |
+
concat = torch.cat((out_up_conv,skip_x),1)
|
82 |
+
out = self.ops(concat)
|
83 |
+
return out
|
84 |
+
|
85 |
+
class OutputTransition(nn.Module):
|
86 |
+
def __init__(self, inChans, n_labels):
|
87 |
+
|
88 |
+
super(OutputTransition, self).__init__()
|
89 |
+
self.final_conv = nn.Conv3d(inChans, n_labels, kernel_size=1)
|
90 |
+
self.sigmoid = nn.Sigmoid()
|
91 |
+
|
92 |
+
def forward(self, x):
|
93 |
+
out = self.sigmoid(self.final_conv(x))
|
94 |
+
return out
|
95 |
+
|
96 |
+
class ConvLayer(nn.Module):
|
97 |
+
def __init__(self, in_channels, out_channels, drop_rate, kernel, pooling, BN=True, relu_type='leaky'):
|
98 |
+
super().__init__()
|
99 |
+
kernel_size, kernel_stride, kernel_padding = kernel
|
100 |
+
pool_kernel, pool_stride, pool_padding = pooling
|
101 |
+
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, kernel_stride, kernel_padding, bias=False)
|
102 |
+
self.pooling = nn.MaxPool3d(pool_kernel, pool_stride, pool_padding)
|
103 |
+
self.BN = nn.BatchNorm3d(out_channels)
|
104 |
+
self.relu = nn.LeakyReLU(inplace=False) if relu_type=='leaky' else nn.ReLU(inplace=False)
|
105 |
+
self.dropout = nn.Dropout(drop_rate, inplace=False)
|
106 |
+
|
107 |
+
def forward(self, x):
|
108 |
+
x = self.conv(x)
|
109 |
+
x = self.pooling(x)
|
110 |
+
x = self.BN(x)
|
111 |
+
x = self.relu(x)
|
112 |
+
x = self.dropout(x)
|
113 |
+
return x
|
114 |
+
|
115 |
+
class AttentionModule(nn.Module):
|
116 |
+
def __init__(self, in_channels, out_channels, drop_rate=0.1):
|
117 |
+
super(AttentionModule, self).__init__()
|
118 |
+
self.conv = nn.Conv3d(in_channels, out_channels, 1, 1, 0, bias=False)
|
119 |
+
self.attention = ConvLayer(in_channels, out_channels, drop_rate, (1, 1, 0), (1, 1, 0))
|
120 |
+
|
121 |
+
def forward(self, x, return_attention=True):
|
122 |
+
feats = self.conv(x)
|
123 |
+
att = F.softmax(self.attention(x))
|
124 |
+
|
125 |
+
out = feats * att
|
126 |
+
|
127 |
+
if return_attention:
|
128 |
+
return att, out
|
129 |
+
|
130 |
+
return out
|
131 |
+
|
132 |
+
class UNet3D(nn.Module):
|
133 |
+
# the number of convolutions in each layer corresponds
|
134 |
+
# to what is in the actual prototxt, not the intent
|
135 |
+
def __init__(self, n_class=1, act='relu', pretrained=False, input_size=(1,1,182,218,182), attention=False, drop_rate=0.1, blocks=4):
|
136 |
+
super(UNet3D, self).__init__()
|
137 |
+
|
138 |
+
self.blocks = blocks
|
139 |
+
self.down_tr64 = DownTransition(1,0,act)
|
140 |
+
self.down_tr128 = DownTransition(64,1,act)
|
141 |
+
self.down_tr256 = DownTransition(128,2,act)
|
142 |
+
self.down_tr512 = DownTransition(256,3,act)
|
143 |
+
|
144 |
+
self.up_tr256 = UpTransition(512, 512,2,act)
|
145 |
+
self.up_tr128 = UpTransition(256,256, 1,act)
|
146 |
+
self.up_tr64 = UpTransition(128,128,0,act)
|
147 |
+
self.out_tr = OutputTransition(64, 1)
|
148 |
+
|
149 |
+
self.pretrained = pretrained
|
150 |
+
self.attention = attention
|
151 |
+
if pretrained:
|
152 |
+
print("Using image pretrained model checkpoint")
|
153 |
+
weight_dir = '/home/skowshik/ADRD_repo/img_pretrained_ckpt/Genesis_Chest_CT.pt'
|
154 |
+
checkpoint = torch.load(weight_dir)
|
155 |
+
state_dict = checkpoint['state_dict']
|
156 |
+
unParalled_state_dict = {}
|
157 |
+
for key in state_dict.keys():
|
158 |
+
unParalled_state_dict[key.replace("module.", "")] = state_dict[key]
|
159 |
+
self.load_state_dict(unParalled_state_dict)
|
160 |
+
del self.up_tr256
|
161 |
+
del self.up_tr128
|
162 |
+
del self.up_tr64
|
163 |
+
del self.out_tr
|
164 |
+
|
165 |
+
if self.blocks == 5:
|
166 |
+
self.down_tr1024 = DownTransition(512,4,act)
|
167 |
+
|
168 |
+
|
169 |
+
# self.conv1 = nn.Conv3d(512, 256, 1, 1, 0, bias=False)
|
170 |
+
# self.conv2 = nn.Conv3d(256, 128, 1, 1, 0, bias=False)
|
171 |
+
# self.conv3 = nn.Conv3d(128, 64, 1, 1, 0, bias=False)
|
172 |
+
|
173 |
+
if attention:
|
174 |
+
self.attention_module = AttentionModule(1024 if self.blocks==5 else 512, n_class, drop_rate=drop_rate)
|
175 |
+
# Output.
|
176 |
+
self.avgpool = nn.AvgPool3d((6,7,6), stride=(6,6,6))
|
177 |
+
|
178 |
+
dummy_inp = torch.rand(input_size)
|
179 |
+
dummy_feats = self.forward(dummy_inp, stage='get_features')
|
180 |
+
dummy_feats = dummy_feats[0]
|
181 |
+
self.in_features = list(dummy_feats.shape)
|
182 |
+
ic(self.in_features)
|
183 |
+
|
184 |
+
self._init_weights()
|
185 |
+
|
186 |
+
def _init_weights(self):
|
187 |
+
if not self.pretrained:
|
188 |
+
for m in self.modules():
|
189 |
+
if isinstance(m, nn.Conv3d):
|
190 |
+
init.kaiming_normal_(m.weight)
|
191 |
+
elif isinstance(m, ContBatchNorm3d):
|
192 |
+
init.constant_(m.weight, 1)
|
193 |
+
init.constant_(m.bias, 0)
|
194 |
+
elif isinstance(m, nn.Linear):
|
195 |
+
init.kaiming_normal_(m.weight)
|
196 |
+
init.constant_(m.bias, 0)
|
197 |
+
elif self.attention:
|
198 |
+
for m in self.attention_module.modules():
|
199 |
+
if isinstance(m, nn.Conv3d):
|
200 |
+
init.kaiming_normal_(m.weight)
|
201 |
+
elif isinstance(m, nn.BatchNorm3d):
|
202 |
+
init.constant_(m.weight, 1)
|
203 |
+
init.constant_(m.bias, 0)
|
204 |
+
else:
|
205 |
+
pass
|
206 |
+
# Zero initialize the last batchnorm in each residual branch.
|
207 |
+
# for m in self.modules():
|
208 |
+
# if isinstance(m, BottleneckBlock):
|
209 |
+
# init.constant_(m.out_conv.bn.weight, 0)
|
210 |
+
|
211 |
+
def forward(self, x, stage='normal', attention=False):
|
212 |
+
ic('backbone forward')
|
213 |
+
self.out64, self.skip_out64 = self.down_tr64(x)
|
214 |
+
self.out128,self.skip_out128 = self.down_tr128(self.out64)
|
215 |
+
self.out256,self.skip_out256 = self.down_tr256(self.out128)
|
216 |
+
self.out512,self.skip_out512 = self.down_tr512(self.out256)
|
217 |
+
if self.blocks == 5:
|
218 |
+
self.out1024,self.skip_out1024 = self.down_tr1024(self.out512)
|
219 |
+
ic(self.out1024.shape)
|
220 |
+
# self.out = self.conv1(self.out512)
|
221 |
+
# self.out = self.conv2(self.out)
|
222 |
+
# self.out = self.conv3(self.out)
|
223 |
+
# self.out = self.conv(self.out)
|
224 |
+
ic(hasattr(self, 'attention_module'))
|
225 |
+
if hasattr(self, 'attention_module'):
|
226 |
+
att, feats = self.attention_module(self.out1024 if self.blocks==5 else self.out512)
|
227 |
+
else:
|
228 |
+
feats = self.out1024 if self.blocks==5 else self.out512
|
229 |
+
ic(feats.shape)
|
230 |
+
if attention:
|
231 |
+
return att, feats
|
232 |
+
return feats
|
adrd/nn/unet_3d.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
sys.path.append('..')
|
3 |
+
# from feature_extractor.for_image_data.backbone import CNN_GAP, ResNet3D, UNet3D
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from torchvision import models
|
7 |
+
import torch.nn.functional as F
|
8 |
+
# from . import UNet3D
|
9 |
+
from .unet import UNet3D
|
10 |
+
from icecream import ic
|
11 |
+
|
12 |
+
|
13 |
+
class UNet3DBase(nn.Module):
|
14 |
+
def __init__(self, n_class=1, act='relu', attention=False, pretrained=False, drop_rate=0.1, blocks=4):
|
15 |
+
super(UNet3DBase, self).__init__()
|
16 |
+
model = UNet3D(n_class=n_class, attention=attention, pretrained=pretrained, blocks=blocks)
|
17 |
+
|
18 |
+
self.blocks = blocks
|
19 |
+
|
20 |
+
self.down_tr64 = model.down_tr64
|
21 |
+
self.down_tr128 = model.down_tr128
|
22 |
+
self.down_tr256 = model.down_tr256
|
23 |
+
self.down_tr512 = model.down_tr512
|
24 |
+
if self.blocks == 5:
|
25 |
+
self.down_tr1024 = model.down_tr1024
|
26 |
+
# self.block_modules = nn.ModuleList([self.down_tr64, self.down_tr128, self.down_tr256, self.down_tr512])
|
27 |
+
|
28 |
+
self.in_features = model.in_features
|
29 |
+
# ic(attention)
|
30 |
+
if attention:
|
31 |
+
self.attention_module = model.attention_module
|
32 |
+
# self.attention_module = AttentionModule(512, n_class, drop_rate=drop_rate)
|
33 |
+
# self.avgpool = nn.AvgPool3d((6,7,6), stride=(6,6,6))
|
34 |
+
|
35 |
+
def forward(self, x, stage='normal', attention=False):
|
36 |
+
# ic('UNet3DBase forward')
|
37 |
+
self.out64, self.skip_out64 = self.down_tr64(x)
|
38 |
+
# ic(self.out64.shape, self.skip_out64.shape)
|
39 |
+
self.out128,self.skip_out128 = self.down_tr128(self.out64)
|
40 |
+
# ic(self.out128.shape, self.skip_out128.shape)
|
41 |
+
self.out256,self.skip_out256 = self.down_tr256(self.out128)
|
42 |
+
# ic(self.out256.shape, self.skip_out256.shape)
|
43 |
+
self.out512,self.skip_out512 = self.down_tr512(self.out256)
|
44 |
+
# ic(self.out512.shape, self.skip_out512.shape)
|
45 |
+
if self.blocks == 5:
|
46 |
+
self.out1024,self.skip_out1024 = self.down_tr1024(self.out512)
|
47 |
+
# ic(self.out1024.shape, self.skip_out1024.shape)
|
48 |
+
# ic(hasattr(self, 'attention_module'))
|
49 |
+
if hasattr(self, 'attention_module'):
|
50 |
+
att, feats = self.attention_module(self.out1024 if self.blocks == 5 else self.out512)
|
51 |
+
else:
|
52 |
+
feats = self.out1024 if self.blocks == 5 else self.out512
|
53 |
+
# ic(feats.shape)
|
54 |
+
if attention:
|
55 |
+
return att, feats
|
56 |
+
return feats
|
57 |
+
|
58 |
+
# self.out_up_256 = self.up_tr256(self.out512,self.skip_out256)
|
59 |
+
# self.out_up_128 = self.up_tr128(self.out_up_256, self.skip_out128)
|
60 |
+
# self.out_up_64 = self.up_tr64(self.out_up_128, self.skip_out64)
|
61 |
+
# self.out = self.out_tr(self.out_up_64)
|
62 |
+
|
63 |
+
# return self.out
|
adrd/nn/unet_img_model.py
ADDED
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pyexpat import features
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch.cuda.amp import autocast
|
6 |
+
import numpy as np
|
7 |
+
import re
|
8 |
+
from icecream import ic
|
9 |
+
import math
|
10 |
+
import torch.nn.utils.weight_norm as weightNorm
|
11 |
+
|
12 |
+
# from . import UNet3DBase
|
13 |
+
from .unet_3d import UNet3DBase
|
14 |
+
|
15 |
+
|
16 |
+
def init_weights(m):
|
17 |
+
classname = m.__class__.__name__
|
18 |
+
if classname.find('Conv2d') != -1 or classname.find('ConvTranspose2d') != -1:
|
19 |
+
nn.init.kaiming_uniform_(m.weight)
|
20 |
+
nn.init.zeros_(m.bias)
|
21 |
+
elif classname.find('BatchNorm') != -1:
|
22 |
+
nn.init.normal_(m.weight, 1.0, 0.02)
|
23 |
+
nn.init.zeros_(m.bias)
|
24 |
+
elif classname.find('Linear') != -1:
|
25 |
+
nn.init.xavier_normal_(m.weight)
|
26 |
+
nn.init.zeros_(m.bias)
|
27 |
+
|
28 |
+
class feat_classifier(nn.Module):
|
29 |
+
def __init__(self, class_num, bottleneck_dim=256, type="linear"):
|
30 |
+
super(feat_classifier, self).__init__()
|
31 |
+
self.type = type
|
32 |
+
# if type in ['conv', 'gap'] and len(bottleneck_dim) > 3:
|
33 |
+
# bottleneck_dim = bottleneck_dim[-3:]
|
34 |
+
ic(bottleneck_dim)
|
35 |
+
if type == 'wn':
|
36 |
+
self.layer = weightNorm(
|
37 |
+
nn.Linear(bottleneck_dim[1:], class_num), name="weight")
|
38 |
+
# self.fc.apply(init_weights)
|
39 |
+
elif type == 'gap':
|
40 |
+
if len(bottleneck_dim) > 3:
|
41 |
+
bottleneck_dim = bottleneck_dim[-3:]
|
42 |
+
self.layer = nn.AvgPool3d(bottleneck_dim, stride=(1,1,1))
|
43 |
+
elif type == 'conv':
|
44 |
+
if len(bottleneck_dim) > 3:
|
45 |
+
bottleneck_dim = bottleneck_dim[-4:]
|
46 |
+
ic(bottleneck_dim)
|
47 |
+
self.layer = nn.Conv3d(bottleneck_dim[0], class_num, kernel_size=bottleneck_dim[1:])
|
48 |
+
ic(self.layer)
|
49 |
+
else:
|
50 |
+
print('bottleneck dim: ', bottleneck_dim)
|
51 |
+
self.layer = nn.Sequential(
|
52 |
+
torch.nn.Flatten(start_dim=1, end_dim=-1),
|
53 |
+
nn.Linear(math.prod(bottleneck_dim), class_num)
|
54 |
+
)
|
55 |
+
self.layer.apply(init_weights)
|
56 |
+
|
57 |
+
def forward(self, x):
|
58 |
+
# print('=> feat_classifier forward')
|
59 |
+
# ic(x.size())
|
60 |
+
x = self.layer(x)
|
61 |
+
# ic(x.size())
|
62 |
+
if self.type in ['gap','conv']:
|
63 |
+
x = torch.squeeze(x)
|
64 |
+
if len(x.shape) < 2:
|
65 |
+
x = torch.unsqueeze(x,0)
|
66 |
+
# print('returning x: ', x.size())
|
67 |
+
return x
|
68 |
+
|
69 |
+
class ImageModel(nn.Module):
|
70 |
+
"""
|
71 |
+
Empirical Risk Minimization (ERM)
|
72 |
+
"""
|
73 |
+
|
74 |
+
def __init__(
|
75 |
+
self,
|
76 |
+
counts=None,
|
77 |
+
classifier='gap',
|
78 |
+
accum_iter=8,
|
79 |
+
save_emb=False,
|
80 |
+
# ssl,
|
81 |
+
num_classes=1,
|
82 |
+
load_img_ckpt=False,
|
83 |
+
):
|
84 |
+
super(ImageModel, self).__init__()
|
85 |
+
if counts is not None:
|
86 |
+
if isinstance(counts[0], list):
|
87 |
+
counts = np.stack(counts, axis=0).sum(axis=0)
|
88 |
+
print('counts: ', counts)
|
89 |
+
total = np.sum(counts)
|
90 |
+
print(total/counts)
|
91 |
+
self.weight = total/torch.FloatTensor(counts)
|
92 |
+
else:
|
93 |
+
total = sum(counts)
|
94 |
+
self.weight = torch.FloatTensor([total/c for c in counts])
|
95 |
+
else:
|
96 |
+
self.weight = None
|
97 |
+
print('weight: ', self.weight)
|
98 |
+
# device = torch.device(f'cuda:{args.gpu_id}' if args.gpu_id is not None else 'cpu')
|
99 |
+
self.criterion = nn.CrossEntropyLoss(weight=self.weight)
|
100 |
+
# if ssl:
|
101 |
+
# # add contrastive loss
|
102 |
+
# # self.ssl_criterion =
|
103 |
+
# pass
|
104 |
+
|
105 |
+
self.featurizer = UNet3DBase(n_class=num_classes, attention=True, pretrained=load_img_ckpt)
|
106 |
+
self.classifier = feat_classifier(
|
107 |
+
num_classes, self.featurizer.in_features, classifier)
|
108 |
+
|
109 |
+
self.network = nn.Sequential(
|
110 |
+
self.featurizer, self.classifier)
|
111 |
+
self.accum_iter = accum_iter
|
112 |
+
self.acc_steps = 0
|
113 |
+
self.save_embedding = save_emb
|
114 |
+
|
115 |
+
def update(self, minibatches, opt, sch, scaler):
|
116 |
+
print('--------------def update----------------')
|
117 |
+
device = list(self.parameters())[0].device
|
118 |
+
all_x = torch.cat([data[1].to(device).float() for data in minibatches])
|
119 |
+
all_y = torch.cat([data[2].to(device).long() for data in minibatches])
|
120 |
+
print('all_x: ', all_x.size())
|
121 |
+
# all_p = self.predict(all_x)
|
122 |
+
# all_probs =
|
123 |
+
label_list = all_y.tolist()
|
124 |
+
count = float(len(label_list))
|
125 |
+
ic(count)
|
126 |
+
|
127 |
+
uniques = sorted(list(set(label_list)))
|
128 |
+
ic(uniques)
|
129 |
+
counts = [float(label_list.count(i)) for i in uniques]
|
130 |
+
ic(counts)
|
131 |
+
|
132 |
+
weights = [count / c for c in counts]
|
133 |
+
ic(weights)
|
134 |
+
|
135 |
+
with autocast():
|
136 |
+
loss = self.criterion(self.predict(all_x), all_y)
|
137 |
+
self.acc_steps += 1
|
138 |
+
print('class: ', loss.item())
|
139 |
+
|
140 |
+
scaler.scale(loss / self.accum_iter).backward()
|
141 |
+
|
142 |
+
if self.acc_steps == self.accum_iter:
|
143 |
+
scaler.step(opt)
|
144 |
+
if sch:
|
145 |
+
sch.step()
|
146 |
+
scaler.update()
|
147 |
+
self.zero_grad()
|
148 |
+
self.acc_steps = 0
|
149 |
+
torch.cuda.empty_cache()
|
150 |
+
|
151 |
+
del all_x
|
152 |
+
del all_y
|
153 |
+
return {'class': loss.item()}, sch
|
154 |
+
|
155 |
+
def forward(self, *args, **kwargs):
|
156 |
+
return self.network(*args, **kwargs)
|
157 |
+
|
158 |
+
def predict(self, x, stage='normal', attention=False):
|
159 |
+
# print('network device: ', list(self.network.parameters())[0].device)
|
160 |
+
# print('x device: ', x.device)
|
161 |
+
if stage == 'get_features' or self.save_embedding:
|
162 |
+
feats = self.network[0](x, attention=attention)
|
163 |
+
output = self.network[1](feats[-1] if attention else feats)
|
164 |
+
return feats, output
|
165 |
+
else:
|
166 |
+
return self.network(x)
|
167 |
+
|
168 |
+
def extract_features(self, x, attention=False):
|
169 |
+
feats = self.network[0](x, attention=attention)
|
170 |
+
return feats
|
171 |
+
|
172 |
+
def load_checkpoint(self, state_dict):
|
173 |
+
try:
|
174 |
+
self.load_checkpoint_helper(state_dict)
|
175 |
+
except:
|
176 |
+
featurizer_dict = {}
|
177 |
+
net_dict = {}
|
178 |
+
for key,val in state_dict.items():
|
179 |
+
if 'featurizer' in key:
|
180 |
+
featurizer_dict[key] = val
|
181 |
+
elif 'network' in key:
|
182 |
+
net_dict[key] = val
|
183 |
+
self.featurizer.load_state_dict(featurizer_dict)
|
184 |
+
self.classifier.load_state_dict(net_dict)
|
185 |
+
|
186 |
+
def load_checkpoint_helper(self, state_dict):
|
187 |
+
try:
|
188 |
+
self.load_state_dict(state_dict)
|
189 |
+
print('try: loaded')
|
190 |
+
except RuntimeError as e:
|
191 |
+
print('--> except')
|
192 |
+
if 'Missing key(s) in state_dict:' in str(e):
|
193 |
+
state_dict = {
|
194 |
+
key.replace('module.', '', 1): value
|
195 |
+
for key, value in state_dict.items()
|
196 |
+
}
|
197 |
+
state_dict = {
|
198 |
+
key.replace('featurizer.', '', 1).replace('classifier.','',1): value
|
199 |
+
for key, value in state_dict.items()
|
200 |
+
}
|
201 |
+
state_dict = {
|
202 |
+
re.sub('network.[0-9].', '', key): value
|
203 |
+
for key, value in state_dict.items()
|
204 |
+
}
|
205 |
+
try:
|
206 |
+
del state_dict['criterion.weight']
|
207 |
+
except:
|
208 |
+
pass
|
209 |
+
self.load_state_dict(state_dict)
|
210 |
+
|
211 |
+
print('except: loaded')
|
adrd/nn/vitautoenc.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) MONAI Consortium
|
2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6 |
+
# Unless required by applicable law or agreed to in writing, software
|
7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9 |
+
# See the License for the specific language governing permissions and
|
10 |
+
# limitations under the License.
|
11 |
+
|
12 |
+
|
13 |
+
from monai.networks.blocks.patchembedding import PatchEmbeddingBlock
|
14 |
+
from monai.networks.layers import Conv
|
15 |
+
from monai.utils import ensure_tuple_rep
|
16 |
+
|
17 |
+
from typing import Sequence, Union
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
|
21 |
+
from ..nn.blocks import TransformerBlock
|
22 |
+
from icecream import ic
|
23 |
+
ic.disable()
|
24 |
+
|
25 |
+
__all__ = ["ViTAutoEnc"]
|
26 |
+
|
27 |
+
|
28 |
+
class ViTAutoEnc(nn.Module):
|
29 |
+
"""
|
30 |
+
Vision Transformer (ViT), based on: "Dosovitskiy et al.,
|
31 |
+
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
|
32 |
+
|
33 |
+
Modified to also give same dimension outputs as the input size of the image
|
34 |
+
"""
|
35 |
+
|
36 |
+
def __init__(
|
37 |
+
self,
|
38 |
+
in_channels: int,
|
39 |
+
img_size: Union[Sequence[int], int],
|
40 |
+
patch_size: Union[Sequence[int], int],
|
41 |
+
out_channels: int = 1,
|
42 |
+
deconv_chns: int = 16,
|
43 |
+
hidden_size: int = 768,
|
44 |
+
mlp_dim: int = 3072,
|
45 |
+
num_layers: int = 12,
|
46 |
+
num_heads: int = 12,
|
47 |
+
pos_embed: str = "conv",
|
48 |
+
dropout_rate: float = 0.0,
|
49 |
+
spatial_dims: int = 3,
|
50 |
+
) -> None:
|
51 |
+
"""
|
52 |
+
Args:
|
53 |
+
in_channels: dimension of input channels or the number of channels for input
|
54 |
+
img_size: dimension of input image.
|
55 |
+
patch_size: dimension of patch size.
|
56 |
+
hidden_size: dimension of hidden layer.
|
57 |
+
out_channels: number of output channels.
|
58 |
+
deconv_chns: number of channels for the deconvolution layers.
|
59 |
+
mlp_dim: dimension of feedforward layer.
|
60 |
+
num_layers: number of transformer blocks.
|
61 |
+
num_heads: number of attention heads.
|
62 |
+
pos_embed: position embedding layer type.
|
63 |
+
dropout_rate: faction of the input units to drop.
|
64 |
+
spatial_dims: number of spatial dimensions.
|
65 |
+
|
66 |
+
Examples::
|
67 |
+
|
68 |
+
# for single channel input with image size of (96,96,96), conv position embedding and segmentation backbone
|
69 |
+
# It will provide an output of same size as that of the input
|
70 |
+
>>> net = ViTAutoEnc(in_channels=1, patch_size=(16,16,16), img_size=(96,96,96), pos_embed='conv')
|
71 |
+
|
72 |
+
# for 3-channel with image size of (128,128,128), output will be same size as of input
|
73 |
+
>>> net = ViTAutoEnc(in_channels=3, patch_size=(16,16,16), img_size=(128,128,128), pos_embed='conv')
|
74 |
+
|
75 |
+
"""
|
76 |
+
|
77 |
+
super().__init__()
|
78 |
+
|
79 |
+
self.patch_size = ensure_tuple_rep(patch_size, spatial_dims)
|
80 |
+
self.spatial_dims = spatial_dims
|
81 |
+
self.hidden_size = hidden_size
|
82 |
+
|
83 |
+
self.patch_embedding = PatchEmbeddingBlock(
|
84 |
+
in_channels=in_channels,
|
85 |
+
img_size=img_size,
|
86 |
+
patch_size=patch_size,
|
87 |
+
hidden_size=hidden_size,
|
88 |
+
num_heads=num_heads,
|
89 |
+
pos_embed=pos_embed,
|
90 |
+
dropout_rate=dropout_rate,
|
91 |
+
spatial_dims=self.spatial_dims,
|
92 |
+
)
|
93 |
+
self.blocks = nn.ModuleList(
|
94 |
+
[TransformerBlock(hidden_size, mlp_dim, num_heads, dropout_rate) for i in range(num_layers)]
|
95 |
+
)
|
96 |
+
self.norm = nn.LayerNorm(hidden_size)
|
97 |
+
|
98 |
+
new_patch_size = [4] * self.spatial_dims
|
99 |
+
conv_trans = Conv[Conv.CONVTRANS, self.spatial_dims]
|
100 |
+
# self.conv3d_transpose* is to be compatible with existing 3d model weights.
|
101 |
+
self.conv3d_transpose = conv_trans(hidden_size, deconv_chns, kernel_size=new_patch_size, stride=new_patch_size)
|
102 |
+
self.conv3d_transpose_1 = conv_trans(
|
103 |
+
in_channels=deconv_chns, out_channels=out_channels, kernel_size=new_patch_size, stride=new_patch_size
|
104 |
+
)
|
105 |
+
|
106 |
+
def forward(self, x, return_emb=False, return_hiddens=False):
|
107 |
+
"""
|
108 |
+
Args:
|
109 |
+
x: input tensor must have isotropic spatial dimensions,
|
110 |
+
such as ``[batch_size, channels, sp_size, sp_size[, sp_size]]``.
|
111 |
+
"""
|
112 |
+
spatial_size = x.shape[2:]
|
113 |
+
x = self.patch_embedding(x)
|
114 |
+
hidden_states_out = []
|
115 |
+
for blk in self.blocks:
|
116 |
+
x = blk(x)
|
117 |
+
hidden_states_out.append(x)
|
118 |
+
x = self.norm(x)
|
119 |
+
x = x.transpose(1, 2)
|
120 |
+
if return_emb:
|
121 |
+
return x
|
122 |
+
d = [s // p for s, p in zip(spatial_size, self.patch_size)]
|
123 |
+
x = torch.reshape(x, [x.shape[0], x.shape[1], *d])
|
124 |
+
x = self.conv3d_transpose(x)
|
125 |
+
x = self.conv3d_transpose_1(x)
|
126 |
+
if return_hiddens:
|
127 |
+
return x, hidden_states_out
|
128 |
+
return x
|
129 |
+
|
130 |
+
def get_last_selfattention(self, x):
|
131 |
+
"""
|
132 |
+
Args:
|
133 |
+
x: input tensor must have isotropic spatial dimensions,
|
134 |
+
such as ``[batch_size, channels, sp_size, sp_size[, sp_size]]``.
|
135 |
+
"""
|
136 |
+
x = self.patch_embedding(x)
|
137 |
+
ic(x.size())
|
138 |
+
for i, blk in enumerate(self.blocks):
|
139 |
+
if i < len(self.blocks) - 1:
|
140 |
+
x = blk(x)
|
141 |
+
x.size()
|
142 |
+
else:
|
143 |
+
return blk(x, return_attention=True)
|
144 |
+
|
145 |
+
def load(self, ckpt_path, map_location='cpu', checkpoint_key='state_dict'):
|
146 |
+
"""
|
147 |
+
Args:
|
148 |
+
ckpt_path: path to the pretrained weights
|
149 |
+
map_location: device to load the checkpoint on
|
150 |
+
"""
|
151 |
+
state_dict = torch.load(ckpt_path, map_location=map_location)
|
152 |
+
ic(state_dict['epoch'], state_dict['train_loss'])
|
153 |
+
if checkpoint_key in state_dict:
|
154 |
+
print(f"Take key {checkpoint_key} in provided checkpoint dict")
|
155 |
+
state_dict = state_dict[checkpoint_key]
|
156 |
+
# remove `module.` prefix
|
157 |
+
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
|
158 |
+
# remove `backbone.` prefix induced by multicrop wrapper
|
159 |
+
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
|
160 |
+
msg = self.load_state_dict(state_dict, strict=False)
|
161 |
+
print('Pretrained weights found at {} and loaded with msg: {}'.format(ckpt_path, msg))
|
162 |
+
|
163 |
+
|