diff --git a/adrd/__init__.py b/adrd/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..7afccb9bc61ec8cedbaca6d99029132533a10e7b --- /dev/null +++ b/adrd/__init__.py @@ -0,0 +1,22 @@ +__version__ = '0.0.1' + +from . import nn +from . import model + +# # load pretrained transformer +# pretrained_transformer = model.Transformer.from_ckpt('{}/ckpt/ckpt.pt'.format(__path__[0])) +# from . import shap_adrd +# from .model import DynamicCalibratedClassifier +# from .model import StaticCalibratedClassifier + +# load fitted transformer and calibrated wrapper +# try: +# fitted_resnet3d = model.CNNResNet3DWithLinearClassifier.from_ckpt('{}/ckpt/ckpt_img_072523.pt'.format(__path__[0])) +# fitted_calibrated_classifier_nonimg = StaticCalibratedClassifier.from_ckpt( +# filepath_state_dict = '{}/ckpt/static_calibrated_classifier_073023.pkl'.format(__path__[0]), +# filepath_wrapped_model = '{}/ckpt/ckpt_080823.pt'.format(__path__[0]), +# ) +# fitted_transformer_nonimg = fitted_calibrated_classifier_nonimg.model +# shap_explainer = shap_adrd.SamplingExplainer(fitted_transformer_nonimg) +# except: +# print('Fail to load checkpoints.') diff --git a/adrd/__pycache__/__init__.cpython-311.pyc b/adrd/__pycache__/__init__.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..2844b66d32ba9cfc55a2ca94857041960ff55c6d Binary files /dev/null and b/adrd/__pycache__/__init__.cpython-311.pyc differ diff --git a/adrd/_ds/__init__.py b/adrd/_ds/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/adrd/_ds/lddl.py b/adrd/_ds/lddl.py new file mode 100755 index 0000000000000000000000000000000000000000..c789d45e4fed0bb0a1494539fe963e54b844bda4 --- /dev/null +++ b/adrd/_ds/lddl.py @@ -0,0 +1,71 @@ +from typing import Any, Self, overload + + +class lddl: + ''' ... ''' + def __init__(self) -> None: + ''' ... ''' + self.dat_ld: list[dict[str, Any]] = None + self.dat_dl: dict[str, list[Any]] = None + + @overload + def __getitem__(self, idx: int) -> dict[str, Any]: ... + + @overload + def __getitem__(self, idx: str) -> list[Any]: ... + + def __getitem__(self, idx: str | int) -> list[Any] | dict[str, Any]: + ''' ... ''' + if isinstance(idx, str): + return self.dat_dl[idx] + elif isinstance(idx, int): + return self.dat_ld[idx] + else: + raise TypeError('Unexpected key type: {}'.format(type(idx))) + + @classmethod + def from_ld(cls, dat: list[dict[str, Any]]) -> Self: + ''' Construct from list of dicts. ''' + obj = cls() + obj.dat_ld = dat + obj.dat_dl = {k: [dat[i][k] for i in range(len(dat))] for k in dat[0]} + return obj + + @classmethod + def from_dl(cls, dat: dict[str, list[Any]]) -> Self: + ''' Construct from dict of lists. ''' + obj = cls() + obj.dat_ld = [dict(zip(dat, v)) for v in zip(*dat.values())] + obj.dat_dl = dat + return obj + + +if __name__ == '__main__': + ''' for testing purpose only ''' + dl = { + 'a': [0, 1, 2], + 'b': [3, 4, 5], + } + + ld = [ + {'a': 0, 'b': 1, 'c': 2}, + {'a': 3, 'b': 4, 'c': 5}, + ] + + # test constructing from ld + dat_ld = lddl.from_ld(ld) + print(dat_ld.dat_ld) + print(dat_ld.dat_dl) + + # test constructing from dl + dat_dl = lddl.from_dl(dl) + print(dat_dl.dat_ld) + print(dat_dl.dat_dl) + + # test __getitem__ + print(dat_dl['a']) + print(dat_dl[0]) + + # mouse hover to check if type hints are correct + v = dat_dl['a'] + v = dat_dl[0] \ No newline at end of file diff --git a/adrd/model/__init__.py b/adrd/model/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..09dd8f3b2e5a5ba3208c0a6459905902785abe26 --- /dev/null +++ b/adrd/model/__init__.py @@ -0,0 +1,6 @@ +from .adrd_model import ADRDModel +from .imaging_model import ImagingModel +from .train_resnet import TrainResNet +# from .transformer import Transformer +from .calibration import DynamicCalibratedClassifier +from .calibration import StaticCalibratedClassifier diff --git a/adrd/model/__pycache__/__init__.cpython-311.pyc b/adrd/model/__pycache__/__init__.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..46057032643d12330bb7988809e9a06426b31d02 Binary files /dev/null and b/adrd/model/__pycache__/__init__.cpython-311.pyc differ diff --git a/adrd/model/__pycache__/adrd_model.cpython-311.pyc b/adrd/model/__pycache__/adrd_model.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..fe74a0cee9fc564111c6a4d9f2201299996a3327 Binary files /dev/null and b/adrd/model/__pycache__/adrd_model.cpython-311.pyc differ diff --git a/adrd/model/__pycache__/calibration.cpython-311.pyc b/adrd/model/__pycache__/calibration.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..0b7cb8cb299e4a7af36f268943a2ff847a929f50 Binary files /dev/null and b/adrd/model/__pycache__/calibration.cpython-311.pyc differ diff --git a/adrd/model/__pycache__/imaging_model.cpython-311.pyc b/adrd/model/__pycache__/imaging_model.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..a174bc181b5fc869aec5639c246cfd8e11d02f8a Binary files /dev/null and b/adrd/model/__pycache__/imaging_model.cpython-311.pyc differ diff --git a/adrd/model/__pycache__/train_resnet.cpython-311.pyc b/adrd/model/__pycache__/train_resnet.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..85755b8cf6f4b5bfa7f9adbc5e3c3723cf5da6c0 Binary files /dev/null and b/adrd/model/__pycache__/train_resnet.cpython-311.pyc differ diff --git a/adrd/model/adrd_model.py b/adrd/model/adrd_model.py new file mode 100755 index 0000000000000000000000000000000000000000..d5ece009566eedee3f8cff19562f7474382bae2a --- /dev/null +++ b/adrd/model/adrd_model.py @@ -0,0 +1,976 @@ +__all__ = ['ADRDModel'] + +import wandb +import torch +from torch.utils.data import DataLoader +import numpy as np +import tqdm +import multiprocessing +from sklearn.base import BaseEstimator +from sklearn.utils.validation import check_is_fitted +from sklearn.model_selection import train_test_split +from scipy.special import expit +from copy import deepcopy +from contextlib import suppress +from typing import Any, Self, Type +from functools import wraps +from tqdm import tqdm +Tensor = Type[torch.Tensor] +Module = Type[torch.nn.Module] + +# for DistributedDataParallel +import torch.distributed as dist +import torch.multiprocessing as mp +from torch.nn.parallel import DistributedDataParallel as DDP + +from .. import nn +from ..nn import Transformer +from ..utils import TransformerTrainingDataset, TransformerBalancedTrainingDataset, TransformerValidationDataset, TransformerTestingDataset, Transformer2ndOrderBalancedTrainingDataset +from ..utils.misc import ProgressBar +from ..utils.misc import get_metrics_multitask, print_metrics_multitask +from ..utils.misc import convert_args_kwargs_to_kwargs + + +def _manage_ctx_fit(func): + ''' ... ''' + @wraps(func) + def wrapper(*args, **kwargs): + # format arguments + kwargs = convert_args_kwargs_to_kwargs(func, args, kwargs) + + if kwargs['self']._device_ids is None: + return func(**kwargs) + else: + # change primary device + default_device = kwargs['self'].device + kwargs['self'].device = kwargs['self']._device_ids[0] + rtn = func(**kwargs) + kwargs['self'].to(default_device) + return rtn + return wrapper + + +class ADRDModel(BaseEstimator): + """Primary model class for ADRD framework. + + The ADRDModel encapsulates the core pipeline of the ADRD framework, + permitting users to train and validate with the provided data. Designed for + user-friendly operation, the ADRDModel is derived from + ``sklearn.base.BaseEstimator``, ensuring compliance with the well-established + API design conventions of scikit-learn. + """ + def __init__(self, + src_modalities: dict[str, dict[str, Any]], + tgt_modalities: dict[str, dict[str, Any]], + label_fractions: dict[str, float], + d_model: int = 32, + nhead: int = 1, + num_encoder_layers: int = 1, + num_decoder_layers: int = 1, + num_epochs: int = 32, + batch_size: int = 8, + batch_size_multiplier: int = 1, + lr: float = 1e-2, + weight_decay: float = 0.0, + beta: float = 0.9999, + gamma: float = 2.0, + criterion: str | None = None, + device: str = 'cpu', + cuda_devices: list = [1], + img_net: str | None = None, + imgnet_layers: int | None = 2, + img_size: int | None = 128, + fusion_stage: str = 'middle', + patch_size: int | None = 16, + imgnet_ckpt: str | None = None, + train_imgnet: bool = False, + ckpt_path: str = './adrd_tool/adrd/dev/ckpt/ckpt.pt', + load_from_ckpt: bool = False, + save_intermediate_ckpts: bool = False, + data_parallel: bool = False, + verbose: int = 0, + wandb_ = 0, + balanced_sampling: bool = False, + label_distribution: dict = {}, + ranking_loss: bool = False, + _device_ids: list | None = None, + + _dataloader_num_workers: int = 4, + _amp_enabled: bool = False, + ) -> None: + """Create a new ADRD model. + + :param src_modalities: _description_ + :type src_modalities: dict[str, dict[str, Any]] + :param tgt_modalities: _description_ + :type tgt_modalities: dict[str, dict[str, Any]] + :param label_fractions: _description_ + :type label_fractions: dict[str, float] + :param d_model: _description_, defaults to 32 + :type d_model: int, optional + :param nhead: number of transformer heads, defaults to 1 + :type nhead: int, optional + :param num_encoder_layers: number of encoder layers, defaults to 1 + :type num_encoder_layers: int, optional + :param num_decoder_layers: number of decoder layers, defaults to 1 + :type num_decoder_layers: int, optional + :param num_epochs: number of training epochs, defaults to 32 + :type num_epochs: int, optional + :param batch_size: batch size, defaults to 8 + :type batch_size: int, optional + :param batch_size_multiplier: _description_, defaults to 1 + :type batch_size_multiplier: int, optional + :param lr: learning rate, defaults to 1e-2 + :type lr: float, optional + :param weight_decay: _description_, defaults to 0.0 + :type weight_decay: float, optional + :param beta: _description_, defaults to 0.9999 + :type beta: float, optional + :param gamma: The focusing parameter for the focal loss. Higher values of gamma make easy-to-classify examples contribute less to the loss relative to hard-to-classify examples. Must be non-negative., defaults to 2.0 + :type gamma: float, optional + :param criterion: The criterion to select the best model, defaults to None + :type criterion: str | None, optional + :param device: 'cuda' or 'cpu', defaults to 'cpu' + :type device: str, optional + :param cuda_devices: A list of gpu numbers to data parallel training. The device must be set to 'cuda' and data_parallel must be set to True, defaults to [1] + :type cuda_devices: list, optional + :param img_net: _description_, defaults to None + :type img_net: str | None, optional + :param imgnet_layers: _description_, defaults to 2 + :type imgnet_layers: int | None, optional + :param img_size: _description_, defaults to 128 + :type img_size: int | None, optional + :param fusion_stage: _description_, defaults to 'middle' + :type fusion_stage: str, optional + :param patch_size: _description_, defaults to 16 + :type patch_size: int | None, optional + :param imgnet_ckpt: _description_, defaults to None + :type imgnet_ckpt: str | None, optional + :param train_imgnet: Set to True to finetune the img_net backbone, defaults to False + :type train_imgnet: bool, optional + :param ckpt_path: The model checkpoint point path, defaults to './adrd_tool/adrd/dev/ckpt/ckpt.pt' + :type ckpt_path: str, optional + :param load_from_ckpt: Set to True to load the model weights from checkpoint ckpt_path, defaults to False + :type load_from_ckpt: bool, optional + :param save_intermediate_ckpts: Set to True to save intermediate model checkpoints, defaults to False + :type save_intermediate_ckpts: bool, optional + :param data_parallel: Set to True to enable data parallel trsining, defaults to False + :type data_parallel: bool, optional + :param verbose: _description_, defaults to 0 + :type verbose: int, optional + :param wandb_: Set to 1 to track the loss and accuracy curves on wandb, defaults to 0 + :type wandb_: int, optional + :param balanced_sampling: _description_, defaults to False + :type balanced_sampling: bool, optional + :param label_distribution: _description_, defaults to {} + :type label_distribution: dict, optional + :param ranking_loss: _description_, defaults to False + :type ranking_loss: bool, optional + :param _device_ids: _description_, defaults to None + :type _device_ids: list | None, optional + :param _dataloader_num_workers: _description_, defaults to 4 + :type _dataloader_num_workers: int, optional + :param _amp_enabled: _description_, defaults to False + :type _amp_enabled: bool, optional + """ + # for multiprocessing + self._rank = 0 + self._lock = None + + # positional parameters + self.src_modalities = src_modalities + self.tgt_modalities = tgt_modalities + + # training parameters + self.label_fractions = label_fractions + self.d_model = d_model + self.nhead = nhead + self.num_encoder_layers = num_encoder_layers + self.num_decoder_layers = num_decoder_layers + self.num_epochs = num_epochs + self.batch_size = batch_size + self.batch_size_multiplier = batch_size_multiplier + self.lr = lr + self.weight_decay = weight_decay + self.beta = beta + self.gamma = gamma + self.criterion = criterion + self.device = device + self.cuda_devices = cuda_devices + self.img_net = img_net + self.patch_size = patch_size + self.img_size = img_size + self.fusion_stage = fusion_stage + self.imgnet_ckpt = imgnet_ckpt + self.imgnet_layers = imgnet_layers + self.train_imgnet = train_imgnet + self.ckpt_path = ckpt_path + self.load_from_ckpt = load_from_ckpt + self.save_intermediate_ckpts = save_intermediate_ckpts + self.data_parallel = data_parallel + self.verbose = verbose + self.label_distribution = label_distribution + self.wandb_ = wandb_ + self.balanced_sampling = balanced_sampling + self.ranking_loss = ranking_loss + self._device_ids = _device_ids + self._dataloader_num_workers = _dataloader_num_workers + self._amp_enabled = _amp_enabled + self.scaler = torch.cuda.amp.GradScaler() + # self._init_net() + + @_manage_ctx_fit + def fit(self, x_trn, x_vld, y_trn, y_vld, img_train_trans=None, img_vld_trans=None, img_mode=0) -> Self: + # def fit(self, x, y) -> Self: + ''' ... ''' + + # start a new wandb run to track this script + if self.wandb_ == 1: + wandb.init( + # set the wandb project where this run will be logged + project="ADRD_main", + + # track hyperparameters and run metadata + config={ + "Loss": 'Focalloss', + "ranking_loss": self.ranking_loss, + "img architecture": self.img_net, + "EMB": "ALL_SEQ", + "epochs": self.num_epochs, + "d_model": self.d_model, + # 'positional encoding': 'Diff PE', + 'Balanced Sampling': self.balanced_sampling, + 'Shared CNN': 'Yes', + } + ) + wandb.run.log_code(".") + else: + wandb.init(mode="disabled") + # for PyTorch computational efficiency + torch.set_num_threads(1) + # print(img_train_trans) + # initialize neural network + print(self.criterion) + print(f"Ranking loss: {self.ranking_loss}") + print(f"Batch size: {self.batch_size}") + print(f"Batch size multiplier: {self.batch_size_multiplier}") + + if img_mode in [0,1,2]: + for k, info in self.src_modalities.items(): + if info['type'] == 'imaging': + if 'densenet' in self.img_net.lower() and 'emb' not in self.img_net.lower(): + info['shape'] = (1,) + self.img_size + info['img_shape'] = (1,) + self.img_size + elif 'emb' not in self.img_net.lower(): + info['shape'] = (1,) + (self.img_size,) * 3 + info['img_shape'] = (1,) + (self.img_size,) * 3 + elif 'swinunetr' in self.img_net.lower(): + info['shape'] = (1, 768, 4, 4, 4) + info['img_shape'] = (1, 768, 4, 4, 4) + + + + self._init_net() + ldr_trn, ldr_vld = self._init_dataloader(x_trn, x_vld, y_trn, y_vld, img_train_trans=img_train_trans, img_vld_trans=img_vld_trans) + + # initialize optimizer and scheduler + if not self.load_from_ckpt: + self.optimizer = self._init_optimizer() + self.scheduler = self._init_scheduler(self.optimizer) + + # gradient scaler for AMP + if self._amp_enabled: + self.scaler = torch.cuda.amp.GradScaler() + + # initialize the focal losses + self.loss_fn = {} + + for k in self.tgt_modalities: + if self.label_fractions[k] >= 0.3: + alpha = -1 + else: + alpha = pow((1 - self.label_fractions[k]), 2) + # alpha = -1 + self.loss_fn[k] = nn.SigmoidFocalLoss( + alpha = alpha, + gamma = self.gamma, + reduction = 'none' + ) + + # to record the best validation performance criterion + if self.criterion is not None: + best_crit = None + best_crit_AUPR = None + + # progress bar for epoch loops + if self.verbose == 1: + with self._lock if self._lock is not None else suppress(): + pbr_epoch = tqdm.tqdm( + desc = 'Rank {:02d}'.format(self._rank), + total = self.num_epochs, + position = self._rank, + ascii = True, + leave = False, + bar_format='{l_bar}{r_bar}' + ) + + self.skip_embedding = {} + for k, info in self.src_modalities.items(): + # if info['type'] == 'imaging': + # if not self.img_net: + # self.skip_embedding[k] = True + # else: + self.skip_embedding[k] = False + + self.grad_list = [] + # Define a hook function to print and store the gradient of a layer + def print_and_store_grad(grad): + self.grad_list.append(grad) + # print(grad) + + + # initialize the ranking loss + self.lambda_coeff = 0.005 + self.margin = 0.25 + self.margin_loss = torch.nn.MarginRankingLoss(reduction='sum', margin=self.margin) + + # training loop + for epoch in range(self.start_epoch, self.num_epochs): + met_trn = self.train_one_epoch(ldr_trn, epoch) + met_vld = self.validate_one_epoch(ldr_vld, epoch) + + print(self.ckpt_path.split('/')[-1]) + + # save the model if it has the best validation performance criterion by far + if self.criterion is None: continue + + # is current criterion better than previous best? + curr_crit = np.mean([met_vld[i][self.criterion] for i in range(len(self.tgt_modalities))]) + curr_crit_AUPR = np.mean([met_vld[i]["AUC (PR)"] for i in range(len(self.tgt_modalities))]) + # AUROC + if best_crit is None or np.isnan(best_crit): + is_better = True + elif self.criterion == 'Loss' and best_crit >= curr_crit: + is_better = True + elif self.criterion != 'Loss' and best_crit <= curr_crit : + is_better = True + else: + is_better = False + + # AUPR + if best_crit_AUPR is None or np.isnan(best_crit_AUPR): + is_better_AUPR = True + elif best_crit_AUPR <= curr_crit_AUPR : + is_better_AUPR = True + else: + is_better_AUPR = False + # update best criterion + if is_better_AUPR: + best_crit_AUPR = curr_crit_AUPR + if self.save_intermediate_ckpts: + print(f"Saving the model to {self.ckpt_path[:-3]}_AUPR.pt...") + self.save(self.ckpt_path[:-3]+"_AUPR.pt", epoch) + if is_better: + best_crit = curr_crit + best_state_dict = deepcopy(self.net_.state_dict()) + if self.save_intermediate_ckpts: + print(f"Saving the model to {self.ckpt_path}...") + self.save(self.ckpt_path, epoch) + + if self.verbose > 2: + print('Best {}: {}'.format(self.criterion, best_crit)) + print('Best {}: {}'.format('AUC (PR)', best_crit_AUPR)) + + if self.verbose == 1: + with self._lock if self._lock is not None else suppress(): + pbr_epoch.update(1) + pbr_epoch.refresh() + + if self.verbose == 1: + with self._lock if self._lock is not None else suppress(): + pbr_epoch.close() + + return self + + def train_one_epoch(self, ldr_trn, epoch): + # progress bar for batch loops + if self.verbose > 1: + pbr_batch = ProgressBar(len(ldr_trn.dataset), 'Epoch {:03d} (TRN)'.format(epoch)) + + # set model to train mode + torch.set_grad_enabled(True) + self.net_.train() + + scores_trn, y_true_trn, y_mask_trn = [], [], [] + losses_trn = [[] for _ in self.tgt_modalities] + iters = len(ldr_trn) + for n_iter, (x_batch, y_batch, mask, y_mask) in enumerate(ldr_trn): + + # mount data to the proper device + x_batch = {k: x_batch[k].to(self.device) for k in x_batch} + y_batch = {k: y_batch[k].to(torch.float).to(self.device) for k in y_batch} + mask = {k: mask[k].to(self.device) for k in mask} + y_mask = {k: y_mask[k].to(self.device) for k in y_mask} + + with torch.autocast( + device_type = 'cpu' if self.device == 'cpu' else 'cuda', + dtype = torch.bfloat16 if self.device == 'cpu' else torch.float16, + enabled = self._amp_enabled, + ): + + outputs = self.net_(x_batch, mask, skip_embedding=self.skip_embedding) + + # calculate multitask loss + loss = 0 + + # for initial 10 epochs, only the focal loss is used for stable training + if self.ranking_loss: + if epoch < 10: + loss = 0 + else: + for i, k in enumerate(self.tgt_modalities): + for ii, kk in enumerate(self.tgt_modalities): + if ii>i: + pairs = (y_mask[k] == 1) & (y_mask[kk] == 1) + total_elements = (torch.abs(y_batch[k][pairs]-y_batch[kk][pairs])).sum() + if total_elements != 0: + loss += self.lambda_coeff * (self.margin_loss(torch.sigmoid(outputs[k])[pairs],torch.sigmoid(outputs[kk][pairs]),y_batch[k][pairs]-y_batch[kk][pairs]))/total_elements + + for i, k in enumerate(self.tgt_modalities): + loss_task = self.loss_fn[k](outputs[k], y_batch[k]) + msk_loss_task = loss_task * y_mask[k] + msk_loss_mean = msk_loss_task.sum() / y_mask[k].sum() + # msk_loss_mean = msk_loss_task.sum() + loss += msk_loss_mean + losses_trn[i] += msk_loss_task.detach().cpu().numpy().tolist() + + # backward + loss = loss / self.batch_size_multiplier + if self._amp_enabled: + self.scaler.scale(loss).backward() + else: + loss.backward() + + if len(self.grad_list) > 0: + print(len(self.grad_list), len(self.grad_list[-1])) + print(f"Gradient at {n_iter}: {self.grad_list[-1][0]}") + + # print("img_MRI_T1_1 ", self.net_.modules_emb_src.img_MRI_T1_1.img_model.features[0].weight) + # print("img_MRI_T1_1 ", self.net_.modules_emb_src.img_MRI_T1_1.downsample[0].weight) + + # update parameters + if n_iter != 0 and n_iter % self.batch_size_multiplier == 0: + if self._amp_enabled: + self.scaler.step(self.optimizer) + self.scaler.update() + self.optimizer.zero_grad() + else: + self.optimizer.step() + self.optimizer.zero_grad() + + # set self.scheduler + self.scheduler.step(epoch + n_iter / iters) + + ''' TODO: change array to dictionary later ''' + outputs = torch.stack(list(outputs.values()), dim=1) + y_batch = torch.stack(list(y_batch.values()), dim=1) + y_mask = torch.stack(list(y_mask.values()), dim=1) + + # save outputs to evaluate performance later + scores_trn.append(outputs.detach().to(torch.float).cpu()) + y_true_trn.append(y_batch.cpu()) + y_mask_trn.append(y_mask.cpu()) + + # update progress bar + if self.verbose > 1: + batch_size = len(next(iter(x_batch.values()))) + pbr_batch.update(batch_size, {}) + pbr_batch.refresh() + + # clear cuda cache + if "cuda" in self.device: + torch.cuda.empty_cache() + + # for better tqdm progress bar display + if self.verbose > 1: + pbr_batch.close() + + # calculate and print training performance metrics + scores_trn = torch.cat(scores_trn) + y_true_trn = torch.cat(y_true_trn) + y_mask_trn = torch.cat(y_mask_trn) + y_pred_trn = (scores_trn > 0).to(torch.int) + y_prob_trn = torch.sigmoid(scores_trn) + met_trn = get_metrics_multitask( + y_true_trn.numpy(), + y_pred_trn.numpy(), + y_prob_trn.numpy(), + y_mask_trn.numpy() + ) + + # add loss to metrics + for i in range(len(self.tgt_modalities)): + met_trn[i]['Loss'] = np.mean(losses_trn[i]) + + # log metrics to wandb + wandb.log({f"Train loss {list(self.tgt_modalities)[i]}": met_trn[i]['Loss'] for i in range(len(self.tgt_modalities))}, step=epoch) + wandb.log({f"Train Balanced Accuracy {list(self.tgt_modalities)[i]}": met_trn[i]['Balanced Accuracy'] for i in range(len(self.tgt_modalities))}, step=epoch) + + wandb.log({f"Train AUC (ROC) {list(self.tgt_modalities)[i]}": met_trn[i]['AUC (ROC)'] for i in range(len(self.tgt_modalities))}, step=epoch) + wandb.log({f"Train AUPR {list(self.tgt_modalities)[i]}": met_trn[i]['AUC (PR)'] for i in range(len(self.tgt_modalities))}, step=epoch) + + if self.verbose > 2: + print_metrics_multitask(met_trn) + + return met_trn + + def validate_one_epoch(self, ldr_vld, epoch): + # # progress bar for validation + if self.verbose > 1: + pbr_batch = ProgressBar(len(ldr_vld.dataset), 'Epoch {:03d} (VLD)'.format(epoch)) + + # set model to validation mode + torch.set_grad_enabled(False) + self.net_.eval() + + scores_vld, y_true_vld, y_mask_vld = [], [], [] + losses_vld = [[] for _ in self.tgt_modalities] + for x_batch, y_batch, mask, y_mask in ldr_vld: + # if len(next(iter(x_batch.values()))) < self.batch_size: + # break + # mount data to the proper device + x_batch = {k: x_batch[k].to(self.device) for k in x_batch} # if 'img' not in k} + # x_img_batch = {k: x_img_batch[k].to(self.device) for k in x_img_batch} + y_batch = {k: y_batch[k].to(torch.float).to(self.device) for k in y_batch} + mask = {k: mask[k].to(self.device) for k in mask} + y_mask = {k: y_mask[k].to(self.device) for k in y_mask} + + # forward + with torch.autocast( + device_type = 'cpu' if self.device == 'cpu' else 'cuda', + dtype = torch.bfloat16 if self.device == 'cpu' else torch.float16, + enabled = self._amp_enabled + ): + + outputs = self.net_(x_batch, mask, skip_embedding=self.skip_embedding) + + # calculate multitask loss + for i, k in enumerate(self.tgt_modalities): + loss_task = self.loss_fn[k](outputs[k], y_batch[k]) + msk_loss_task = loss_task * y_mask[k] + losses_vld[i] += msk_loss_task.detach().cpu().numpy().tolist() + + ''' TODO: change array to dictionary later ''' + outputs = torch.stack(list(outputs.values()), dim=1) + y_batch = torch.stack(list(y_batch.values()), dim=1) + y_mask = torch.stack(list(y_mask.values()), dim=1) + + # save outputs to evaluate performance later + scores_vld.append(outputs.detach().to(torch.float).cpu()) + y_true_vld.append(y_batch.cpu()) + y_mask_vld.append(y_mask.cpu()) + + # update progress bar + if self.verbose > 1: + batch_size = len(next(iter(x_batch.values()))) + pbr_batch.update(batch_size, {}) + pbr_batch.refresh() + + # clear cuda cache + if "cuda" in self.device: + torch.cuda.empty_cache() + + # for better tqdm progress bar display + if self.verbose > 1: + pbr_batch.close() + + # calculate and print validation performance metrics + scores_vld = torch.cat(scores_vld) + y_true_vld = torch.cat(y_true_vld) + y_mask_vld = torch.cat(y_mask_vld) + y_pred_vld = (scores_vld > 0).to(torch.int) + y_prob_vld = torch.sigmoid(scores_vld) + met_vld = get_metrics_multitask( + y_true_vld.numpy(), + y_pred_vld.numpy(), + y_prob_vld.numpy(), + y_mask_vld.numpy() + ) + + # add loss to metrics + for i in range(len(self.tgt_modalities)): + met_vld[i]['Loss'] = np.mean(losses_vld[i]) + + wandb.log({f"Validation loss {list(self.tgt_modalities)[i]}": met_vld[i]['Loss'] for i in range(len(self.tgt_modalities))}, step=epoch) + wandb.log({f"Validation Balanced Accuracy {list(self.tgt_modalities)[i]}": met_vld[i]['Balanced Accuracy'] for i in range(len(self.tgt_modalities))}, step=epoch) + + wandb.log({f"Validation AUC (ROC) {list(self.tgt_modalities)[i]}": met_vld[i]['AUC (ROC)'] for i in range(len(self.tgt_modalities))}, step=epoch) + wandb.log({f"Validation AUPR {list(self.tgt_modalities)[i]}": met_vld[i]['AUC (PR)'] for i in range(len(self.tgt_modalities))}, step=epoch) + + if self.verbose > 2: + print_metrics_multitask(met_vld) + + return met_vld + + + def predict_logits(self, + x: list[dict[str, Any]], + _batch_size: int | None = None, + skip_embedding: dict | None = None, + img_transform: Any | None = None, + ) -> list[dict[str, float]]: + ''' + The input x can be a single sample or a list of samples. + ''' + # input validation + check_is_fitted(self) + print(self.device) + + # for PyTorch computational efficiency + torch.set_num_threads(1) + + # set model to eval mode + torch.set_grad_enabled(False) + self.net_.eval() + + # intialize dataset and dataloader object + dat = TransformerTestingDataset(x, self.src_modalities, img_transform=img_transform) + ldr = DataLoader( + dataset = dat, + batch_size = _batch_size if _batch_size is not None else len(x), + shuffle = False, + drop_last = False, + num_workers = 0, + collate_fn = TransformerTestingDataset.collate_fn, + ) + # print("dataloader done") + + # run model and collect results + logits: list[dict[str, float]] = [] + for x_batch, mask in tqdm(ldr): + # mount data to the proper device + # print(x_batch['his_SEX']) + x_batch = {k: x_batch[k].to(self.device) for k in x_batch} + mask = {k: mask[k].to(self.device) for k in mask} + + # forward + output: dict[str, Tensor] = self.net_(x_batch, mask, skip_embedding) + + # convert output from dict-of-list to list of dict, then append + tmp = {k: output[k].tolist() for k in self.tgt_modalities} + tmp = [{k: tmp[k][i] for k in self.tgt_modalities} for i in range(len(next(iter(tmp.values()))))] + logits += tmp + + return logits + + def predict_proba(self, + x: list[dict[str, Any]], + skip_embedding: dict | None = None, + temperature: float = 1.0, + _batch_size: int | None = None, + img_transform: Any | None = None, + ) -> list[dict[str, float]]: + ''' ... ''' + logits = self.predict_logits(x=x, _batch_size=_batch_size, img_transform=img_transform, skip_embedding=skip_embedding) + print("got logits") + return logits, [{k: expit(smp[k] / temperature) for k in self.tgt_modalities} for smp in logits] + + def predict(self, + x: list[dict[str, Any]], + skip_embedding: dict | None = None, + fpr: dict[str, Any] | None = None, + tpr: dict[str, Any] | None = None, + thresholds: dict[str, Any] | None = None, + _batch_size: int | None = None, + img_transform: Any | None = None, + ) -> list[dict[str, int]]: + ''' ... ''' + if fpr is None or tpr is None or thresholds is None: + logits, proba = self.predict_proba(x, _batch_size=_batch_size, img_transform=img_transform, skip_embedding=skip_embedding) + print("got proba") + return logits, proba, [{k: int(smp[k] > 0.5) for k in self.tgt_modalities} for smp in proba] + else: + logits, proba = self.predict_proba(x, _batch_size=_batch_size, img_transform=img_transform, skip_embedding=skip_embedding) + print("got proba") + youden_index = {} + thr = {} + for i, k in enumerate(self.tgt_modalities): + youden_index[k] = tpr[i] - fpr[i] + thr[k] = thresholds[i][np.argmax(youden_index[k])] + # print(thr[k]) + # print(thr) + return logits, proba, [{k: int(smp[k] > thr[k]) for k in self.tgt_modalities} for smp in proba] + + def save(self, filepath: str, epoch: int) -> None: + """Save the model to the given file stream. + + :param filepath: _description_ + :type filepath: str + :param epoch: _description_ + :type epoch: int + """ + check_is_fitted(self) + if self.data_parallel: + state_dict = self.net_.module.state_dict() + else: + state_dict = self.net_.state_dict() + + # attach model hyper parameters + state_dict['src_modalities'] = self.src_modalities + state_dict['tgt_modalities'] = self.tgt_modalities + state_dict['d_model'] = self.d_model + state_dict['nhead'] = self.nhead + state_dict['num_encoder_layers'] = self.num_encoder_layers + state_dict['num_decoder_layers'] = self.num_decoder_layers + state_dict['optimizer'] = self.optimizer + state_dict['img_net'] = self.img_net + state_dict['imgnet_layers'] = self.imgnet_layers + state_dict['img_size'] = self.img_size + state_dict['patch_size'] = self.patch_size + state_dict['imgnet_ckpt'] = self.imgnet_ckpt + state_dict['train_imgnet'] = self.train_imgnet + state_dict['epoch'] = epoch + + if self.scaler is not None: + state_dict['scaler'] = self.scaler.state_dict() + if self.label_distribution: + state_dict['label_distribution'] = self.label_distribution + + torch.save(state_dict, filepath) + + def load(self, filepath: str, map_location: str = 'cpu', img_dict=None) -> None: + """Load a model from the given file stream. + + :param filepath: _description_ + :type filepath: str + :param map_location: _description_, defaults to 'cpu' + :type map_location: str, optional + :param img_dict: _description_, defaults to None + :type img_dict: _type_, optional + """ + # load state_dict + state_dict = torch.load(filepath, map_location=map_location) + + # load data modalities + self.src_modalities: dict[str, dict[str, Any]] = state_dict.pop('src_modalities') + self.tgt_modalities: dict[str, dict[str, Any]] = state_dict.pop('tgt_modalities') + if 'label_distribution' in state_dict: + self.label_distribution: dict[str, dict[int, int]] = state_dict.pop('label_distribution') + if 'optimizer' in state_dict: + self.optimizer = state_dict.pop('optimizer') + + # initialize model + self.d_model = state_dict.pop('d_model') + self.nhead = state_dict.pop('nhead') + self.num_encoder_layers = state_dict.pop('num_encoder_layers') + self.num_decoder_layers = state_dict.pop('num_decoder_layers') + if 'epoch' in state_dict.keys(): + self.start_epoch = state_dict.pop('epoch') + if img_dict is None: + self.img_net = state_dict.pop('img_net') + self.imgnet_layers = state_dict.pop('imgnet_layers') + self.img_size = state_dict.pop('img_size') + self.patch_size = state_dict.pop('patch_size') + self.imgnet_ckpt = state_dict.pop('imgnet_ckpt') + self.train_imgnet = state_dict.pop('train_imgnet') + else: + self.img_net = img_dict['img_net'] + self.imgnet_layers = img_dict['imgnet_layers'] + self.img_size = img_dict['img_size'] + self.patch_size = img_dict['patch_size'] + self.imgnet_ckpt = img_dict['imgnet_ckpt'] + self.train_imgnet = img_dict['train_imgnet'] + state_dict.pop('img_net') + state_dict.pop('imgnet_layers') + state_dict.pop('img_size') + state_dict.pop('patch_size') + state_dict.pop('imgnet_ckpt') + state_dict.pop('train_imgnet') + + for k, info in self.src_modalities.items(): + if info['type'] == 'imaging': + if 'emb' not in self.img_net.lower(): + info['shape'] = (1,) + (self.img_size,) * 3 + info['img_shape'] = (1,) + (self.img_size,) * 3 + elif 'swinunetr' in self.img_net.lower(): + info['shape'] = (1, 768, 4, 4, 4) + info['img_shape'] = (1, 768, 4, 4, 4) + # print(info['shape']) + + self.net_ = Transformer(self.src_modalities, self.tgt_modalities, self.d_model, self.nhead, self.num_encoder_layers, self.num_decoder_layers, self.device, self.cuda_devices, self.img_net, self.imgnet_layers, self.img_size, self.patch_size, self.imgnet_ckpt, self.train_imgnet, self.fusion_stage) + + + if 'scaler' in state_dict and state_dict['scaler']: + self.scaler.load_state_dict(state_dict.pop('scaler')) + self.net_.load_state_dict(state_dict) + check_is_fitted(self) + self.net_.to(self.device) + + def to(self, device: str) -> Self: + """Mount the model to the given device. + + :param device: _description_ + :type device: str + :return: _description_ + :rtype: Self + """ + self.device = device + if hasattr(self, 'model'): self.net_ = self.net_.to(device) + if hasattr(self, 'img_model'): self.img_model = self.img_model.to(device) + return self + + @classmethod + def from_ckpt(cls, filepath: str, device='cpu', img_dict=None) -> Self: + """Create a new ADRD model and load parameters from the checkpoint. + + This is an alternative constructor. + + :param filepath: _description_ + :type filepath: str + :param device: _description_, defaults to 'cpu' + :type device: str, optional + :param img_dict: _description_, defaults to None + :type img_dict: _type_, optional + :return: _description_ + :rtype: Self + """ + obj = cls(None, None, None,device=device) + if device == 'cuda': + obj.device = "{}:{}".format(obj.device, str(obj.cuda_devices[0])) + print(obj.device) + obj.load(filepath, map_location=obj.device, img_dict=img_dict) + return obj + + def _init_net(self): + """ ... """ + # set the device for use + if self.device == 'cuda': + self.device = "{}:{}".format(self.device, str(self.cuda_devices[0])) + print("Device: " + self.device) + + self.start_epoch = 0 + if self.load_from_ckpt: + try: + print("Loading model from checkpoint...") + self.load(self.ckpt_path, map_location=self.device) + except: + print("Cannot load from checkpoint. Initializing new model...") + self.load_from_ckpt = False + + if not self.load_from_ckpt: + self.net_ = nn.Transformer( + src_modalities = self.src_modalities, + tgt_modalities = self.tgt_modalities, + d_model = self.d_model, + nhead = self.nhead, + num_encoder_layers = self.num_encoder_layers, + num_decoder_layers = self.num_decoder_layers, + device = self.device, + cuda_devices = self.cuda_devices, + img_net = self.img_net, + layers = self.imgnet_layers, + img_size = self.img_size, + patch_size = self.patch_size, + imgnet_ckpt = self.imgnet_ckpt, + train_imgnet = self.train_imgnet, + fusion_stage = self.fusion_stage, + ) + + # intialize model parameters using xavier_uniform + for name, p in self.net_.named_parameters(): + if p.dim() > 1: + torch.nn.init.xavier_uniform_(p) + + self.net_.to(self.device) + + # Initialize the number of GPUs + if self.data_parallel and torch.cuda.device_count() > 1: + print("Available", torch.cuda.device_count(), "GPUs!") + self.net_ = torch.nn.DataParallel(self.net_, device_ids=self.cuda_devices) + + # return net + + def _init_dataloader(self, x_trn, x_vld, y_trn, y_vld, img_train_trans=None, img_vld_trans=None): + # initialize dataset and dataloader + if self.balanced_sampling: + dat_trn = Transformer2ndOrderBalancedTrainingDataset( + x_trn, y_trn, + self.src_modalities, + self.tgt_modalities, + dropout_rate = .5, + dropout_strategy = 'permutation', + img_transform=img_train_trans, + ) + else: + dat_trn = TransformerTrainingDataset( + x_trn, y_trn, + self.src_modalities, + self.tgt_modalities, + dropout_rate = .5, + dropout_strategy = 'permutation', + img_transform=img_train_trans, + ) + + dat_vld = TransformerValidationDataset( + x_vld, y_vld, + self.src_modalities, + self.tgt_modalities, + img_transform=img_vld_trans, + ) + + ldr_trn = DataLoader( + dataset = dat_trn, + batch_size = self.batch_size, + shuffle = True, + drop_last = False, + num_workers = self._dataloader_num_workers, + collate_fn = TransformerTrainingDataset.collate_fn, + # pin_memory = True + ) + + ldr_vld = DataLoader( + dataset = dat_vld, + batch_size = self.batch_size, + shuffle = False, + drop_last = False, + num_workers = self._dataloader_num_workers, + collate_fn = TransformerValidationDataset.collate_fn, + # pin_memory = True + ) + + return ldr_trn, ldr_vld + + def _init_optimizer(self): + """ ... """ + params = list(self.net_.parameters()) + return torch.optim.AdamW( + params, + lr = self.lr, + betas = (0.9, 0.98), + weight_decay = self.weight_decay + ) + + def _init_scheduler(self, optimizer): + """ ... """ + + return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( + optimizer=optimizer, + T_0=64, + T_mult=2, + eta_min = 0, + verbose=(self.verbose > 2) + ) + + def _init_loss_func(self, + num_per_cls: dict[str, tuple[int, int]], + ) -> dict[str, Module]: + """ ... """ + return {k: nn.SigmoidFocalLossBeta( + beta = self.beta, + gamma = self.gamma, + num_per_cls = num_per_cls[k], + reduction = 'none', + ) for k in self.tgt_modalities} + + def _proc_fit(self): + """ ... """ diff --git a/adrd/model/calibration.py b/adrd/model/calibration.py new file mode 100755 index 0000000000000000000000000000000000000000..26998b9513a65079c91ff023ec62daec73de2b8e --- /dev/null +++ b/adrd/model/calibration.py @@ -0,0 +1,450 @@ +import numpy as np +from sklearn.base import BaseEstimator +from sklearn.utils.validation import check_is_fitted +from sklearn.linear_model import LogisticRegression +from sklearn.isotonic import IsotonicRegression +from functools import lru_cache +from functools import cached_property +from typing import Self, Any +from pickle import dump +from pickle import load +from abc import ABC, abstractmethod + +from . import ADRDModel +from ..utils import Formatter +from ..utils import MissingMasker + + +def calibration_curve( + y_true: list[int], + y_pred: list[float], + n_bins: int = 10, + ratio: float = 1.0, +) -> tuple[list[float], list[float]]: + """ + Compute true and predicted probabilities for a calibration curve. The method + assumes the inputs come from a binary classifier, and discretize the [0, 1] + interval into bins. + + Note that this function is an alternative to + sklearn.calibration.calibration_curve() which can only estimate the absolute + proportion of positive cases in each bin. + + Parameters + ---------- + y_true : list[int] + True targets. + y_pred : list[float] + Probabilities of the positive class. + n_bins : int, default=10 + Number of bins to discretize the [0, 1] interval. A bigger number + requires more data. Bins with no samples (i.e. without corresponding + values in y_prob) will not be returned, thus the returned arrays may + have less than n_bins values. + ratio : float, default=1.0 + Used to adjust the class balance. + + Returns + ------- + prob_true : list[float] + The proportion of positive samples in each bin. + prob_pred : list[float] + The mean predicted probability in each bin. + """ + # generate "n_bin" intervals + tmp = np.around(np.linspace(0, 1, n_bins + 1), decimals=6) + intvs = [(tmp[i - 1], tmp[i]) for i in range(1, len(tmp))] + + # pair up (pred, true) and group them by intervals + tmp = list(zip(y_pred, y_true)) + intv_pairs = {(l, r): [p for p in tmp if l <= p[0] < r] for l, r in intvs} + + # calculate balanced proportion of POSITIVE cases for each intervel + # along with the balanced averaged predictions + intv_prob_true: dict[tuple, float] = dict() + intv_prob_pred: dict[tuple, float] = dict() + for intv, pairs in intv_pairs.items(): + # number of cases that fall into the interval + n_pairs = len(pairs) + + # it's likely that no predictions fall into the interval + if n_pairs == 0: continue + + # count number of positives and negatives in the interval + n_pos = sum([p[1] for p in pairs]) + n_neg = n_pairs - n_pos + + # calculate adjusted proportion of positives + intv_prob_true[intv] = n_pos / (n_pos + n_neg * ratio) + + # calculate adjusted avg. predictions + sum_pred_pos = sum([p[0] for p in pairs if p[1] == 1]) + sum_pred_neg = sum([p[0] for p in pairs if p[1] == 0]) + intv_prob_pred[intv] = (sum_pred_pos + sum_pred_neg * ratio) + intv_prob_pred[intv] /= (n_pos + n_neg * ratio) + + prob_true = list(intv_prob_true.values()) + prob_pred = list(intv_prob_pred.values()) + return prob_true, prob_pred + + +class CalibrationCore(BaseEstimator): + """ + A wrapper class of multiple regressors to predict the proportions of + positive samples from the predicted probabilities. The method for + calibration can be 'sigmoid' which corresponds to Platt's method (i.e. a + logistic regression model) or 'isotonic' which is a non-parametric approach. + It is not advised to use isotonic calibration with too few calibration + samples (<<1000) since it tends to overfit. + + TODO + ---- + - 'sigmoid' method is not trivial to implement. + """ + def __init__(self, + method: str = 'isotonic', + ) -> None: + """ + Initialization function of CalibrationCore class. + + Parameters + ---------- + method : {'sigmoid', 'isotonic'}, default='isotonic' + The method to use for calibration. can be 'sigmoid' which + corresponds to Platt's method (i.e. a logistic regression model) or + 'isotonic' which is a non-parametric approach. It is not advised to + use isotonic calibration with too few calibration samples (<<1000) + since it tends to overfit. + + Raises + ------ + ValueError + Sigmoid approach has not been implemented. + """ + assert method in ('sigmoid', 'isotonic') + if method == 'sigmoid': + raise ValueError('Sigmoid approach has not been implemented.') + self.method = method + + def fit(self, + prob_pred: list[float], + prob_true: list[float], + ) -> Self: + """ + Fit the underlying regressor using prob_pred, prob_true as training + data. + + Parameters + ---------- + prob_pred : list[float] + Probabilities predicted directly by a model. + prob_true : list[float] + Target probabilities to calibrate to. + + Returns + ------- + Self + CalibrationCore object. + """ + # using Platt's method for calibration + if self.method == 'sigmoid': + self.model_ = LogisticRegression() + self.model_.fit(prob_pred, prob_true) + + # using isotonic calibration + elif self.method == 'isotonic': + self.model_ = IsotonicRegression(y_min=0, y_max=1, out_of_bounds='clip') + self.model_.fit(prob_pred, prob_true) + + return self + + def predict(self, + prob_pred: list[float], + ) -> list[float]: + """ + Calibrate the input probabilities using the fitted regressor. + + Parameters + ---------- + prob_pred : list[float] + Probabilities predicted directly by a model. + + Returns + ------- + prob_cali : list[float] + Calibrated probabilities. + """ + # as usual, the core needs to be fitted + check_is_fitted(self) + + # note that logistic regression is classification model, we need to call + # 'predict_proba' instead of 'predict' to get the calibrated results + if self.method == 'sigmoid': + prob_cali = self.model_.predict_proba(prob_pred) + elif self.method == 'isotonic': + prob_cali = self.model_.predict(prob_pred) + + return prob_cali + + +class CalibratedClassifier(ABC): + """ + Abstract class of calibrated classifier. + """ + def __init__(self, + model: ADRDModel, + background_src: list[dict[str, Any]], + background_tgt: list[dict[str, Any]], + background_is_embedding: dict[str, bool] | None = None, + method: str = 'isotonic', + ) -> None: + """ + Constructor of Calibrator class. + + Parameters + ---------- + model : ADRDModel + Fitted model to calibrate. + background_src : list[dict[str, Any]] + Features of the background dataset. + background_tgt : list[dict[str, Any]] + Labels of the background dataset. + method : {'sigmoid', 'isotonic'}, default='isotonic' + Method used by the underlying regressor. + """ + self.method = method + self.model = model + self.src_modalities = model.src_modalities + self.tgt_modalities = model.tgt_modalities + self.background_is_embedding = background_is_embedding + + # format background data + fmt_src = Formatter(self.src_modalities) + fmt_tgt = Formatter(self.tgt_modalities) + self.background_src = [fmt_src(smp) for smp in background_src] + self.background_tgt = [fmt_tgt(smp) for smp in background_tgt] + + @abstractmethod + def predict_proba(self, + src: list[dict[str, Any]], + is_embedding: dict[str, bool] | None = None, + ) -> list[dict[str, float]]: + """ + This method returns calibrated probabilities of classification. + + Parameters + ---------- + src : list[dict[str, Any]] + Features of the input samples. + + Returns + ------- + list[dict[str, float]] + Calibrated probabilities. + """ + pass + + def predict(self, + src: list[dict[str, Any]], + is_embedding: dict[str, bool] | None = None, + ) -> list[dict[str, int]]: + """ + Make predictions based on the results of predict_proba(). + + Parameters + ---------- + x : list[dict[str, Any]] + Input features. + + Returns + ------- + list[dict[str, int]] + Calibrated predictions. + """ + proba = self.predict_proba(src, is_embedding) + return [{k: int(smp[k] > 0.5) for k in self.tgt_modalities} for smp in proba] + + def save(self, + filepath_state_dict: str, + ) -> None: + """ + Save the state dict and the underlying model to the given paths. + + Parameters + ---------- + filepath_state_dict : str + File path to save the state_dict which includes the background + dataset and the regressor information. + filepath_wrapped_model : str | None, default=None + File path to save the wrapped model. If None, the model won't be + saved. + """ + # save state dict + state_dict = { + 'background_src': self.background_src, + 'background_tgt': self.background_tgt, + 'background_is_embedding': self.background_is_embedding, + 'method': self.method, + } + with open(filepath_state_dict, 'wb') as f: + dump(state_dict, f) + + @classmethod + def from_ckpt(cls, + filepath_state_dict: str, + filepath_wrapped_model: str, + ) -> Self: + """ + Alternative constructor which loads from checkpoint. + + Parameters + ---------- + filepath_state_dict : str + File path to load the state_dict which includes the background + dataset and the regressor information. + filepath_wrapped_model : str + File path of the wrapped model. + + Returns + ------- + Self + CalibratedClassifier class object. + """ + with open(filepath_state_dict, 'rb') as f: + kwargs = load(f) + kwargs['model'] = ADRDModel.from_ckpt(filepath_wrapped_model) + return cls(**kwargs) + + +class DynamicCalibratedClassifier(CalibratedClassifier): + """ + The dynamic approach generates background predictions based on the + missingness pattern of each input. With an astronomical number of + missingness patterns, calibrating each sample requires a comprehensive + process that involves running the ADRDModel on the majority of the + background data and training a corresponding regressor. This results in a + computationally intensive calculation. + """ + def predict_proba(self, + src: list[dict[str, Any]], + is_embedding: dict[str, bool] | None = None, + ) -> list[dict[str, float]]: + + # initialize mask generator and format inputs + msk_gen = MissingMasker(self.src_modalities) + fmt_src = Formatter(self.src_modalities) + src = [fmt_src(smp) for smp in src] + + # calculate calibrated probabilities + calibrated_prob: list[dict[str, float]] = [] + for smp in src: + # model output and missingness pattern + prob = self.model.predict_proba([smp], is_embedding)[0] + mask = tuple(msk_gen(smp).values()) + + # get/fit core and calculate calibrated probabilities + core = self._fit_core(mask) + calibrated_prob.append({k: core[k].predict([prob[k]])[0] for k in self.tgt_modalities}) + + return calibrated_prob + + # @lru_cache(maxsize = None) + def _fit_core(self, + missingness_pattern: tuple[bool], + ) -> dict[str, CalibrationCore]: + ''' ... ''' + # remove features from all background samples accordingly + background_src, background_tgt = [], [] + for src, tgt in zip(self.background_src, self.background_tgt): + src = {k: v for j, (k, v) in enumerate(src.items()) if missingness_pattern[j] == False} + + # make sure there is at least one feature available + if len([v is not None for v in src.values()]) == 0: continue + background_src.append(src) + background_tgt.append(tgt) + + # run model on background samples and collection predictions + background_prob = self.model.predict_proba(background_src, self.background_is_embedding, _batch_size=1024) + + # list[dict] -> dict[list] + N = len(background_src) + background_prob = {k: [background_prob[i][k] for i in range(N)] for k in self.tgt_modalities} + background_true = {k: [background_tgt[i][k] for i in range(N)] for k in self.tgt_modalities} + + # now, fit cores + core: dict[str, CalibrationCore] = dict() + for k in self.tgt_modalities: + prob_true, prob_pred = calibration_curve( + background_true[k], background_prob[k], + ratio = self.background_ratio[k], + ) + core[k] = CalibrationCore(self.method).fit(prob_pred, prob_true) + + return core + + @cached_property + def background_ratio(self) -> dict[str, float]: + ''' The ratio of positives over negatives in the background dataset. ''' + return {k: self.background_n_pos[k] / self.background_n_neg[k] for k in self.tgt_modalities} + + @cached_property + def background_n_pos(self) -> dict[str, int]: + ''' Number of positives w.r.t each target in the background dataset. ''' + return {k: sum([d[k] for d in self.background_tgt]) for k in self.tgt_modalities} + + @cached_property + def background_n_neg(self) -> dict[str, int]: + ''' Number of negatives w.r.t each target in the background dataset. ''' + return {k: len(self.background_tgt) - self.background_n_pos[k] for k in self.tgt_modalities} + + +class StaticCalibratedClassifier(CalibratedClassifier): + """ + The static approach generates background predictions without considering the + missingness patterns. + """ + def predict_proba(self, + src: list[dict[str, Any]], + is_embedding: dict[str, bool] | None = None, + ) -> list[dict[str, float]]: + + # number of input samples + N = len(src) + + # format inputs, and run ADRDModel, and convert to dict[list] + fmt_src = Formatter(self.src_modalities) + src = [fmt_src(smp) for smp in src] + prob = self.model.predict_proba(src, is_embedding) + prob = {k: [prob[i][k] for i in range(N)] for k in self.tgt_modalities} + + # calibrate probabilities + core = self._fit_core() + calibrated_prob = {k: core[k].predict(prob[k]) for k in self.tgt_modalities} + + # convert back to list[dict] + calibrated_prob: list[dict[str, float]] = [ + {k: calibrated_prob[k][i] for k in self.tgt_modalities} for i in range(N) + ] + return calibrated_prob + + @lru_cache(maxsize = None) + def _fit_core(self) -> dict[str, CalibrationCore]: + ''' ... ''' + # run model on background samples and collection predictions + background_prob = self.model.predict_proba(self.background_src, self.background_is_embedding, _batch_size=1024) + + # list[dict] -> dict[list] + N = len(self.background_src) + background_prob = {k: [background_prob[i][k] for i in range(N)] for k in self.tgt_modalities} + background_true = {k: [self.background_tgt[i][k] for i in range(N)] for k in self.tgt_modalities} + + # now, fit cores + core: dict[str, CalibrationCore] = dict() + for k in self.tgt_modalities: + prob_true, prob_pred = calibration_curve( + background_true[k], background_prob[k], + ratio = 1.0, + ) + core[k] = CalibrationCore(self.method).fit(prob_pred, prob_true) + + return core \ No newline at end of file diff --git a/adrd/model/cnn_resnet3d_with_linear_classifier.py b/adrd/model/cnn_resnet3d_with_linear_classifier.py new file mode 100755 index 0000000000000000000000000000000000000000..51d5cc227b6de56d7e6d97e233ba2bfcc75b6a40 --- /dev/null +++ b/adrd/model/cnn_resnet3d_with_linear_classifier.py @@ -0,0 +1,533 @@ +__all__ = ['CNNResNet3DWithLinearClassifier'] + +import torch +from torch.utils.data import DataLoader +import numpy as np +import tqdm +from sklearn.base import BaseEstimator +from sklearn.utils.validation import check_is_fitted +from sklearn.model_selection import train_test_split +from scipy.special import expit +from copy import deepcopy +from contextlib import suppress +from typing import Any, Self, Type +from functools import wraps +Tensor = Type[torch.Tensor] +Module = Type[torch.nn.Module] + +from ..utils.misc import ProgressBar +from ..utils.misc import get_metrics_multitask, print_metrics_multitask + +from .. import nn +from ..utils import TransformerTrainingDataset +from ..utils import Transformer2ndOrderBalancedTrainingDataset +from ..utils import TransformerValidationDataset +from ..utils import TransformerTestingDataset +from ..utils.misc import ProgressBar +from ..utils.misc import get_metrics_multitask, print_metrics_multitask +from ..utils.misc import convert_args_kwargs_to_kwargs + + +def _manage_ctx_fit(func): + ''' ... ''' + @wraps(func) + def wrapper(*args, **kwargs): + # format arguments + kwargs = convert_args_kwargs_to_kwargs(func, args, kwargs) + + if kwargs['self']._device_ids is None: + return func(**kwargs) + else: + # change primary device + default_device = kwargs['self'].device + kwargs['self'].device = kwargs['self']._device_ids[0] + rtn = func(**kwargs) + + # the actual module is wrapped + kwargs['self'].net_ = kwargs['self'].net_.module + kwargs['self'].to(default_device) + return rtn + return wrapper + + +class CNNResNet3DWithLinearClassifier(BaseEstimator): + + def __init__(self, + src_modalities: dict[str, dict[str, Any]], + tgt_modalities: dict[str, dict[str, Any]], + num_epochs: int = 32, + batch_size: int = 8, + batch_size_multiplier: int = 1, + lr: float = 1e-2, + weight_decay: float = 0.0, + beta: float = 0.9999, + gamma: float = 2.0, + scale: float = 1.0, + criterion: str | None = None, + device: str = 'cpu', + verbose: int = 0, + _device_ids: list | None = None, + _dataloader_num_workers: int = 0, + _amp_enabled: bool = False, + _tmp_ckpt_filepath: str | None = None, + ) -> None: + ''' ... ''' + # for multiprocessing + self._rank = 0 + self._lock = None + + # positional parameters + self.src_modalities = src_modalities + self.tgt_modalities = tgt_modalities + + # training parameters + self.num_epochs = num_epochs + self.batch_size = batch_size + self.batch_size_multiplier = batch_size_multiplier + self.lr = lr + self.weight_decay = weight_decay + self.beta = beta + self.gamma = gamma + self.scale = scale + self.criterion = criterion + self.device = device + self.verbose = verbose + self._device_ids = _device_ids + self._dataloader_num_workers = _dataloader_num_workers + self._amp_enabled = _amp_enabled + self._tmp_ckpt_filepath = _tmp_ckpt_filepath + + + @_manage_ctx_fit + def fit(self, x, y) -> Self: + ''' ... ''' + # for PyTorch computational efficiency + torch.set_num_threads(1) + + # initialize neural network + self.net_ = self._init_net() + + # initialize dataloaders + ldr_trn, ldr_vld = self._init_dataloader(x, y) + + # initialize optimizer and scheduler + optimizer = self._init_optimizer() + scheduler = self._init_scheduler(optimizer) + + # gradient scaler for AMP + if self._amp_enabled: scaler = torch.cuda.amp.GradScaler() + + # initialize loss function (binary cross entropy) + loss_func = self._init_loss_func({ + k: ( + sum([_[k] == 0 for _ in ldr_trn.dataset.tgt]), + sum([_[k] == 1 for _ in ldr_trn.dataset.tgt]), + ) for k in self.tgt_modalities + }) + + # to record the best validation performance criterion + if self.criterion is not None: best_crit = None + + # progress bar for epoch loops + if self.verbose == 1: + with self._lock if self._lock is not None else suppress(): + pbr_epoch = tqdm.tqdm( + desc = 'Rank {:02d}'.format(self._rank), + total = self.num_epochs, + position = self._rank, + ascii = True, + leave = False, + bar_format='{l_bar}{r_bar}' + ) + + # training loop + for epoch in range(self.num_epochs): + # progress bar for batch loops + if self.verbose > 1: + pbr_batch = ProgressBar(len(ldr_trn.dataset), 'Epoch {:03d} (TRN)'.format(epoch)) + + # set model to train mode + torch.set_grad_enabled(True) + self.net_.train() + + scores_trn: dict[str, list[float]] = {k: [] for k in self.tgt_modalities} + y_true_trn: dict[str, list[int]] = {k: [] for k in self.tgt_modalities} + losses_trn: dict[str, list[float]] = {k: [] for k in self.tgt_modalities} + for n_iter, (x_batch, y_batch, _, mask_y) in enumerate(ldr_trn): + # mount data to the proper device + x_batch = {k: x_batch[k].to(self.device) for k in self.src_modalities} + y_batch = {k: y_batch[k].to(torch.float).to(self.device) for k in self.tgt_modalities} + # mask_x = {k: mask_x[k].to(self.device) for k in self.src_modalities} + mask_y = {k: mask_y[k].to(self.device) for k in self.tgt_modalities} + + # forward + with torch.autocast( + device_type = 'cpu' if self.device == 'cpu' else 'cuda', + dtype = torch.bfloat16 if self.device == 'cpu' else torch.float16, + enabled = self._amp_enabled, + ): + outputs = self.net_(x_batch) + + # calculate multitask loss + loss = 0 + for i, tgt_k in enumerate(self.tgt_modalities): + loss_k = loss_func[tgt_k](outputs[tgt_k], y_batch[tgt_k]) + loss_k = torch.masked_select(loss_k, torch.logical_not(mask_y[tgt_k].squeeze())) + loss += loss_k.mean() + losses_trn[tgt_k] += loss_k.detach().cpu().numpy().tolist() + + # backward + if self._amp_enabled: + scaler.scale(loss).backward() + else: + loss.backward() + + # update parameters + if n_iter != 0 and n_iter % self.batch_size_multiplier == 0: + if self._amp_enabled: + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + else: + optimizer.step() + optimizer.zero_grad() + + # save outputs to evaluate performance later + for tgt_k in self.tgt_modalities: + tmp = torch.masked_select(outputs[tgt_k], torch.logical_not(mask_y[tgt_k].squeeze())) + scores_trn[tgt_k] += tmp.detach().cpu().numpy().tolist() + tmp = torch.masked_select(y_batch[tgt_k], torch.logical_not(mask_y[tgt_k].squeeze())) + y_true_trn[tgt_k] += tmp.cpu().numpy().tolist() + + # update progress bar + if self.verbose > 1: + batch_size = len(next(iter(x_batch.values()))) + pbr_batch.update(batch_size, {}) + pbr_batch.refresh() + + # for better tqdm progress bar display + if self.verbose > 1: + pbr_batch.close() + + # set scheduler + scheduler.step() + + # calculate and print training performance metrics + y_pred_trn: dict[str, list[int]] = {k: [] for k in self.tgt_modalities} + y_prob_trn: dict[str, list[float]] = {k: [] for k in self.tgt_modalities} + for tgt_k in self.tgt_modalities: + for i in range(len(scores_trn[tgt_k])): + y_pred_trn[tgt_k].append(1 if scores_trn[tgt_k][i] > 0 else 0) + y_prob_trn[tgt_k].append(expit(scores_trn[tgt_k][i])) + met_trn = get_metrics_multitask(y_true_trn, y_pred_trn, y_prob_trn) + + # add loss to metrics + for tgt_k in self.tgt_modalities: + met_trn[tgt_k]['Loss'] = np.mean(losses_trn[tgt_k]) + + if self.verbose > 2: + print_metrics_multitask(met_trn) + + # progress bar for validation + if self.verbose > 1: + pbr_batch = ProgressBar(len(ldr_vld.dataset), 'Epoch {:03d} (VLD)'.format(epoch)) + + # set model to validation mode + torch.set_grad_enabled(False) + self.net_.eval() + + scores_vld: dict[str, list[float]] = {k: [] for k in self.tgt_modalities} + y_true_vld: dict[str, list[int]] = {k: [] for k in self.tgt_modalities} + losses_vld: dict[str, list[float]] = {k: [] for k in self.tgt_modalities} + for x_batch, y_batch, _, mask_y in ldr_vld: + # mount data to the proper device + x_batch = {k: x_batch[k].to(self.device) for k in self.src_modalities} + y_batch = {k: y_batch[k].to(torch.float).to(self.device) for k in self.tgt_modalities} + # mask_x = {k: mask_x[k].to(self.device) for k in self.src_modalities} + mask_y = {k: mask_y[k].to(self.device) for k in self.tgt_modalities} + + # forward + with torch.autocast( + device_type = 'cpu' if self.device == 'cpu' else 'cuda', + dtype = torch.bfloat16 if self.device == 'cpu' else torch.float16, + enabled = self._amp_enabled + ): + outputs = self.net_(x_batch) + + # calculate multitask loss + for i, tgt_k in enumerate(self.tgt_modalities): + loss_k = loss_func[tgt_k](outputs[tgt_k], y_batch[tgt_k]) + loss_k = torch.masked_select(loss_k, torch.logical_not(mask_y[tgt_k].squeeze())) + losses_vld[tgt_k] += loss_k.detach().cpu().numpy().tolist() + + # save outputs to evaluate performance later + for tgt_k in self.tgt_modalities: + tmp = torch.masked_select(outputs[tgt_k], torch.logical_not(mask_y[tgt_k].squeeze())) + scores_vld[tgt_k] += tmp.detach().cpu().numpy().tolist() + tmp = torch.masked_select(y_batch[tgt_k], torch.logical_not(mask_y[tgt_k].squeeze())) + y_true_vld[tgt_k] += tmp.cpu().numpy().tolist() + + # update progress bar + if self.verbose > 1: + batch_size = len(next(iter(x_batch.values()))) + pbr_batch.update(batch_size, {}) + pbr_batch.refresh() + + # for better tqdm progress bar display + if self.verbose > 1: + pbr_batch.close() + + # calculate and print validation performance metrics + y_pred_vld: dict[str, list[int]] = {k: [] for k in self.tgt_modalities} + y_prob_vld: dict[str, list[float]] = {k: [] for k in self.tgt_modalities} + for tgt_k in self.tgt_modalities: + for i in range(len(scores_vld[tgt_k])): + y_pred_vld[tgt_k].append(1 if scores_vld[tgt_k][i] > 0 else 0) + y_prob_vld[tgt_k].append(expit(scores_vld[tgt_k][i])) + met_vld = get_metrics_multitask(y_true_vld, y_pred_vld, y_prob_vld) + + # add loss to metrics + for tgt_k in self.tgt_modalities: + met_vld[tgt_k]['Loss'] = np.mean(losses_vld[tgt_k]) + + if self.verbose > 2: + print_metrics_multitask(met_vld) + + # save the model if it has the best validation performance criterion by far + if self.criterion is None: continue + + # is current criterion better than previous best? + curr_crit = np.mean([met_vld[k][self.criterion] for k in self.tgt_modalities]) + if best_crit is None or np.isnan(best_crit): + is_better = True + elif self.criterion == 'Loss' and best_crit >= curr_crit: + is_better = True + elif self.criterion != 'Loss' and best_crit <= curr_crit: + is_better = True + else: + is_better = False + + # update best criterion + if is_better: + best_crit = curr_crit + best_state_dict = deepcopy(self.net_.state_dict()) + + if self._tmp_ckpt_filepath is not None: + self.save(self._tmp_ckpt_filepath) + + if self.verbose > 2: + print('Best {}: {}'.format(self.criterion, best_crit)) + + if self.verbose == 1: + with self._lock if self._lock is not None else suppress(): + pbr_epoch.update(1) + pbr_epoch.refresh() + + if self.verbose == 1: + with self._lock if self._lock is not None else suppress(): + pbr_epoch.close() + + # restore the model of the best validation performance across all epoches + if ldr_vld is not None and self.criterion is not None: + self.net_.load_state_dict(best_state_dict) + + return self + + def predict_logits(self, + x: list[dict[str, Any]], + _batch_size: int | None = None, + ) -> list[dict[str, float]]: + """ + The input x can be a single sample or a list of samples. + """ + # input validation + check_is_fitted(self) + + # for PyTorch computational efficiency + torch.set_num_threads(1) + + # set model to eval mode + torch.set_grad_enabled(False) + self.net_.eval() + + # intialize dataset and dataloader object + dat = TransformerTestingDataset(x, self.src_modalities) + ldr = DataLoader( + dataset = dat, + batch_size = _batch_size if _batch_size is not None else len(x), + shuffle = False, + drop_last = False, + num_workers = 0, + collate_fn = TransformerTestingDataset.collate_fn, + ) + + # run model and collect results + logits: list[dict[str, float]] = [] + for x_batch, _ in ldr: + # mount data to the proper device + x_batch = {k: x_batch[k].to(self.device) for k in self.src_modalities} + + # forward + output: dict[str, Tensor] = self.net_(x_batch) + + # convert output from dict-of-list to list of dict, then append + tmp = {k: output[k].tolist() for k in self.tgt_modalities} + tmp = [{k: tmp[k][i] for k in self.tgt_modalities} for i in range(len(next(iter(tmp.values()))))] + logits += tmp + + return logits + + def predict_proba(self, + x: list[dict[str, Any]], + temperature: float = 1.0, + _batch_size: int | None = None, + ) -> list[dict[str, float]]: + ''' ... ''' + logits = self.predict_logits(x, _batch_size) + return [{k: expit(smp[k] / temperature) for k in self.tgt_modalities} for smp in logits] + + def predict(self, + x: list[dict[str, Any]], + _batch_size: int | None = None, + ) -> list[dict[str, int]]: + ''' ... ''' + logits = self.predict_logits(x, _batch_size) + return [{k: int(smp[k] > 0.0) for k in self.tgt_modalities} for smp in logits] + + def save(self, filepath: str) -> None: + ''' ... ''' + check_is_fitted(self) + state_dict = self.net_.state_dict() + + # attach model hyper parameters + state_dict['src_modalities'] = self.src_modalities + state_dict['tgt_modalities'] = self.tgt_modalities + print('Saving model checkpoint to {} ... '.format(filepath), end='') + torch.save(state_dict, filepath) + print('Done.') + + def load(self, filepath: str) -> None: + ''' ... ''' + # load state_dict + state_dict = torch.load(filepath, map_location='cpu') + + # load essential parameters + self.src_modalities: dict[str, dict[str, Any]] = state_dict.pop('src_modalities') + self.tgt_modalities: dict[str, dict[str, Any]] = state_dict.pop('tgt_modalities') + + # initialize model + self.net_ = nn.CNNResNet3DWithLinearClassifier( + self.src_modalities, + self.tgt_modalities, + ) + + # load model parameters + self.net_.load_state_dict(state_dict) + self.to(self.device) + + def to(self, device: str) -> Self: + ''' Mount model to the given device. ''' + self.device = device + if hasattr(self, 'net_'): self.net_ = self.net_.to(device) + return self + + @classmethod + def from_ckpt(cls, filepath: str) -> Self: + ''' ... ''' + obj = cls(None, None) + obj.load(filepath) + return obj + + def _init_net(self): + """ ... """ + net = nn.CNNResNet3DWithLinearClassifier( + self.src_modalities, + self.tgt_modalities, + ).to(self.device) + + # train on multiple GPUs using torch.nn.DataParallel + if self._device_ids is not None: + net = torch.nn.DataParallel(net, device_ids=self._device_ids) + + # intialize model parameters using xavier_uniform + for p in net.parameters(): + if p.dim() > 1: + torch.nn.init.xavier_uniform_(p) + + return net + + def _init_dataloader(self, x, y): + """ ... """ + # split dataset + x_trn, x_vld, y_trn, y_vld = train_test_split( + x, y, test_size = 0.2, random_state = 0, + ) + + # initialize dataset and dataloader + # dat_trn = TransformerTrainingDataset( + dat_trn = Transformer2ndOrderBalancedTrainingDataset( + x_trn, y_trn, + self.src_modalities, + self.tgt_modalities, + dropout_rate = .5, + # dropout_strategy = 'compensated', + dropout_strategy = 'permutation', + ) + + dat_vld = TransformerValidationDataset( + x_vld, y_vld, + self.src_modalities, + self.tgt_modalities, + ) + + ldr_trn = DataLoader( + dataset = dat_trn, + batch_size = self.batch_size, + shuffle = True, + drop_last = False, + num_workers = self._dataloader_num_workers, + collate_fn = TransformerTrainingDataset.collate_fn, + # pin_memory = True + ) + + ldr_vld = DataLoader( + dataset = dat_vld, + batch_size = self.batch_size, + shuffle = False, + drop_last = False, + num_workers = self._dataloader_num_workers, + collate_fn = TransformerValidationDataset.collate_fn, + # pin_memory = True + ) + + return ldr_trn, ldr_vld + + def _init_optimizer(self): + """ ... """ + return torch.optim.AdamW( + self.net_.parameters(), + lr = self.lr, + betas = (0.9, 0.98), + weight_decay = self.weight_decay + ) + + def _init_scheduler(self, optimizer): + """ ... """ + return torch.optim.lr_scheduler.OneCycleLR( + optimizer = optimizer, + max_lr = self.lr, + total_steps = self.num_epochs, + verbose = (self.verbose > 2) + ) + + def _init_loss_func(self, + num_per_cls: dict[str, tuple[int, int]], + ) -> dict[str, Module]: + """ ... """ + return {k: nn.SigmoidFocalLoss( + beta = self.beta, + gamma = self.gamma, + scale = self.scale, + num_per_cls = num_per_cls[k], + reduction = 'none', + ) for k in self.tgt_modalities} \ No newline at end of file diff --git a/adrd/model/imaging_model.py b/adrd/model/imaging_model.py new file mode 100755 index 0000000000000000000000000000000000000000..238896ec079ce23c21ac4cb18cb5916bd28fc828 --- /dev/null +++ b/adrd/model/imaging_model.py @@ -0,0 +1,843 @@ +__all__ = ['Transformer'] + +import wandb +import torch +import numpy as np +import functools +import inspect +import monai +import random + +from tqdm import tqdm +from functools import wraps +from sklearn.base import BaseEstimator +from sklearn.utils.validation import check_is_fitted +from sklearn.model_selection import train_test_split +from scipy.special import expit +from copy import deepcopy +from contextlib import suppress +from typing import Any, Self, Type +Tensor = Type[torch.Tensor] +Module = Type[torch.nn.Module] +from torch.utils.data import DataLoader +from monai.utils.type_conversion import convert_to_tensor +from monai.transforms import ( + LoadImaged, + Compose, + CropForegroundd, + CopyItemsd, + SpatialPadd, + EnsureChannelFirstd, + Spacingd, + OneOf, + ScaleIntensityRanged, + HistogramNormalized, + RandSpatialCropSamplesd, + RandSpatialCropd, + CenterSpatialCropd, + RandCoarseDropoutd, + RandCoarseShuffled, + Resized, +) + +# for DistributedDataParallel +import torch.distributed as dist +import torch.multiprocessing as mp +from torch.nn.parallel import DistributedDataParallel as DDP + +from .. import nn +from ..utils.misc import ProgressBar +from ..utils.misc import get_metrics_multitask, print_metrics_multitask +from ..utils.misc import convert_args_kwargs_to_kwargs + +import warnings +warnings.filterwarnings("ignore") + + +def _manage_ctx_fit(func): + ''' ... ''' + @wraps(func) + def wrapper(*args, **kwargs): + # format arguments + kwargs = convert_args_kwargs_to_kwargs(func, args, kwargs) + + if kwargs['self']._device_ids is None: + return func(**kwargs) + else: + # change primary device + default_device = kwargs['self'].device + kwargs['self'].device = kwargs['self']._device_ids[0] + rtn = func(**kwargs) + kwargs['self'].to(default_device) + return rtn + return wrapper + +def collate_handle_corrupted(samples_list, dataset, labels, dtype=torch.half): + # print(len(samples_list)) + orig_len = len(samples_list) + # for the loss to be consistent, we drop samples with NaN values in any of their corresponding crops + for i, s in enumerate(samples_list): + ic(s is None) + if s is None: + continue + samples_list = list(filter(lambda x: x is not None, samples_list)) + + if len(samples_list) == 0: + ic('recursive call') + return collate_handle_corrupted([dataset[random.randint(0, len(dataset)-1)] for _ in range(orig_len)], dataset, labels) + + # collated_images = torch.stack([convert_to_tensor(s["image"]) for s in samples_list]) + try: + if "image" in samples_list[0]: + samples_list = [s for s in samples_list if not torch.isnan(s["image"]).any()] + # print('samples list: ', len(samples_list)) + collated_images = torch.stack([convert_to_tensor(s["image"]) for s in samples_list]) + # print("here1") + collated_labels = {k: torch.Tensor([s["label"][k] if s["label"][k] is not None else 0 for s in samples_list]) for k in labels} + # print("here2") + collated_mask = {k: torch.Tensor([1 if s["label"][k] is not None else 0 for s in samples_list]) for k in labels} + # print("here3") + return {"image": collated_images, + "label": collated_labels, + "mask": collated_mask} + except: + return collate_handle_corrupted([dataset[random.randint(0, len(dataset)-1)] for _ in range(orig_len)], dataset, labels) + + + +def get_backend(img_backend): + if img_backend == 'C3D': + return nn.C3D + elif img_backend == 'DenseNet': + return nn.DenseNet + + +class ImagingModel(BaseEstimator): + ''' ... ''' + def __init__(self, + tgt_modalities: list[str], + label_fractions: dict[str, float], + num_epochs: int = 32, + batch_size: int = 8, + batch_size_multiplier: int = 1, + lr: float = 1e-2, + weight_decay: float = 0.0, + beta: float = 0.9999, + gamma: float = 2.0, + bn_size: int = 4, + growth_rate: int = 12, + block_config: tuple = (3, 3, 3), + compression: float = 0.5, + num_init_features: int = 16, + drop_rate: float = 0.2, + criterion: str | None = None, + device: str = 'cpu', + cuda_devices: list = [1], + ckpt_path: str = '/home/skowshik/ADRD_repo/adrd_tool/dev/ckpt/ckpt.pt', + load_from_ckpt: bool = True, + save_intermediate_ckpts: bool = False, + data_parallel: bool = False, + verbose: int = 0, + img_backend: str | None = None, + label_distribution: dict = {}, + wandb_ = 1, + _device_ids: list | None = None, + _dataloader_num_workers: int = 4, + _amp_enabled: bool = False, + ) -> None: + ''' ... ''' + # for multiprocessing + self._rank = 0 + self._lock = None + + # positional parameters + self.tgt_modalities = tgt_modalities + + # training parameters + self.label_fractions = label_fractions + self.num_epochs = num_epochs + self.batch_size = batch_size + self.batch_size_multiplier = batch_size_multiplier + self.lr = lr + self.weight_decay = weight_decay + self.beta = beta + self.gamma = gamma + self.bn_size = bn_size + self.growth_rate = growth_rate + self.block_config = block_config + self.compression = compression + self.num_init_features = num_init_features + self.drop_rate = drop_rate + self.criterion = criterion + self.device = device + self.cuda_devices = cuda_devices + self.ckpt_path = ckpt_path + self.load_from_ckpt = load_from_ckpt + self.save_intermediate_ckpts = save_intermediate_ckpts + self.data_parallel = data_parallel + self.verbose = verbose + self.img_backend = img_backend + self.label_distribution = label_distribution + self.wandb_ = wandb_ + self._device_ids = _device_ids + self._dataloader_num_workers = _dataloader_num_workers + self._amp_enabled = _amp_enabled + self.scaler = torch.cuda.amp.GradScaler() + + @_manage_ctx_fit + def fit(self, trn_list, vld_list, img_train_trans=None, img_vld_trans=None) -> Self: + # def fit(self, x, y) -> Self: + ''' ... ''' + + # start a new wandb run to track this script + if self.wandb_ == 1: + wandb.init( + # set the wandb project where this run will be logged + project="ADRD_main", + + # track hyperparameters and run metadata + config={ + "Model": "DenseNet", + "Loss": 'Focalloss', + "EMB": "ALL_EMB", + "epochs": 256, + } + ) + wandb.run.log_code("/home/skowshik/ADRD_repo/pipeline_v1_main/adrd_tool") + else: + wandb.init(mode="disabled") + # for PyTorch computational efficiency + torch.set_num_threads(1) + print(self.criterion) + + # initialize neural network + self._init_net() + + # for k, info in self.src_modalities.items(): + # if info['type'] == 'imaging' and self.img_net != 'EMB': + # info['shape'] = (1,) + (self.img_size,) * 3 + # info['img_shape'] = (1,) + (self.img_size,) * 3 + # print(info['shape']) + + # initialize dataloaders + # ldr_trn, ldr_vld = self._init_dataloader(x, y) + # ldr_trn, ldr_vld = self._init_dataloader(x_trn, x_vld, y_trn, y_vld) + ldr_trn, ldr_vld = self._init_dataloader(trn_list, vld_list, img_train_trans=img_train_trans, img_vld_trans=img_vld_trans) + + # initialize optimizer and scheduler + if not self.load_from_ckpt: + self.optimizer = self._init_optimizer() + self.scheduler = self._init_scheduler(self.optimizer) + + # gradient scaler for AMP + if self._amp_enabled: + self.scaler = torch.cuda.amp.GradScaler() + + # initialize focal loss function + self.loss_fn = {} + + for k in self.tgt_modalities: + if self.label_fractions[k] >= 0.3: + alpha = -1 + else: + alpha = pow((1 - self.label_fractions[k]), 2) + # alpha = -1 + self.loss_fn[k] = nn.SigmoidFocalLoss( + alpha = alpha, + gamma = self.gamma, + reduction = 'none' + ) + + # to record the best validation performance criterion + if self.criterion is not None: + best_crit = None + best_crit_AUPR = None + + # progress bar for epoch loops + if self.verbose == 1: + with self._lock if self._lock is not None else suppress(): + pbr_epoch = tqdm( + desc = 'Rank {:02d}'.format(self._rank), + total = self.num_epochs, + position = self._rank, + ascii = True, + leave = False, + bar_format='{l_bar}{r_bar}' + ) + + # Define a hook function to print and store the gradient of a layer + def print_and_store_grad(grad, grad_list): + grad_list.append(grad) + # print(grad) + + # grad_list = [] + # self.net_.modules_emb_src['img_MRI_T1'].downsample[0].weight.register_hook(lambda grad: print_and_store_grad(grad, grad_list)) + + # lambda_coeff = 0.0001 + # margin_loss = torch.nn.MarginRankingLoss(reduction='sum', margin=0.05) + + # training loop + for epoch in range(self.start_epoch, self.num_epochs): + met_trn = self.train_one_epoch(ldr_trn, epoch) + met_vld = self.validate_one_epoch(ldr_vld, epoch) + + print(self.ckpt_path.split('/')[-1]) + + # save the model if it has the best validation performance criterion by far + if self.criterion is None: continue + + + # is current criterion better than previous best? + curr_crit = np.mean([met_vld[i][self.criterion] for i in range(len(self.tgt_modalities))]) + curr_crit_AUPR = np.mean([met_vld[i]["AUC (PR)"] for i in range(len(self.tgt_modalities))]) + # AUROC + if best_crit is None or np.isnan(best_crit): + is_better = True + elif self.criterion == 'Loss' and best_crit >= curr_crit: + is_better = True + elif self.criterion != 'Loss' and best_crit <= curr_crit : + is_better = True + else: + is_better = False + + # AUPR + if best_crit_AUPR is None or np.isnan(best_crit_AUPR): + is_better_AUPR = True + elif best_crit_AUPR <= curr_crit_AUPR : + is_better_AUPR = True + else: + is_better_AUPR = False + + # update best criterion + if is_better_AUPR: + best_crit_AUPR = curr_crit_AUPR + if self.save_intermediate_ckpts: + print(f"Saving the model to {self.ckpt_path[:-3]}_AUPR.pt...") + self.save(self.ckpt_path[:-3]+"_AUPR.pt", epoch) + + if is_better: + best_crit = curr_crit + best_state_dict = deepcopy(self.net_.state_dict()) + if self.save_intermediate_ckpts: + print(f"Saving the model to {self.ckpt_path}...") + self.save(self.ckpt_path, epoch) + + if self.verbose > 2: + print('Best {}: {}'.format(self.criterion, best_crit)) + print('Best {}: {}'.format('AUC (PR)', best_crit_AUPR)) + + if self.verbose == 1: + with self._lock if self._lock is not None else suppress(): + pbr_epoch.update(1) + pbr_epoch.refresh() + + return self + + def train_one_epoch(self, ldr_trn, epoch): + + # progress bar for batch loops + if self.verbose > 1: + pbr_batch = ProgressBar(len(ldr_trn.dataset), 'Epoch {:03d} (TRN)'.format(epoch)) + + torch.set_grad_enabled(True) + self.net_.train() + + scores_trn, y_true_trn, y_mask_trn = [], [], [] + losses_trn = [[] for _ in self.tgt_modalities] + iters = len(ldr_trn) + print(iters) + for n_iter, batch_data in enumerate(ldr_trn): + # if len(batch_data["image"]) < self.batch_size: + # continue + + x_batch = batch_data["image"].to(self.device, non_blocking=True) + y_batch = {k: v.to(self.device, non_blocking=True) for k,v in batch_data["label"].items()} + y_mask = {k: v.to(self.device, non_blocking=True) for k,v in batch_data["mask"].items()} + + with torch.autocast( + device_type = 'cpu' if self.device == 'cpu' else 'cuda', + dtype = torch.bfloat16 if self.device == 'cpu' else torch.float16, + enabled = self._amp_enabled, + ): + + outputs = self.net_(x_batch, shap=False) + # print(outputs.shape) + # calculate multitask loss + loss = 0 + for i, k in enumerate(self.tgt_modalities): + loss_task = self.loss_fn[k](outputs[k], y_batch[k]) + msk_loss_task = loss_task * y_mask[k] + msk_loss_mean = msk_loss_task.sum() / y_mask[k].sum() + loss += msk_loss_mean + losses_trn[i] += msk_loss_task.detach().cpu().numpy().tolist() + + # backward + if self._amp_enabled: + self.scaler.scale(loss).backward() + else: + loss.backward() + + # print(len(grad_list), len(grad_list[-1])) + # print(f"Gradient at {n_iter}: {grad_list[-1][0]}") + + # update parameters + if n_iter != 0 and n_iter % self.batch_size_multiplier == 0: + if self._amp_enabled: + self.scaler.step(self.optimizer) + self.scaler.update() + self.optimizer.zero_grad() + else: + self.optimizer.step() + self.optimizer.zero_grad() + # set self.scheduler + self.scheduler.step(epoch + n_iter / iters) + # print(f"Weight: {self.net_.module.features[0].weight[0]}") + + ''' TODO: change array to dictionary later ''' + outputs = torch.stack(list(outputs.values()), dim=1) + y_batch = torch.stack(list(y_batch.values()), dim=1) + y_mask = torch.stack(list(y_mask.values()), dim=1) + + # save outputs to evaluate performance later + scores_trn.append(outputs.detach().to(torch.float).cpu()) + y_true_trn.append(y_batch.cpu()) + y_mask_trn.append(y_mask.cpu()) + + # log metrics to wandb + + # update progress bar + if self.verbose > 1: + batch_size = len(x_batch) + pbr_batch.update(batch_size, {}) + pbr_batch.refresh() + + # clear cuda cache + if "cuda" in self.device: + torch.cuda.empty_cache() + + # for better tqdm progress bar display + if self.verbose > 1: + pbr_batch.close() + + # # set self.scheduler + # self.scheduler.step() + + # calculate and print training performance metrics + scores_trn = torch.cat(scores_trn) + y_true_trn = torch.cat(y_true_trn) + y_mask_trn = torch.cat(y_mask_trn) + y_pred_trn = (scores_trn > 0).to(torch.int) + y_prob_trn = torch.sigmoid(scores_trn) + met_trn = get_metrics_multitask( + y_true_trn.numpy(), + y_pred_trn.numpy(), + y_prob_trn.numpy(), + y_mask_trn.numpy() + ) + + # add loss to metrics + for i in range(len(self.tgt_modalities)): + met_trn[i]['Loss'] = np.mean(losses_trn[i]) + + wandb.log({f"Train loss {list(self.tgt_modalities)[i]}": met_trn[i]['Loss'] for i in range(len(self.tgt_modalities))}, step=epoch) + wandb.log({f"Train Balanced Accuracy {list(self.tgt_modalities)[i]}": met_trn[i]['Balanced Accuracy'] for i in range(len(self.tgt_modalities))}, step=epoch) + + wandb.log({f"Train AUC (ROC) {list(self.tgt_modalities)[i]}": met_trn[i]['AUC (ROC)'] for i in range(len(self.tgt_modalities))}, step=epoch) + wandb.log({f"Train AUPR {list(self.tgt_modalities)[i]}": met_trn[i]['AUC (PR)'] for i in range(len(self.tgt_modalities))}, step=epoch) + + if self.verbose > 2: + print_metrics_multitask(met_trn) + + return met_trn + + # @torch.no_grad() + def validate_one_epoch(self, ldr_vld, epoch): + # progress bar for validation + if self.verbose > 1: + pbr_batch = ProgressBar(len(ldr_vld.dataset), 'Epoch {:03d} (VLD)'.format(epoch)) + + # set model to validation mode + torch.set_grad_enabled(False) + self.net_.eval() + + scores_vld, y_true_vld, y_mask_vld = [], [], [] + losses_vld = [[] for _ in self.tgt_modalities] + for batch_data in ldr_vld: + # if len(batch_data["image"]) < self.batch_size: + # continue + x_batch = batch_data["image"].to(self.device, non_blocking=True) + y_batch = {k: v.to(self.device, non_blocking=True) for k,v in batch_data["label"].items()} + y_mask = {k: v.to(self.device, non_blocking=True) for k,v in batch_data["mask"].items()} + + # forward + with torch.autocast( + device_type = 'cpu' if self.device == 'cpu' else 'cuda', + dtype = torch.bfloat16 if self.device == 'cpu' else torch.float16, + enabled = self._amp_enabled + ): + + outputs = self.net_(x_batch, shap=False) + + # calculate multitask loss + for i, k in enumerate(self.tgt_modalities): + loss_task = self.loss_fn[k](outputs[k], y_batch[k]) + msk_loss_task = loss_task * y_mask[k] + losses_vld[i] += msk_loss_task.detach().cpu().numpy().tolist() + + ''' TODO: change array to dictionary later ''' + outputs = torch.stack(list(outputs.values()), dim=1) + y_batch = torch.stack(list(y_batch.values()), dim=1) + y_mask = torch.stack(list(y_mask.values()), dim=1) + + # save outputs to evaluate performance later + scores_vld.append(outputs.detach().to(torch.float).cpu()) + y_true_vld.append(y_batch.cpu()) + y_mask_vld.append(y_mask.cpu()) + + # update progress bar + if self.verbose > 1: + batch_size = len(x_batch) + pbr_batch.update(batch_size, {}) + pbr_batch.refresh() + + # clear cuda cache + if "cuda" in self.device: + torch.cuda.empty_cache() + + # for better tqdm progress bar display + if self.verbose > 1: + pbr_batch.close() + + # calculate and print validation performance metrics + scores_vld = torch.cat(scores_vld) + y_true_vld = torch.cat(y_true_vld) + y_mask_vld = torch.cat(y_mask_vld) + y_pred_vld = (scores_vld > 0).to(torch.int) + y_prob_vld = torch.sigmoid(scores_vld) + met_vld = get_metrics_multitask( + y_true_vld.numpy(), + y_pred_vld.numpy(), + y_prob_vld.numpy(), + y_mask_vld.numpy() + ) + + # add loss to metrics + for i in range(len(self.tgt_modalities)): + met_vld[i]['Loss'] = np.mean(losses_vld[i]) + + wandb.log({f"Validation loss {list(self.tgt_modalities)[i]}": met_vld[i]['Loss'] for i in range(len(self.tgt_modalities))}, step=epoch) + wandb.log({f"Validation Balanced Accuracy {list(self.tgt_modalities)[i]}": met_vld[i]['Balanced Accuracy'] for i in range(len(self.tgt_modalities))}, step=epoch) + + wandb.log({f"Validation AUC (ROC) {list(self.tgt_modalities)[i]}": met_vld[i]['AUC (ROC)'] for i in range(len(self.tgt_modalities))}, step=epoch) + wandb.log({f"Validation AUPR {list(self.tgt_modalities)[i]}": met_vld[i]['AUC (PR)'] for i in range(len(self.tgt_modalities))}, step=epoch) + + if self.verbose > 2: + print_metrics_multitask(met_vld) + + return met_vld + + + def save(self, filepath: str, epoch: int = 0) -> None: + ''' ... ''' + check_is_fitted(self) + if self.data_parallel: + state_dict = self.net_.module.state_dict() + else: + state_dict = self.net_.state_dict() + + # attach model hyper parameters + state_dict['tgt_modalities'] = self.tgt_modalities + state_dict['optimizer'] = self.optimizer + state_dict['bn_size'] = self.bn_size + state_dict['growth_rate'] = self.growth_rate + state_dict['block_config'] = self.block_config + state_dict['compression'] = self.compression + state_dict['num_init_features'] = self.num_init_features + state_dict['drop_rate'] = self.drop_rate + state_dict['epoch'] = epoch + + if self.scaler is not None: + state_dict['scaler'] = self.scaler.state_dict() + if self.label_distribution: + state_dict['label_distribution'] = self.label_distribution + + torch.save(state_dict, filepath) + + def load(self, filepath: str, map_location: str = 'cpu', how='latest') -> None: + ''' ... ''' + # load state_dict + if how == 'latest': + if torch.load(filepath)['epoch'] > torch.load(f'{filepath[:-3]}_AUPR.pt')['epoch']: + print("Loading model saved using AUROC") + state_dict = torch.load(filepath, map_location=map_location) + else: + print("Loading model saved using AUPR") + state_dict = torch.load(f'{filepath[:-3]}_AUPR.pt', map_location=map_location) + else: + state_dict = torch.load(filepath, map_location=map_location) + + # load data modalities + self.tgt_modalities: dict[str, dict[str, Any]] = state_dict.pop('tgt_modalities') + if 'label_distribution' in state_dict: + self.label_distribution: dict[str, dict[int, int]] = state_dict.pop('label_distribution') + if 'optimizer' in state_dict: + self.optimizer = state_dict.pop('optimizer') + if 'bn_size' in state_dict: + self.bn_size = state_dict.pop('bn_size') + if 'growth_rate' in state_dict: + self.growth_rate = state_dict.pop('growth_rate') + if 'block_config' in state_dict: + self.block_config = state_dict.pop('block_config') + if 'compression' in state_dict: + self.compression = state_dict.pop('compression') + if 'num_init_features' in state_dict: + self.num_init_features = state_dict.pop('num_init_features') + if 'drop_rate' in state_dict: + self.drop_rate = state_dict.pop('drop_rate') + if 'epoch' in state_dict: + self.start_epoch = state_dict.pop('epoch') + print(f'Epoch: {self.start_epoch}') + + # initialize model + + self.net_ = get_backend(self.img_backend)( + tgt_modalities = self.tgt_modalities, + bn_size = self.bn_size, + growth_rate=self.growth_rate, + block_config=self.block_config, + compression=self.compression, + num_init_features=self.num_init_features, + drop_rate=self.drop_rate, + load_from_ckpt=self.load_from_ckpt + ) + print(self.net_) + + if 'scaler' in state_dict and state_dict['scaler']: + self.scaler.load_state_dict(state_dict.pop('scaler')) + self.net_.load_state_dict(state_dict) + check_is_fitted(self) + self.net_.to(self.device) + + def to(self, device: str) -> Self: + ''' Mount model to the given device. ''' + self.device = device + if hasattr(self, 'model'): self.net_ = self.net_.to(device) + return self + + @classmethod + def from_ckpt(cls, filepath: str, device='cpu', img_backend=None, load_from_ckpt=True, how='latest') -> Self: + ''' ... ''' + obj = cls(None, None, None,device=device) + if device == 'cuda': + obj.device = "{}:{}".format(obj.device, str(obj.cuda_devices[0])) + print(obj.device) + obj.img_backend=img_backend + obj.load_from_ckpt = load_from_ckpt + obj.load(filepath, map_location=obj.device, how=how) + return obj + + def _init_net(self): + """ ... """ + self.start_epoch = 0 + # set the device for use + if self.device == 'cuda': + self.device = "{}:{}".format(self.device, str(self.cuda_devices[0])) + # self.load(self.ckpt_path, map_location=self.device) + # print("Loading model from checkpoint...") + # self.load(self.ckpt_path, map_location=self.device) + + if self.load_from_ckpt: + try: + print("Loading model from checkpoint...") + self.load(self.ckpt_path, map_location=self.device) + except: + print("Cannot load from checkpoint. Initializing new model...") + self.load_from_ckpt = False + + if not self.load_from_ckpt: + self.net_ = get_backend(self.img_backend)( + tgt_modalities = self.tgt_modalities, + bn_size = self.bn_size, + growth_rate=self.growth_rate, + block_config=self.block_config, + compression=self.compression, + num_init_features=self.num_init_features, + drop_rate=self.drop_rate, + load_from_ckpt=self.load_from_ckpt + ) + + # # intialize model parameters using xavier_uniform + # for p in self.net_.parameters(): + # if p.dim() > 1: + # torch.nn.init.xavier_uniform_(p) + + self.net_.to(self.device) + + # Initialize the number of GPUs + if self.data_parallel and torch.cuda.device_count() > 1: + print("Available", torch.cuda.device_count(), "GPUs!") + self.net_ = torch.nn.DataParallel(self.net_, device_ids=self.cuda_devices) + + # return net + + def _init_dataloader(self, trn_list, vld_list, img_train_trans=None, img_vld_trans=None): + # def _init_dataloader(self, x, y): + """ ... """ + # # split dataset + # x_trn, x_vld, y_trn, y_vld = train_test_split( + # x, y, test_size = 0.2, random_state = 0, + # ) + + # # initialize dataset and dataloader + # dat_trn = CNNTrainingValidationDataset( + # x_trn, y_trn, + # self.tgt_modalities, + # img_transform=img_train_trans, + # ) + + # dat_vld = CNNTrainingValidationDataset( + # x_vld, y_vld, + # self.tgt_modalities, + # img_transform=img_vld_trans, + # ) + + dat_trn = monai.data.Dataset(data=trn_list, transform=img_train_trans) + dat_vld = monai.data.Dataset(data=vld_list, transform=img_vld_trans) + collate_fn_trn = functools.partial(collate_handle_corrupted, dataset=dat_trn, dtype=torch.FloatTensor, labels=self.tgt_modalities) + collate_fn_vld = functools.partial(collate_handle_corrupted, dataset=dat_vld, dtype=torch.FloatTensor, labels=self.tgt_modalities) + + ldr_trn = DataLoader( + dataset = dat_trn, + batch_size = self.batch_size, + shuffle = True, + drop_last = False, + num_workers = self._dataloader_num_workers, + collate_fn = collate_fn_trn, + # pin_memory = True + ) + + ldr_vld = DataLoader( + dataset = dat_vld, + batch_size = self.batch_size, + shuffle = False, + drop_last = False, + num_workers = self._dataloader_num_workers, + collate_fn = collate_fn_vld, + # pin_memory = True + ) + + return ldr_trn, ldr_vld + + def _init_optimizer(self): + """ ... """ + params = list(self.net_.parameters()) + # for p in params: + # print(p.requires_grad) + return torch.optim.AdamW( + params, + lr = self.lr, + betas = (0.9, 0.98), + weight_decay = self.weight_decay + ) + + def _init_scheduler(self, optimizer): + """ ... """ + # return torch.optim.lr_scheduler.OneCycleLR( + # optimizer = optimizer, + # max_lr = self.lr, + # total_steps = self.num_epochs, + # verbose = (self.verbose > 2) + # ) + + # return torch.optim.lr_scheduler.CosineAnnealingLR( + # optimizer=optimizer, + # T_max=64, + # verbose=(self.verbose > 2) + # ) + + return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( + optimizer=optimizer, + T_0=64, + T_mult=2, + eta_min = 0, + verbose=(self.verbose > 2) + ) + + def _init_loss_func(self, + num_per_cls: dict[str, tuple[int, int]], + ) -> dict[str, Module]: + """ ... """ + return {k: nn.SigmoidFocalLossBeta( + beta = self.beta, + gamma = self.gamma, + num_per_cls = num_per_cls[k], + reduction = 'none', + ) for k in self.tgt_modalities} + + def _proc_fit(self): + """ ... """ + + def _init_test_dataloader(self, batch_size, tst_list, img_tst_trans=None): + # input validation + check_is_fitted(self) + print(self.device) + + # for PyTorch computational efficiency + torch.set_num_threads(1) + + # set model to eval mode + torch.set_grad_enabled(False) + self.net_.eval() + + dat_tst = monai.data.Dataset(data=tst_list, transform=img_tst_trans) + collate_fn_tst = functools.partial(collate_handle_corrupted, dataset=dat_tst, dtype=torch.FloatTensor, labels=self.tgt_modalities) + # print(collate_fn_tst) + + ldr_tst = DataLoader( + dataset = dat_tst, + batch_size = batch_size, + shuffle = False, + drop_last = False, + num_workers = self._dataloader_num_workers, + collate_fn = collate_fn_tst, + # pin_memory = True + ) + return ldr_tst + + + def predict_logits(self, + ldr_tst: Any | None = None, + ) -> list[dict[str, float]]: + + # run model and collect results + logits: list[dict[str, float]] = [] + for batch_data in tqdm(ldr_tst): + # print(batch_data["image"]) + if len(batch_data) == 0: + continue + x_batch = batch_data["image"].to(self.device, non_blocking=True) + outputs = self.net_(x_batch, shap=False) + + # convert output from dict-of-list to list of dict, then append + tmp = {k: outputs[k].tolist() for k in self.tgt_modalities} + tmp = [{k: tmp[k][i] for k in self.tgt_modalities} for i in range(len(next(iter(tmp.values()))))] + logits += tmp + + return logits + + def predict_proba(self, + ldr_tst: Any | None = None, + temperature: float = 1.0, + ) -> list[dict[str, float]]: + ''' ... ''' + logits = self.predict_logits(ldr_tst) + print("got logits") + return logits, [{k: expit(smp[k] / temperature) for k in self.tgt_modalities} for smp in logits] + + def predict(self, + ldr_tst: Any | None = None, + ) -> list[dict[str, int]]: + ''' ... ''' + logits, proba = self.predict_proba(ldr_tst) + print("got proba") + return logits, proba, [{k: int(smp[k] > 0.5) for k in self.tgt_modalities} for smp in proba] \ No newline at end of file diff --git a/adrd/model/train_resnet.py b/adrd/model/train_resnet.py new file mode 100755 index 0000000000000000000000000000000000000000..c9cddac9c15ec26228c48ebb20b75117086b1890 --- /dev/null +++ b/adrd/model/train_resnet.py @@ -0,0 +1,484 @@ +import torch +import numpy as np +import tqdm +from sklearn.base import BaseEstimator +from sklearn.utils.validation import check_is_fitted +from sklearn.model_selection import train_test_split +from scipy.special import expit +from copy import deepcopy +from contextlib import suppress +from typing import Any, Self +from icecream import ic + +from .. import nn +from ..utils import TransformerTrainingDataset +from ..utils import TransformerValidationDataset +from ..utils import MissingMasker +from ..utils import ConstantImputer +from ..utils import Formatter +from ..utils.misc import ProgressBar +from ..utils.misc import get_metrics_multitask, print_metrics_multitask + + +class TrainResNet(BaseEstimator): + ''' ... ''' + def __init__(self, + src_modalities: dict[str, dict[str, Any]], + tgt_modalities: dict[str, dict[str, Any]], + label_fractions: dict[str, float], + num_epochs: int = 32, + batch_size: int = 8, + lr: float = 1e-2, + weight_decay: float = 0.0, + gamma: float = 0.0, + criterion: str | None = None, + device: str = 'cpu', + cuda_devices: list = [1,2], + mri_feature: str = 'img_MRI_T1', + ckpt_path: str = '/home/skowshik/ADRD_repo/adrd_tool/adrd/dev/ckpt/ckpt.pt', + load_from_ckpt: bool = True, + save_intermediate_ckpts: bool = False, + data_parallel: bool = False, + verbose: int = 0, + ): + ''' ... ''' + # for multiprocessing + self._rank = 0 + self._lock = None + + # positional parameters + self.src_modalities = src_modalities + self.tgt_modalities = tgt_modalities + + # training parameters + self.label_fractions = label_fractions + self.num_epochs = num_epochs + self.batch_size = batch_size + self.lr = lr + self.weight_decay = weight_decay + self.gamma = gamma + self.criterion = criterion + self.device = device + self.cuda_devices = cuda_devices + self.mri_feature = mri_feature + self.ckpt_path = ckpt_path + self.load_from_ckpt = load_from_ckpt + self.save_intermediate_ckpts = save_intermediate_ckpts + self.data_parallel = data_parallel + self.verbose = verbose + + def fit(self, x, y): + ''' ... ''' + # for PyTorch computational efficiency + torch.set_num_threads(1) + + # set the device for use + if self.device == 'cuda': + self.device = "{}:{}".format(self.device, str(self.cuda_devices[0])) + + # initialize model + if self.load_from_ckpt: + try: + print("Loading model from checkpoint...") + self.load(self.ckpt_path, map_location=self.device) + except: + print("Cannot load from checkpoint. Initializing new model...") + self.load_from_ckpt = False + + # initialize model + if not self.load_from_ckpt: + self.net_ = nn.ResNetModel( + self.tgt_modalities, + mri_feature = self.mri_feature + ) + # intialize model parameters using xavier_uniform + for p in self.net_.parameters(): + if p.dim() > 1: + torch.nn.init.xavier_uniform_(p) + + self.net_.to(self.device) + + # Initialize the number of GPUs + if self.data_parallel and torch.cuda.device_count() > 1: + print("Available", torch.cuda.device_count(), "GPUs!") + self.net_ = torch.nn.DataParallel(self.net_, device_ids=self.cuda_devices) + + + # split dataset + x_trn, x_vld, y_trn, y_vld = train_test_split( + x, y, test_size = 0.2, random_state = 0, + ) + + # initialize dataset and dataloader + dat_trn = TransformerTrainingDataset( + x_trn, y_trn, + self.src_modalities, + self.tgt_modalities, + dropout_rate = .5, + dropout_strategy = 'compensated', + mri_feature = self.mri_feature, + ) + + dat_vld = TransformerValidationDataset( + x_vld, y_vld, + self.src_modalities, + self.tgt_modalities, + mri_feature = self.mri_feature, + ) + + # ic(dat_trn[0]) + + ldr_trn = torch.utils.data.DataLoader( + dat_trn, + batch_size = self.batch_size, + shuffle = True, + drop_last = False, + num_workers = 0, + collate_fn = TransformerTrainingDataset.collate_fn, + # pin_memory = True + ) + + ldr_vld = torch.utils.data.DataLoader( + dat_vld, + batch_size = self.batch_size, + shuffle = False, + drop_last = False, + num_workers = 0, + collate_fn = TransformerTrainingDataset.collate_fn, + # pin_memory = True + ) + + # initialize optimizer + optimizer = torch.optim.AdamW( + self.net_.parameters(), + lr = self.lr, + betas = (0.9, 0.98), + weight_decay = self.weight_decay + ) + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=64, verbose=(self.verbose > 2)) + + # initialize loss function (binary cross entropy) + loss_fn = {} + + for k in self.tgt_modalities: + alpha = pow((1 - self.label_fractions[k]), self.gamma) + # if alpha < 0.5: + # alpha = -1 + loss_fn[k] = nn.SigmoidFocalLoss( + alpha = alpha, + gamma = self.gamma, + reduction = 'none' + ) + + # to record the best validation performance criterion + if self.criterion is not None: + best_crit = None + + # progress bar for epoch loops + if self.verbose == 1: + with self._lock if self._lock is not None else suppress(): + pbr_epoch = tqdm.tqdm( + desc = 'Rank {:02d}'.format(self._rank), + total = self.num_epochs, + position = self._rank, + ascii = True, + leave = False, + bar_format='{l_bar}{r_bar}' + ) + + # Define a hook function to print and store the gradient of a layer + def print_and_store_grad(grad, grad_list): + grad_list.append(grad) + # print(grad) + + # grad_list = [] + # self.net_.module.img_net_.featurizer.down_tr64.ops[0].conv1.weight.register_hook(lambda grad: print_and_store_grad(grad, grad_list)) + # self.net_.module.modules_emb_src['gender'].weight.register_hook(lambda grad: print_and_store_grad(grad, grad_list)) + + + # training loop + for epoch in range(self.num_epochs): + # progress bar for batch loops + if self.verbose > 1: + pbr_batch = ProgressBar(len(dat_trn), 'Epoch {:03d} (TRN)'.format(epoch)) + + # set model to train mode + torch.set_grad_enabled(True) + self.net_.train() + + scores_trn, y_true_trn = [], [] + losses_trn = [[] for _ in self.tgt_modalities] + for x_batch, y_batch, mask in ldr_trn: + + # mount data to the proper device + x_batch = {k: x_batch[k].to(self.device) for k in x_batch} + y_batch = {k: y_batch[k].to(torch.float).to(self.device) for k in y_batch} + + # forward + outputs = self.net_(x_batch) + + # calculate multitask loss + loss = 0 + for i, k in enumerate(self.tgt_modalities): + loss_task = loss_fn[k](outputs[k], y_batch[k]) + loss += loss_task.mean() + losses_trn[i] += loss_task.detach().cpu().numpy().tolist() + + # backward + optimizer.zero_grad(set_to_none=True) + loss.backward() + optimizer.step() + + ''' TODO: change array to dictionary later ''' + outputs = torch.stack(list(outputs.values()), dim=1) + y_batch = torch.stack(list(y_batch.values()), dim=1) + + # save outputs to evaluate performance later + scores_trn.append(outputs.detach().to(torch.float).cpu()) + y_true_trn.append(y_batch.cpu()) + + # update progress bar + if self.verbose > 1: + batch_size = len(next(iter(x_batch.values()))) + pbr_batch.update(batch_size, {}) + pbr_batch.refresh() + + # clear cuda cache + if "cuda" in self.device: + torch.cuda.empty_cache() + + # for better tqdm progress bar display + if self.verbose > 1: + pbr_batch.close() + + # set scheduler + scheduler.step() + + # calculate and print training performance metrics + scores_trn = torch.cat(scores_trn) + y_true_trn = torch.cat(y_true_trn) + y_pred_trn = (scores_trn > 0).to(torch.int) + y_prob_trn = torch.sigmoid(scores_trn) + met_trn = get_metrics_multitask( + y_true_trn.numpy(), + y_pred_trn.numpy(), + y_prob_trn.numpy() + ) + + # add loss to metrics + for i in range(len(self.tgt_modalities)): + met_trn[i]['Loss'] = np.mean(losses_trn[i]) + + if self.verbose > 2: + print_metrics_multitask(met_trn) + + # progress bar for validation + if self.verbose > 1: + pbr_batch = ProgressBar(len(dat_vld), 'Epoch {:03d} (VLD)'.format(epoch)) + + # set model to validation mode + torch.set_grad_enabled(False) + self.net_.eval() + + scores_vld, y_true_vld = [], [] + losses_vld = [[] for _ in self.tgt_modalities] + for x_batch, y_batch, mask in ldr_vld: + # mount data to the proper device + x_batch = {k: x_batch[k].to(self.device) for k in x_batch} + y_batch = {k: y_batch[k].to(torch.float).to(self.device) for k in y_batch} + + # forward + outputs = self.net_(x_batch) + + # calculate multitask loss + for i, k in enumerate(self.tgt_modalities): + loss_task = loss_fn[k](outputs[k], y_batch[k]) + losses_vld[i] += loss_task.detach().cpu().numpy().tolist() + + ''' TODO: change array to dictionary later ''' + outputs = torch.stack(list(outputs.values()), dim=1) + y_batch = torch.stack(list(y_batch.values()), dim=1) + + # save outputs to evaluate performance later + scores_vld.append(outputs.detach().to(torch.float).cpu()) + y_true_vld.append(y_batch.cpu()) + + # update progress bar + if self.verbose > 1: + batch_size = len(next(iter(x_batch.values()))) + pbr_batch.update(batch_size, {}) + pbr_batch.refresh() + + # clear cuda cache + if "cuda" in self.device: + torch.cuda.empty_cache() + + # for better tqdm progress bar display + if self.verbose > 1: + pbr_batch.close() + + # calculate and print validation performance metrics + scores_vld = torch.cat(scores_vld) + y_true_vld = torch.cat(y_true_vld) + y_pred_vld = (scores_vld > 0).to(torch.int) + y_prob_vld = torch.sigmoid(scores_vld) + met_vld = get_metrics_multitask( + y_true_vld.numpy(), + y_pred_vld.numpy(), + y_prob_vld.numpy() + ) + + # add loss to metrics + for i in range(len(self.tgt_modalities)): + met_vld[i]['Loss'] = np.mean(losses_vld[i]) + + if self.verbose > 2: + print_metrics_multitask(met_vld) + + # save the model if it has the best validation performance criterion by far + if self.criterion is None: continue + + # is current criterion better than previous best? + curr_crit = np.mean([met_vld[i][self.criterion] for i in range(len(self.tgt_modalities))]) + if best_crit is None or np.isnan(best_crit): + is_better = True + elif self.criterion == 'Loss' and best_crit >= curr_crit: + is_better = True + elif self.criterion != 'Loss' and best_crit <= curr_crit: + is_better = True + else: + is_better = False + + # update best criterion + if is_better: + best_crit = curr_crit + best_state_dict = deepcopy(self.net_.state_dict()) + if self.save_intermediate_ckpts: + print("Saving the model...") + self.save(self.ckpt_path) + + if self.verbose > 2: + print('Best {}: {}'.format(self.criterion, best_crit)) + + if self.verbose == 1: + with self._lock if self._lock is not None else suppress(): + pbr_epoch.update(1) + pbr_epoch.refresh() + + if self.verbose == 1: + with self._lock if self._lock is not None else suppress(): + pbr_epoch.close() + + # restore the model of the best validation performance across all epoches + if ldr_vld is not None and self.criterion is not None: + self.net_.load_state_dict(best_state_dict) + + return self + + def predict_logits(self, + x: list[dict[str, Any]], + ) -> list[dict[str, float]]: + ''' + The input x can be a single sample or a list of samples. + ''' + # input validation + check_is_fitted(self) + + # for PyTorch computational efficiency + torch.set_num_threads(1) + + # set model to eval mode + torch.set_grad_enabled(False) + self.net_.eval() + + # number of samples to evaluate + n_samples = len(x) + + # format x + fmt = Formatter(self.src_modalities) + x = [fmt(smp) for smp in x] + + # generate missing mask (BEFORE IMPUTATION) + msk = MissingMasker(self.src_modalities) + mask = [msk(smp) for smp in x] + + # reformat x and then impute by 0s + imp = ConstantImputer(self.src_modalities) + x = [imp(smp) for smp in x] + + # convert list-of-dict to dict-of-list + x = {k: [smp[k] for smp in x] for k in self.src_modalities} + mask = {k: [smp[k] for smp in mask] for k in self.src_modalities} + + # to tensor + x = {k: torch.as_tensor(np.array(v)).to(self.device) for k, v in x.items()} + mask = {k: torch.as_tensor(np.array(v)).to(self.device) for k, v in mask.items()} + + # calculate logits + logits = self.net_(x) + + # convert dict-of-list to list-of-dict + logits = {k: logits[k].tolist() for k in self.tgt_modalities} + logits = [{k: logits[k][i] for k in self.tgt_modalities} for i in range(n_samples)] + + return logits + + def predict_proba(self, + x: list[dict[str, Any]], + temperature: float = 1.0 + ) -> list[dict[str, float]]: + ''' ... ''' + # calculate logits + logits = self.predict_logits(x) + + # convert logits to probabilities and + proba = [{k: expit(smp[k] / temperature) for k in self.tgt_modalities} for smp in logits] + return proba + + def predict(self, + x: list[dict[str, Any]], + ) -> list[dict[str, int]]: + ''' ... ''' + proba = self.predict_proba(x) + return [{k: int(smp[k] > 0.5) for k in self.tgt_modalities} for smp in proba] + + def save(self, filepath: str) -> None: + ''' ... ''' + check_is_fitted(self) + if self.data_parallel: + state_dict = self.net_.module.state_dict() + else: + state_dict = self.net_.state_dict() + + # attach model hyper parameters + state_dict['src_modalities'] = self.src_modalities + state_dict['tgt_modalities'] = self.tgt_modalities + state_dict['mri_feature'] = self.mri_feature + + torch.save(state_dict, filepath) + + def load(self, filepath: str, map_location: str='cpu') -> None: + ''' ... ''' + # load state_dict + state_dict = torch.load(filepath, map_location=map_location) + + # load data modalities + self.src_modalities = state_dict.pop('src_modalities') + self.tgt_modalities = state_dict.pop('tgt_modalities') + + # initialize model + self.net_ = nn.ResNetModel( + self.tgt_modalities, + mri_feature = state_dict.pop('mri_feature') + ) + + # load model parameters + self.net_.load_state_dict(state_dict) + self.net_.to(self.device) + + @classmethod + def from_ckpt(cls, filepath: str, device='cpu') -> Self: + ''' ... ''' + obj = cls(None, None, None,device=device) + obj.load(filepath) + return obj \ No newline at end of file diff --git a/adrd/model/transformer.py b/adrd/model/transformer.py new file mode 100755 index 0000000000000000000000000000000000000000..24c3a5d9ae6db9be0247662bb4009870b04aac57 --- /dev/null +++ b/adrd/model/transformer.py @@ -0,0 +1,600 @@ +__all__ = ['Transformer'] + +import torch +from torch.utils.data import DataLoader +import numpy as np +import tqdm +from sklearn.base import BaseEstimator +from sklearn.utils.validation import check_is_fitted +from sklearn.model_selection import train_test_split +from scipy.special import expit +from copy import deepcopy +from contextlib import suppress +from typing import Any, Self, Type +from functools import wraps +Tensor = Type[torch.Tensor] +Module = Type[torch.nn.Module] + +from .. import nn +from ..utils import TransformerTrainingDataset +from ..utils import TransformerBalancedTrainingDataset +from ..utils import Transformer2ndOrderBalancedTrainingDataset +from ..utils import TransformerValidationDataset +from ..utils import TransformerTestingDataset +from ..utils.misc import ProgressBar +from ..utils.misc import get_metrics_multitask, print_metrics_multitask +from ..utils.misc import convert_args_kwargs_to_kwargs + + +def _manage_ctx_fit(func): + ''' ... ''' + @wraps(func) + def wrapper(*args, **kwargs): + # format arguments + kwargs = convert_args_kwargs_to_kwargs(func, args, kwargs) + + if kwargs['self']._device_ids is None: + return func(**kwargs) + else: + # change primary device + default_device = kwargs['self'].device + kwargs['self'].device = kwargs['self']._device_ids[0] + rtn = func(**kwargs) + kwargs['self'].to(default_device) + return rtn + return wrapper + + +class Transformer(BaseEstimator): + ''' ... ''' + def __init__(self, + src_modalities: dict[str, dict[str, Any]], + tgt_modalities: dict[str, dict[str, Any]], + d_model: int = 32, + nhead: int = 1, + num_layers: int = 1, + num_epochs: int = 32, + batch_size: int = 8, + batch_size_multiplier: int = 1, + lr: float = 1e-2, + weight_decay: float = 0.0, + beta: float = 0.9999, + gamma: float = 2.0, + scale: float = 1.0, + lambd: float = 0.0, + criterion: str | None = None, + device: str = 'cpu', + verbose: int = 0, + _device_ids: list | None = None, + _dataloader_num_workers: int = 0, + _amp_enabled: bool = False, + ) -> None: + ''' ... ''' + # for multiprocessing + self._rank = 0 + self._lock = None + + # positional parameters + self.src_modalities = src_modalities + self.tgt_modalities = tgt_modalities + + # training parameters + self.d_model = d_model + self.nhead = nhead + self.num_layers = num_layers + self.num_epochs = num_epochs + self.batch_size = batch_size + self.batch_size_multiplier = batch_size_multiplier + self.lr = lr + self.weight_decay = weight_decay + self.beta = beta + self.gamma = gamma + self.scale = scale + self.lambd = lambd + self.criterion = criterion + self.device = device + self.verbose = verbose + self._device_ids = _device_ids + self._dataloader_num_workers = _dataloader_num_workers + self._amp_enabled = _amp_enabled + + @_manage_ctx_fit + def fit(self, + x, y, + is_embedding: dict[str, bool] | None = None, + ) -> Self: + ''' ... ''' + # for PyTorch computational efficiency + torch.set_num_threads(1) + + # initialize neural network + self.net_ = self._init_net() + + # initialize dataloaders + ldr_trn, ldr_vld = self._init_dataloader(x, y, is_embedding) + + # initialize optimizer and scheduler + optimizer = self._init_optimizer() + scheduler = self._init_scheduler(optimizer) + + # gradient scaler for AMP + if self._amp_enabled: scaler = torch.cuda.amp.GradScaler() + + # initialize loss function (binary cross entropy) + loss_func = self._init_loss_func({ + k: ( + sum([_[k] == 0 for _ in ldr_trn.dataset.tgt]), + sum([_[k] == 1 for _ in ldr_trn.dataset.tgt]), + ) for k in self.tgt_modalities + }) + + # to record the best validation performance criterion + if self.criterion is not None: best_crit = None + + # progress bar for epoch loops + if self.verbose == 1: + with self._lock if self._lock is not None else suppress(): + pbr_epoch = tqdm.tqdm( + desc = 'Rank {:02d}'.format(self._rank), + total = self.num_epochs, + position = self._rank, + ascii = True, + leave = False, + bar_format='{l_bar}{r_bar}' + ) + + # training loop + for epoch in range(self.num_epochs): + # progress bar for batch loops + if self.verbose > 1: + pbr_batch = ProgressBar(len(ldr_trn.dataset), 'Epoch {:03d} (TRN)'.format(epoch)) + + # set model to train mode + torch.set_grad_enabled(True) + self.net_.train() + + scores_trn: dict[str, list[float]] = {k: [] for k in self.tgt_modalities} + y_true_trn: dict[str, list[int]] = {k: [] for k in self.tgt_modalities} + losses_trn: dict[str, list[float]] = {k: [] for k in self.tgt_modalities} + for n_iter, (x_batch, y_batch, mask_x, mask_y) in enumerate(ldr_trn): + # mount data to the proper device + x_batch = {k: x_batch[k].to(self.device) for k in self.src_modalities} + y_batch = {k: y_batch[k].to(torch.float).to(self.device) for k in self.tgt_modalities} + mask_x = {k: mask_x[k].to(self.device) for k in self.src_modalities} + mask_y = {k: mask_y[k].to(self.device) for k in self.tgt_modalities} + + # forward + with torch.autocast( + device_type = 'cpu' if self.device == 'cpu' else 'cuda', + dtype = torch.bfloat16 if self.device == 'cpu' else torch.float16, + enabled = self._amp_enabled, + ): + outputs = self.net_(x_batch, mask_x, is_embedding) + + # calculate multitask loss + loss = 0 + for i, tgt_k in enumerate(self.tgt_modalities): + loss_k = loss_func[tgt_k](outputs[tgt_k], y_batch[tgt_k]) + loss_k = torch.masked_select(loss_k, torch.logical_not(mask_y[tgt_k].squeeze())) + loss += loss_k.mean() + losses_trn[tgt_k] += loss_k.detach().cpu().numpy().tolist() + + # if self.lambd != 0: + + # backward + if self._amp_enabled: + scaler.scale(loss).backward() + else: + loss.backward() + + # update parameters + if n_iter != 0 and n_iter % self.batch_size_multiplier == 0: + if self._amp_enabled: + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + else: + optimizer.step() + optimizer.zero_grad() + + # save outputs to evaluate performance later + for tgt_k in self.tgt_modalities: + tmp = torch.masked_select(outputs[tgt_k], torch.logical_not(mask_y[tgt_k].squeeze())) + scores_trn[tgt_k] += tmp.detach().cpu().numpy().tolist() + tmp = torch.masked_select(y_batch[tgt_k], torch.logical_not(mask_y[tgt_k].squeeze())) + y_true_trn[tgt_k] += tmp.cpu().numpy().tolist() + + # update progress bar + if self.verbose > 1: + batch_size = len(next(iter(x_batch.values()))) + pbr_batch.update(batch_size, {}) + pbr_batch.refresh() + + # for better tqdm progress bar display + if self.verbose > 1: + pbr_batch.close() + + # set scheduler + scheduler.step() + + # calculate and print training performance metrics + y_pred_trn: dict[str, list[int]] = {k: [] for k in self.tgt_modalities} + y_prob_trn: dict[str, list[float]] = {k: [] for k in self.tgt_modalities} + for tgt_k in self.tgt_modalities: + for i in range(len(scores_trn[tgt_k])): + y_pred_trn[tgt_k].append(1 if scores_trn[tgt_k][i] > 0 else 0) + y_prob_trn[tgt_k].append(expit(scores_trn[tgt_k][i])) + met_trn = get_metrics_multitask(y_true_trn, y_pred_trn, y_prob_trn) + + # add loss to metrics + for tgt_k in self.tgt_modalities: + met_trn[tgt_k]['Loss'] = np.mean(losses_trn[tgt_k]) + + if self.verbose > 2: + print_metrics_multitask(met_trn) + + # progress bar for validation + if self.verbose > 1: + pbr_batch = ProgressBar(len(ldr_vld.dataset), 'Epoch {:03d} (VLD)'.format(epoch)) + + # set model to validation mode + torch.set_grad_enabled(False) + self.net_.eval() + + scores_vld: dict[str, list[float]] = {k: [] for k in self.tgt_modalities} + y_true_vld: dict[str, list[int]] = {k: [] for k in self.tgt_modalities} + losses_vld: dict[str, list[float]] = {k: [] for k in self.tgt_modalities} + for x_batch, y_batch, mask_x, mask_y in ldr_vld: + # mount data to the proper device + x_batch = {k: x_batch[k].to(self.device) for k in self.src_modalities} + y_batch = {k: y_batch[k].to(torch.float).to(self.device) for k in self.tgt_modalities} + mask_x = {k: mask_x[k].to(self.device) for k in self.src_modalities} + mask_y = {k: mask_y[k].to(self.device) for k in self.tgt_modalities} + + # forward + with torch.autocast( + device_type = 'cpu' if self.device == 'cpu' else 'cuda', + dtype = torch.bfloat16 if self.device == 'cpu' else torch.float16, + enabled = self._amp_enabled + ): + outputs = self.net_(x_batch, mask_x, is_embedding) + + # calculate multitask loss + for i, tgt_k in enumerate(self.tgt_modalities): + loss_k = loss_func[tgt_k](outputs[tgt_k], y_batch[tgt_k]) + loss_k = torch.masked_select(loss_k, torch.logical_not(mask_y[tgt_k].squeeze())) + losses_vld[tgt_k] += loss_k.detach().cpu().numpy().tolist() + + # save outputs to evaluate performance later + for tgt_k in self.tgt_modalities: + tmp = torch.masked_select(outputs[tgt_k], torch.logical_not(mask_y[tgt_k].squeeze())) + scores_vld[tgt_k] += tmp.detach().cpu().numpy().tolist() + tmp = torch.masked_select(y_batch[tgt_k], torch.logical_not(mask_y[tgt_k].squeeze())) + y_true_vld[tgt_k] += tmp.cpu().numpy().tolist() + + # update progress bar + if self.verbose > 1: + batch_size = len(next(iter(x_batch.values()))) + pbr_batch.update(batch_size, {}) + pbr_batch.refresh() + + # for better tqdm progress bar display + if self.verbose > 1: + pbr_batch.close() + + # calculate and print validation performance metrics + y_pred_vld: dict[str, list[int]] = {k: [] for k in self.tgt_modalities} + y_prob_vld: dict[str, list[float]] = {k: [] for k in self.tgt_modalities} + for tgt_k in self.tgt_modalities: + for i in range(len(scores_vld[tgt_k])): + y_pred_vld[tgt_k].append(1 if scores_vld[tgt_k][i] > 0 else 0) + y_prob_vld[tgt_k].append(expit(scores_vld[tgt_k][i])) + met_vld = get_metrics_multitask(y_true_vld, y_pred_vld, y_prob_vld) + + # add loss to metrics + for tgt_k in self.tgt_modalities: + met_vld[tgt_k]['Loss'] = np.mean(losses_vld[tgt_k]) + + if self.verbose > 2: + print_metrics_multitask(met_vld) + + # save the model if it has the best validation performance criterion by far + if self.criterion is None: continue + + # is current criterion better than previous best? + curr_crit = np.mean([met_vld[k][self.criterion] for k in self.tgt_modalities]) + if best_crit is None or np.isnan(best_crit): + is_better = True + elif self.criterion == 'Loss' and best_crit >= curr_crit: + is_better = True + elif self.criterion != 'Loss' and best_crit <= curr_crit: + is_better = True + else: + is_better = False + + # update best criterion + if is_better: + best_crit = curr_crit + best_state_dict = deepcopy(self.net_.state_dict()) + + if self.verbose > 2: + print('Best {}: {}'.format(self.criterion, best_crit)) + + if self.verbose == 1: + with self._lock if self._lock is not None else suppress(): + pbr_epoch.update(1) + pbr_epoch.refresh() + + if self.verbose == 1: + with self._lock if self._lock is not None else suppress(): + pbr_epoch.close() + + # restore the model of the best validation performance across all epoches + if ldr_vld is not None and self.criterion is not None: + self.net_.load_state_dict(best_state_dict) + + return self + + def predict_logits(self, + x: list[dict[str, Any]], + is_embedding: dict[str, bool] | None = None, + _batch_size: int | None = None, + ) -> list[dict[str, float]]: + ''' + The input x can be a single sample or a list of samples. + ''' + # input validation + check_is_fitted(self) + + # for PyTorch computational efficiency + torch.set_num_threads(1) + + # set model to eval mode + torch.set_grad_enabled(False) + self.net_.eval() + + # intialize dataset and dataloader object + dat = TransformerTestingDataset(x, self.src_modalities, is_embedding) + ldr = DataLoader( + dataset = dat, + batch_size = _batch_size if _batch_size is not None else len(x), + shuffle = False, + drop_last = False, + num_workers = 0, + collate_fn = TransformerTestingDataset.collate_fn, + ) + + # run model and collect results + logits: list[dict[str, float]] = [] + for x_batch, mask_x in ldr: + # mount data to the proper device + x_batch = {k: x_batch[k].to(self.device) for k in self.src_modalities} + mask_x = {k: mask_x[k].to(self.device) for k in self.src_modalities} + + # forward + output: dict[str, Tensor] = self.net_(x_batch, mask_x, is_embedding) + + # convert output from dict-of-list to list of dict, then append + tmp = {k: output[k].tolist() for k in self.tgt_modalities} + tmp = [{k: tmp[k][i] for k in self.tgt_modalities} for i in range(len(next(iter(tmp.values()))))] + logits += tmp + + return logits + + def predict_proba(self, + x: list[dict[str, Any]], + is_embedding: dict[str, bool] | None = None, + temperature: float = 1.0, + _batch_size: int | None = None, + ) -> list[dict[str, float]]: + ''' ... ''' + logits = self.predict_logits(x, is_embedding, _batch_size) + return [{k: expit(smp[k] / temperature) for k in self.tgt_modalities} for smp in logits] + + def predict(self, + x: list[dict[str, Any]], + is_embedding: dict[str, bool] | None = None, + _batch_size: int | None = None, + ) -> list[dict[str, int]]: + ''' ... ''' + logits = self.predict_logits(x, is_embedding, _batch_size) + return [{k: int(smp[k] > 0.0) for k in self.tgt_modalities} for smp in logits] + + def save(self, filepath: str) -> None: + ''' ... ''' + check_is_fitted(self) + state_dict = self.net_.state_dict() + + # attach model hyper parameters + state_dict['src_modalities'] = self.src_modalities + state_dict['tgt_modalities'] = self.tgt_modalities + state_dict['d_model'] = self.d_model + state_dict['nhead'] = self.nhead + state_dict['num_layers'] = self.num_layers + torch.save(state_dict, filepath) + + def load(self, filepath: str) -> None: + ''' ... ''' + # load state_dict + state_dict = torch.load(filepath, map_location='cpu') + + # load essential parameters + self.src_modalities: dict[str, dict[str, Any]] = state_dict.pop('src_modalities') + self.tgt_modalities: dict[str, dict[str, Any]] = state_dict.pop('tgt_modalities') + self.d_model = state_dict.pop('d_model') + self.nhead = state_dict.pop('nhead') + self.num_layers = state_dict.pop('num_layers') + + # initialize model + self.net_ = nn.Transformer( + self.src_modalities, + self.tgt_modalities, + self.d_model, + self.nhead, + self.num_layers, + ) + + # load model parameters + self.net_.load_state_dict(state_dict) + self.to(self.device) + + def to(self, device: str) -> Self: + ''' Mount model to the given device. ''' + self.device = device + if hasattr(self, 'net_'): self.net_ = self.net_.to(device) + return self + + @classmethod + def from_ckpt(cls, filepath: str) -> Self: + ''' ... ''' + obj = cls(None, None) + obj.load(filepath) + return obj + + def _init_net(self): + """ ... """ + net = nn.Transformer( + self.src_modalities, + self.tgt_modalities, + self.d_model, + self.nhead, + self.num_layers, + ).to(self.device) + + # train on multiple GPUs using torch.nn.DataParallel + if self._device_ids is not None: + net = torch.nn.DataParallel(net, device_ids=self._device_ids) + + # intialize model parameters using xavier_uniform + for p in net.parameters(): + if p.dim() > 1: + torch.nn.init.xavier_uniform_(p) + + return net + + def _init_dataloader(self, x, y, is_embedding): + """ ... """ + # split dataset + x_trn, x_vld, y_trn, y_vld = train_test_split( + x, y, test_size = 0.2, random_state = 0, + ) + + # initialize dataset and dataloader + # dat_trn = TransformerTrainingDataset( + # dat_trn = TransformerBalancedTrainingDataset( + dat_trn = Transformer2ndOrderBalancedTrainingDataset( + x_trn, y_trn, + self.src_modalities, + self.tgt_modalities, + dropout_rate = .5, + # dropout_strategy = 'compensated', + dropout_strategy = 'permutation', + ) + + dat_vld = TransformerValidationDataset( + x_vld, y_vld, + self.src_modalities, + self.tgt_modalities, + is_embedding, + ) + + ldr_trn = DataLoader( + dataset = dat_trn, + batch_size = self.batch_size, + shuffle = True, + drop_last = False, + num_workers = self._dataloader_num_workers, + collate_fn = TransformerTrainingDataset.collate_fn, + # pin_memory = True + ) + + ldr_vld = DataLoader( + dataset = dat_vld, + batch_size = self.batch_size, + shuffle = False, + drop_last = False, + num_workers = self._dataloader_num_workers, + collate_fn = TransformerValidationDataset.collate_fn, + # pin_memory = True + ) + + return ldr_trn, ldr_vld + + def _init_optimizer(self): + """ ... """ + return torch.optim.AdamW( + self.net_.parameters(), + lr = self.lr, + betas = (0.9, 0.98), + weight_decay = self.weight_decay + ) + + def _init_scheduler(self, optimizer): + """ ... """ + return torch.optim.lr_scheduler.OneCycleLR( + optimizer = optimizer, + max_lr = self.lr, + total_steps = self.num_epochs, + verbose = (self.verbose > 2) + ) + + def _init_loss_func(self, + num_per_cls: dict[str, tuple[int, int]], + ) -> dict[str, Module]: + """ ... """ + return {k: nn.SigmoidFocalLoss( + beta = self.beta, + gamma = self.gamma, + scale = self.scale, + num_per_cls = num_per_cls[k], + reduction = 'none', + ) for k in self.tgt_modalities} + + def _extract_embedding(self, + x: list[dict[str, Any]], + is_embedding: dict[str, bool] | None = None, + _batch_size: int | None = None, + ) -> list[dict[str, Any]]: + """ ... """ + # input validation + check_is_fitted(self) + + # for PyTorch computational efficiency + torch.set_num_threads(1) + + # set model to eval mode + torch.set_grad_enabled(False) + self.net_.eval() + + # intialize dataset and dataloader object + dat = TransformerTestingDataset(x, self.src_modalities, is_embedding) + ldr = DataLoader( + dataset = dat, + batch_size = _batch_size if _batch_size is not None else len(x), + shuffle = False, + drop_last = False, + num_workers = 0, + collate_fn = TransformerTestingDataset.collate_fn, + ) + + # run model and extract embeddings + embeddings: list[dict[str, Any]] = [] + for x_batch, _ in ldr: + # mount data to the proper device + x_batch = {k: x_batch[k].to(self.device) for k in self.src_modalities} + + # forward + out: dict[str, Tensor] = self.net_.forward_emb(x_batch, is_embedding) + + # convert output from dict-of-list to list of dict, then append + tmp = {k: out[k].detach().cpu().numpy() for k in self.src_modalities} + tmp = [{k: tmp[k][i] for k in self.src_modalities} for i in range(len(next(iter(tmp.values()))))] + embeddings += tmp + + # remove imputed embeddings + for i in range(len(x)): + avail = [k for k, v in x[i].items() if v is not None] + embeddings[i] = {k: embeddings[i][k] for k in avail} + + return embeddings + diff --git a/adrd/nn/__init__.py b/adrd/nn/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..3ac9efb960c28e7fdf5782305faa6dae9f3e6d32 --- /dev/null +++ b/adrd/nn/__init__.py @@ -0,0 +1,12 @@ +from .transformer import Transformer +from .vitautoenc import ViTAutoEnc +from .unet import UNet3D +from .unet_3d import UNet3DBase +from .focal_loss import SigmoidFocalLoss +from .unet_img_model import ImageModel +from .img_model_wrapper import ImagingModelWrapper +from .resnet_img_model import ResNetModel +from .c3d import C3D +from .dense_net import DenseNet +from .cnn_resnet3d import CNNResNet3D +from .cnn_resnet3d_with_linear_classifier import CNNResNet3DWithLinearClassifier diff --git a/adrd/nn/__pycache__/__init__.cpython-311.pyc b/adrd/nn/__pycache__/__init__.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..78839d2e5fef229e4d88fbc8b7926a2de5887964 Binary files /dev/null and b/adrd/nn/__pycache__/__init__.cpython-311.pyc differ diff --git a/adrd/nn/__pycache__/blocks.cpython-311.pyc b/adrd/nn/__pycache__/blocks.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..f9d90dcd95ab6ddc9812868543410bf0b546961b Binary files /dev/null and b/adrd/nn/__pycache__/blocks.cpython-311.pyc differ diff --git a/adrd/nn/__pycache__/c3d.cpython-311.pyc b/adrd/nn/__pycache__/c3d.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..c4cfd404397423c814b9cd53566c62125e0e6071 Binary files /dev/null and b/adrd/nn/__pycache__/c3d.cpython-311.pyc differ diff --git a/adrd/nn/__pycache__/cnn_resnet3d.cpython-311.pyc b/adrd/nn/__pycache__/cnn_resnet3d.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..edf5b81d436102717a237b0d377715a12d1037ec Binary files /dev/null and b/adrd/nn/__pycache__/cnn_resnet3d.cpython-311.pyc differ diff --git a/adrd/nn/__pycache__/cnn_resnet3d_with_linear_classifier.cpython-311.pyc b/adrd/nn/__pycache__/cnn_resnet3d_with_linear_classifier.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..09e2bcb9897771fbe70c26006a97bce753560800 Binary files /dev/null and b/adrd/nn/__pycache__/cnn_resnet3d_with_linear_classifier.cpython-311.pyc differ diff --git a/adrd/nn/__pycache__/dense_net.cpython-311.pyc b/adrd/nn/__pycache__/dense_net.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..f99d1b6f9605865c72ec0cdbb23e270f4ffdc72e Binary files /dev/null and b/adrd/nn/__pycache__/dense_net.cpython-311.pyc differ diff --git a/adrd/nn/__pycache__/focal_loss.cpython-311.pyc b/adrd/nn/__pycache__/focal_loss.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..4cc786c33c4ef40e982c84a85dec544e22904e4e Binary files /dev/null and b/adrd/nn/__pycache__/focal_loss.cpython-311.pyc differ diff --git a/adrd/nn/__pycache__/img_model_wrapper.cpython-311.pyc b/adrd/nn/__pycache__/img_model_wrapper.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..75353211ec098f630a302e0614a8920fe0287de9 Binary files /dev/null and b/adrd/nn/__pycache__/img_model_wrapper.cpython-311.pyc differ diff --git a/adrd/nn/__pycache__/net_resnet3d.cpython-311.pyc b/adrd/nn/__pycache__/net_resnet3d.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..a6e85b50362245eb09976b541a9075df4c29aeeb Binary files /dev/null and b/adrd/nn/__pycache__/net_resnet3d.cpython-311.pyc differ diff --git a/adrd/nn/__pycache__/resnet3d.cpython-311.pyc b/adrd/nn/__pycache__/resnet3d.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..b08318600408da439cda1da6c6ce67b4c344eeaa Binary files /dev/null and b/adrd/nn/__pycache__/resnet3d.cpython-311.pyc differ diff --git a/adrd/nn/__pycache__/resnet_img_model.cpython-311.pyc b/adrd/nn/__pycache__/resnet_img_model.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..f88d0ab7b1eb808219bbd1364e45993ba788fc2e Binary files /dev/null and b/adrd/nn/__pycache__/resnet_img_model.cpython-311.pyc differ diff --git a/adrd/nn/__pycache__/selfattention.cpython-311.pyc b/adrd/nn/__pycache__/selfattention.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..2a5f32314884db9e2a19c851040dcdc94f8cf734 Binary files /dev/null and b/adrd/nn/__pycache__/selfattention.cpython-311.pyc differ diff --git a/adrd/nn/__pycache__/transformer.cpython-311.pyc b/adrd/nn/__pycache__/transformer.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..f59644dbacdfb5d83672cc5a993ad04ac36d3d0a Binary files /dev/null and b/adrd/nn/__pycache__/transformer.cpython-311.pyc differ diff --git a/adrd/nn/__pycache__/unet.cpython-311.pyc b/adrd/nn/__pycache__/unet.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..8b94e49d19692d451e91e58a5a658c1f89b51d20 Binary files /dev/null and b/adrd/nn/__pycache__/unet.cpython-311.pyc differ diff --git a/adrd/nn/__pycache__/unet_3d.cpython-311.pyc b/adrd/nn/__pycache__/unet_3d.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..6daeead47f9778e327ac1ee2e92985a863cbee8c Binary files /dev/null and b/adrd/nn/__pycache__/unet_3d.cpython-311.pyc differ diff --git a/adrd/nn/__pycache__/unet_img_model.cpython-311.pyc b/adrd/nn/__pycache__/unet_img_model.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..c0d06f20f45118f90ad6e35207013e8457eccbcf Binary files /dev/null and b/adrd/nn/__pycache__/unet_img_model.cpython-311.pyc differ diff --git a/adrd/nn/__pycache__/vitautoenc.cpython-311.pyc b/adrd/nn/__pycache__/vitautoenc.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..cf0b6e970d9a18f43e85208732b1cec05ab8b24c Binary files /dev/null and b/adrd/nn/__pycache__/vitautoenc.cpython-311.pyc differ diff --git a/adrd/nn/blocks.py b/adrd/nn/blocks.py new file mode 100755 index 0000000000000000000000000000000000000000..4a591b41c44f3093486b654b99a3e4c8ccceaac9 --- /dev/null +++ b/adrd/nn/blocks.py @@ -0,0 +1,57 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from monai.networks.blocks.mlp import MLPBlock +from typing import Sequence, Union +import torch +import torch.nn as nn + +from ..nn.selfattention import SABlock + +class TransformerBlock(nn.Module): + """ + A transformer block, based on: "Dosovitskiy et al., + An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " + """ + + def __init__( + self, hidden_size: int, mlp_dim: int, num_heads: int, dropout_rate: float = 0.0, qkv_bias: bool = False + ) -> None: + """ + Args: + hidden_size: dimension of hidden layer. + mlp_dim: dimension of feedforward layer. + num_heads: number of attention heads. + dropout_rate: faction of the input units to drop. + qkv_bias: apply bias term for the qkv linear layer + + """ + + super().__init__() + + if not (0 <= dropout_rate <= 1): + raise ValueError("dropout_rate should be between 0 and 1.") + + if hidden_size % num_heads != 0: + raise ValueError("hidden_size should be divisible by num_heads.") + + self.mlp = MLPBlock(hidden_size, mlp_dim, dropout_rate) + self.norm1 = nn.LayerNorm(hidden_size) + self.attn = SABlock(hidden_size, num_heads, dropout_rate, qkv_bias) + self.norm2 = nn.LayerNorm(hidden_size) + + def forward(self, x, return_attention=False): + y, attn = self.attn(self.norm1(x)) + if return_attention: + return attn + x = x + y + x = x + self.mlp(self.norm2(x)) + return x diff --git a/adrd/nn/c3d.py b/adrd/nn/c3d.py new file mode 100755 index 0000000000000000000000000000000000000000..9cd8a9c380c4163244089ff25be215d1c0599a3c --- /dev/null +++ b/adrd/nn/c3d.py @@ -0,0 +1,99 @@ +# From https://github.com/xmuyzz/3D-CNN-PyTorch/blob/master/models/C3DNet.py + +import torch +import torch.nn as nn +import sys +# from icecream import ic +import math + +class C3D(torch.nn.Module): + + def __init__(self, tgt_modalities, in_channels=1, load_from_ckpt=None): + + super(C3D, self).__init__() + self.conv_group1 = nn.Sequential( + nn.Conv3d(in_channels, 64, kernel_size=3, padding=1), + nn.BatchNorm3d(64), + nn.ReLU(), + nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(1, 2, 2))) + self.conv_group2 = nn.Sequential( + nn.Conv3d(64, 128, kernel_size=3, padding=1), + nn.BatchNorm3d(128), + nn.ReLU(), + nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2))) + self.conv_group3 = nn.Sequential( + nn.Conv3d(128, 256, kernel_size=3, padding=1), + nn.BatchNorm3d(256), + nn.ReLU(), + nn.Conv3d(256, 256, kernel_size=3, padding=1), + nn.BatchNorm3d(256), + nn.ReLU(), + nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2)) + ) + self.conv_group4 = nn.Sequential( + nn.Conv3d(256, 512, kernel_size=3, padding=1), + nn.BatchNorm3d(512), + nn.ReLU(), + nn.Conv3d(512, 512, kernel_size=3, padding=1), + nn.BatchNorm3d(512), + nn.ReLU(), + nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=(0, 1, 1)) + ) + + # last_duration = int(math.floor(128 / 16)) + # last_size = int(math.ceil(128 / 32)) + self.fc1 = nn.Sequential( + nn.Linear((512 * 15 * 9 * 9) , 512), + nn.ReLU(), + nn.Dropout(0.5)) + self.fc2 = nn.Sequential( + nn.Linear(512, 256), + nn.ReLU(), + nn.Dropout(0.5)) + # self.fc = nn.Sequential( + # nn.Linear(4096, num_classes)) + + self.fc = torch.nn.ModuleDict() + for k in tgt_modalities: + self.fc[k] = torch.nn.Linear(256, 1) + + def forward(self, x): + # for k in x.keys(): + # x[k] = x[k].to(torch.float32) + + # x = torch.stack([o for o in x.values()], dim=0)[0] + # print(x.shape) + + out = self.conv_group1(x) + out = self.conv_group2(out) + out = self.conv_group3(out) + out = self.conv_group4(out) + out = out.view(out.size(0), -1) + # print(out.shape) + out = self.fc1(out) + out = self.fc2(out) + # out = self.fc(out) + + tgt_iter = self.fc.keys() + out_tgt = {k: self.fc[k](out).squeeze(1) for k in tgt_iter} + return out_tgt + + +if __name__ == "__main__": + model = C3D(tgt_modalities=['NC', 'MCI', 'DE']) + print(model) + x = torch.rand((1, 1, 128, 128, 128)) + # layers = list(model.features.named_children()) + # features = nn.Sequential(*list(model.features.children()))(x) + # print(features.shape) + print(sum(p.numel() for p in model.parameters())) + # layer_found = False + # features = None + # desired_layer_name = 'transition3' + + # for name, layer in layers: + # if name == desired_layer_name: + # x = layer(x) + # print(x) + # model(x) + # print(features) \ No newline at end of file diff --git a/adrd/nn/cnn_resnet3d.py b/adrd/nn/cnn_resnet3d.py new file mode 100755 index 0000000000000000000000000000000000000000..e4ec443339567704a1b8276be9a0b0be05baead3 --- /dev/null +++ b/adrd/nn/cnn_resnet3d.py @@ -0,0 +1,81 @@ +import torch +import torch.nn as nn +from typing import Any, Type +Tensor = Type[torch.Tensor] + +from .resnet3d import r3d_18 + + +class CNNResNet3D(nn.Module): + + def __init__(self, + src_modalities: dict[str, dict[str, Any]], + tgt_modalities: dict[str, dict[str, Any]] + ) -> None: + """ ... """ + super().__init__() + + # resnet + # embedding modules for source + self.modules_emb_src = nn.ModuleDict() + for k, info in src_modalities.items(): + if info['type'] == 'imaging' and len(info['img_shape']) == 4: + self.modules_emb_src[k] = nn.Sequential( + r3d_18(), + nn.Dropout(0.5) + ) + else: + # unrecognized + raise ValueError('{} is an unrecognized data modality'.format(k)) + + # classifiers (binary only) + self.modules_cls = nn.ModuleDict() + for k, info in tgt_modalities.items(): + if info['type'] == 'categorical' and info['num_categories'] == 2: + # categorical + self.modules_cls[k] = nn.Linear(256, 1) + else: + # unrecognized + raise ValueError + + def forward(self, + x: dict[str, Tensor], + ) -> dict[str, Tensor]: + """ ... """ + out_emb = self.forward_emb(x) + out_emb = out_emb[list(out_emb.keys())[0]] + out_cls = self.forward_cls(out_emb) + return out_cls + + def forward_emb(self, + x: dict[str, Tensor], + ) -> dict[str, Tensor]: + """ ... """ + out_emb = dict() + for k in self.modules_emb_src.keys(): + out_emb[k] = self.modules_emb_src[k](x[k]) + return out_emb + + def forward_cls(self, + out_emb: dict[str, Tensor] + ) -> dict[str, Tensor]: + """ ... """ + out_cls = dict() + for k in self.modules_cls.keys(): + out_cls[k] = self.modules_cls[k](out_emb).squeeze(1) + return out_cls + + +# for testing purpose only +if __name__ == '__main__': + src_modalities = { + 'img_MRI_T1': {'type': 'imaging', 'img_shape': [1, 182, 218, 182]} + } + tgt_modalities = { + 'AD': {'type': 'categorical', 'num_categories': 2}, + 'PD': {'type': 'categorical', 'num_categories': 2} + } + net = CNNResNet3D(src_modalities, tgt_modalities) + net.eval() + x = {'img_MRI_T1': torch.zeros(2, 1, 182, 218, 182)} + print(net(x)) \ No newline at end of file diff --git a/adrd/nn/cnn_resnet3d_with_linear_classifier.py b/adrd/nn/cnn_resnet3d_with_linear_classifier.py new file mode 100755 index 0000000000000000000000000000000000000000..aeb77defe92ac325dff8a1993b5efb28016bfae5 --- /dev/null +++ b/adrd/nn/cnn_resnet3d_with_linear_classifier.py @@ -0,0 +1,56 @@ +import torch +import torch.nn as nn +from typing import Any, Type +Tensor = Type[torch.Tensor] + +from .resnet3d import r3d_18 + +class CNNResNet3DWithLinearClassifier(nn.Module): + + def __init__(self, + src_modalities: dict[str, dict[str, Any]], + tgt_modalities: dict[str, dict[str, Any]] + ) -> None: + """ ... """ + super().__init__() + self.core = _CNNResNet3DWithLinearClassifier(len(tgt_modalities)) + self.src_modalities = src_modalities + self.tgt_modalities = tgt_modalities + + def forward(self, + x: dict[str, Tensor], + ) -> dict[str, Tensor]: + """ x is expected to be a singleton dictionary """ + src_k = list(x.keys())[0] + x = x[src_k] + out = self.core(x) + out = {tgt_k: out[:, i] for i, tgt_k in enumerate(self.tgt_modalities)} + return out + + +class _CNNResNet3DWithLinearClassifier(nn.Module): + + def __init__(self, + len_tgt_modalities: int, + ) -> None: + """ ... """ + super().__init__() + self.cnn = r3d_18() + self.cls = nn.Sequential( + nn.Dropout(0.5), + nn.Linear(256, len_tgt_modalities), + ) + + def forward(self, x: Tensor) -> Tensor: + """ ... """ + out_emb = self.forward_emb(x) + out_cls = self.forward_cls(out_emb) + return out_cls + + def forward_emb(self, x: Tensor) -> Tensor: + """ ... """ + return self.cnn(x) + + def forward_cls(self, out_emb: Tensor) -> Tensor: + """ ... """ + return self.cls(out_emb) \ No newline at end of file diff --git a/adrd/nn/dense_net.py b/adrd/nn/dense_net.py new file mode 100755 index 0000000000000000000000000000000000000000..e9a26dcce13a173107bbcc5e14d73b9b4fb4ada7 --- /dev/null +++ b/adrd/nn/dense_net.py @@ -0,0 +1,211 @@ +# This implementation is based on the DenseNet-BC implementation in torchvision +# https://github.com/pytorch/vision/blob/master/torchvision/models/densenet.py +# https://github.com/gpleiss/efficient_densenet_pytorch/blob/master/models/densenet.py + + +import math +import torch +import numpy as np +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as cp +from collections import OrderedDict + + +def _bn_function_factory(norm, relu, conv): + def bn_function(*inputs): + concated_features = torch.cat(inputs, 1) + bottleneck_output = conv(relu(norm(concated_features))) + return bottleneck_output + + return bn_function + + +class _DenseLayer(nn.Module): + def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, efficient=False): + super(_DenseLayer, self).__init__() + self.add_module('norm1', nn.BatchNorm3d(num_input_features)), + self.add_module('relu1', nn.ReLU(inplace=True)), + self.add_module('conv1', nn.Conv3d(num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False)), + self.add_module('norm2', nn.BatchNorm3d(bn_size * growth_rate)), + self.add_module('relu2', nn.ReLU(inplace=True)), + self.add_module('conv2', nn.Conv3d(bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False)), + self.drop_rate = drop_rate + self.efficient = efficient + + def forward(self, *prev_features): + bn_function = _bn_function_factory(self.norm1, self.relu1, self.conv1) + if self.efficient and any(prev_feature.requires_grad for prev_feature in prev_features): + bottleneck_output = cp.checkpoint(bn_function, *prev_features) + else: + bottleneck_output = bn_function(*prev_features) + new_features = self.conv2(self.relu2(self.norm2(bottleneck_output))) + if self.drop_rate > 0: + new_features = F.dropout(new_features, p=self.drop_rate, training=self.training) + return new_features + + +class _Transition(nn.Sequential): + def __init__(self, num_input_features, num_output_features): + super(_Transition, self).__init__() + self.add_module('norm', nn.BatchNorm3d(num_input_features)) + self.add_module('relu', nn.ReLU(inplace=True)) + self.add_module('conv', nn.Conv3d(num_input_features, num_output_features, kernel_size=1, stride=1, bias=False)) + self.add_module('pool', nn.AvgPool3d(kernel_size=2, stride=2)) + + +class _DenseBlock(nn.Module): + def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate, efficient=False): + super(_DenseBlock, self).__init__() + for i in range(num_layers): + layer = _DenseLayer( + num_input_features + i * growth_rate, + growth_rate=growth_rate, + bn_size=bn_size, + drop_rate=drop_rate, + efficient=efficient, + ) + self.add_module('denselayer%d' % (i + 1), layer) + + def forward(self, init_features): + features = [init_features] + for name, layer in self.named_children(): + new_features = layer(*features) + features.append(new_features) + return torch.cat(features, 1) + + +class DenseNet(nn.Module): + r"""Densenet-BC model class, based on + `"Densely Connected Convolutional Networks" ` + Args: + growth_rate (int) - how many filters to add each layer (`k` in paper) + block_config (list of 3 or 4 ints) - how many layers in each pooling block + num_init_features (int) - the number of filters to learn in the first convolution layer + bn_size (int) - multiplicative factor for number of bottle neck layers + (i.e. bn_size * k features in the bottleneck layer) + drop_rate (float) - dropout rate after each dense layer + tgt_modalities (list) - list of target modalities + efficient (bool) - set to True to use checkpointing. Much more memory efficient, but slower. + """ + # def __init__(self, tgt_modalities, growth_rate=12, block_config=(3, 3, 3), compression=0.5, + # num_init_features=16, bn_size=4, drop_rate=0, efficient=False, load_from_ckpt=False): # config 1 + + def __init__(self, tgt_modalities, growth_rate=12, block_config=(3, 3, 3), compression=0.5, + num_init_features=16, bn_size=4, drop_rate=0, efficient=False, load_from_ckpt=False): # config 2 + + super(DenseNet, self).__init__() + + # First convolution + self.features = nn.Sequential(OrderedDict([('conv0', nn.Conv3d(1, num_init_features, kernel_size=7, stride=2, padding=0, bias=False)),])) + self.features.add_module('norm0', nn.BatchNorm3d(num_init_features)) + self.features.add_module('relu0', nn.ReLU(inplace=True)) + self.features.add_module('pool0', nn.MaxPool3d(kernel_size=3, stride=2, padding=0, ceil_mode=False)) + self.tgt_modalities = tgt_modalities + + # Each denseblock + num_features = num_init_features + for i, num_layers in enumerate(block_config): + block = _DenseBlock( + num_layers=num_layers, + num_input_features=num_features, + bn_size=bn_size, + growth_rate=growth_rate, + drop_rate=drop_rate, + efficient=efficient, + ) + self.features.add_module('denseblock%d' % (i + 1), block) + num_features = num_features + num_layers * growth_rate + if i != len(block_config): + trans = _Transition(num_input_features=num_features, + num_output_features=int(num_features * compression)) + self.features.add_module('transition%d' % (i + 1), trans) + num_features = int(num_features * compression) + + # Final batch norm + self.features.add_module('norm_final', nn.BatchNorm3d(num_features)) + + # Classification heads + self.tgt = torch.nn.ModuleDict() + for k in tgt_modalities: + # self.tgt[k] = torch.nn.Linear(621, 1) # config 2 + self.tgt[k] = torch.nn.Sequential( + torch.nn.Linear(self.test_size(), 256), + torch.nn.ReLU(), + torch.nn.Linear(256, 1) + ) + + print(f'load_from_ckpt: {load_from_ckpt}') + # Initialization + if not load_from_ckpt: + for name, param in self.named_parameters(): + if 'conv' in name and 'weight' in name: + n = param.size(0) * param.size(2) * param.size(3) * param.size(4) + param.data.normal_().mul_(math.sqrt(2. / n)) + elif 'norm' in name and 'weight' in name: + param.data.fill_(1) + elif 'norm' in name and 'bias' in name: + param.data.fill_(0) + elif ('classifier' in name or 'tgt' in name) and 'bias' in name: + param.data.fill_(0) + + # self.size = self.test_size() + + def forward(self, x, shap=True): + # print(x.shape) + features = self.features(x) + # print(features.shape) + out = F.relu(features, inplace=True) + # out = F.adaptive_avg_pool3d(out, (1, 1, 1)) + out = torch.flatten(out, 1) + + # print(out.shape) + + # out_tgt = self.tgt(out).squeeze(1) + # print(out_tgt) + # return F.softmax(out_tgt) + + tgt_iter = self.tgt.keys() + out_tgt = {k: self.tgt[k](out).squeeze(1) for k in tgt_iter} + if shap: + out_tgt = torch.stack(list(out_tgt.values())) + return out_tgt.T + else: + return out_tgt + + def test_size(self): + case = torch.ones((1, 1, 182, 218, 182)) + output = self.features(case).view(-1).size(0) + return output + + +if __name__ == "__main__": + model = DenseNet( + tgt_modalities=['NC', 'MCI', 'DE'], + growth_rate=12, + block_config=(2, 3, 2), + compression=0.5, + num_init_features=16, + drop_rate=0.2) + print(model) + torch.manual_seed(42) + x = torch.rand((1, 1, 182, 218, 182)) + # layers = list(model.features.named_children()) + features = nn.Sequential(*list(model.features.children()))(x) + print(features.shape) + print(sum(p.numel() for p in model.parameters())) + # out = mdl.net_(x, shap=False) + # print(out) + + out = model(x, shap=False) + print(out) + # layer_found = False + # features = None + # desired_layer_name = 'transition3' + + # for name, layer in layers: + # if name == desired_layer_name: + # x = layer(x) + # print(x) + # model(x) + # print(features) \ No newline at end of file diff --git a/adrd/nn/focal_loss.py b/adrd/nn/focal_loss.py new file mode 100755 index 0000000000000000000000000000000000000000..f26cb46ece693df2486f7e629242e13b10f024de --- /dev/null +++ b/adrd/nn/focal_loss.py @@ -0,0 +1,120 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import sys + +class SigmoidFocalLoss(nn.Module): + ''' ... ''' + def __init__( + self, + alpha: float = -1, + gamma: float = 2.0, + reduction: str = 'mean', + ): + ''' ... ''' + super().__init__() + self.alpha = alpha + self.gamma = gamma + self.reduction = reduction + + def forward(self, input, target): + ''' ... ''' + p = torch.sigmoid(input) + ce_loss = F.binary_cross_entropy_with_logits(input, target, reduction='none') + p_t = p * target + (1 - p) * (1 - target) + loss = ce_loss * ((1 - p_t) ** self.gamma) + + if self.alpha >= 0: + alpha_t = self.alpha * target + (1 - self.alpha) * (1 - target) + loss = alpha_t * loss + + if self.reduction == 'mean': + loss = loss.mean() + elif self.reduction == 'sum': + loss = loss.sum() + + return loss + + +class SigmoidFocalLossBeta(nn.Module): + ''' ... ''' + def __init__( + self, + beta: float = 0.9999, + gamma: float = 2.0, + num_per_cls = (1, 1), + reduction: str = 'mean', + ): + ''' ... ''' + super().__init__() + eps = sys.float_info.epsilon + self.gamma = gamma + self.reduction = reduction + + # weights to balance loss + self.weight_neg = ((1 - beta) / (1 - beta ** num_per_cls[0] + eps)) + self.weight_pos = ((1 - beta) / (1 - beta ** num_per_cls[1] + eps)) + # weight_neg = (1 - beta) / (1 - beta ** num_per_cls[0]) + # weight_pos = (1 - beta) / (1 - beta ** num_per_cls[1]) + # self.weight_neg = weight_neg / (weight_neg + weight_pos) + # self.weight_pos = weight_pos / (weight_neg + weight_pos) + + def forward(self, input, target): + ''' ... ''' + p = torch.sigmoid(input) + p_t = p * target + (1 - p) * (1 - target) + ce_loss = F.binary_cross_entropy_with_logits(input, target, reduction='none') + loss = ce_loss * ((1 - p_t) ** self.gamma) + + alpha_t = self.weight_pos * target + self.weight_neg * (1 - target) + loss = alpha_t * loss + + if self.reduction == 'mean': + loss = loss.mean() + elif self.reduction == 'sum': + loss = loss.sum() + + return loss + +class AsymmetricLoss(nn.Module): + def __init__(self, gamma_neg=4, gamma_pos=1, alpha=0.5, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=True): + super(AsymmetricLoss, self).__init__() + self.alpha = alpha + self.gamma_neg = gamma_neg + self.gamma_pos = gamma_pos + self.clip = clip + self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss + self.eps = eps + + + def forward(self, x, y): + """" + Parameters + ---------- + x: input logits + y: targets (multi-label binarized vector) + """ + # Calculating Probabilities + x_sigmoid = torch.sigmoid(x) + xs_pos = x_sigmoid + xs_neg = 1 - x_sigmoid + # Asymmetric Clipping + if self.clip is not None and self.clip > 0: + xs_neg = (xs_neg + self.clip).clamp(max=1) + # Basic CE calculation + los_pos = y * torch.log(xs_pos.clamp(min=self.eps)) + los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps)) + loss = self.alpha*los_pos + (1-self.alpha)*los_neg + # Asymmetric Focusing + if self.gamma_neg > 0 or self.gamma_pos > 0: + if self.disable_torch_grad_focal_loss: + torch.set_grad_enabled(False) + pt0 = xs_pos * y + pt1 = xs_neg * (1 - y) # pt = p if t > 0 else 1-p + pt = pt0 + pt1 + one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y) + one_sided_w = torch.pow(1 - pt, one_sided_gamma) + if self.disable_torch_grad_focal_loss: + torch.set_grad_enabled(True) + loss *= one_sided_w + return -loss#.sum() diff --git a/adrd/nn/img_model_wrapper.py b/adrd/nn/img_model_wrapper.py new file mode 100755 index 0000000000000000000000000000000000000000..88ca17a4db852446d3149dfac157832611b2b3d1 --- /dev/null +++ b/adrd/nn/img_model_wrapper.py @@ -0,0 +1,174 @@ +import torch +from .. import nn +from .. import model +import numpy as np +from icecream import ic +from monai.networks.nets.swin_unetr import SwinUNETR +from typing import Any + +class ImagingModelWrapper(torch.nn.Module): + def __init__( + self, + arch: str = 'ViTAutoEnc', + tgt_modalities: dict | None = {}, + img_size: int | None = 128, + patch_size: int | None = 16, + ckpt_path: str | None = None, + train_backbone: bool = False, + out_dim: int = 128, + layers: int | None = 1, + device: str = 'cpu', + fusion_stage: str = 'middle', + ): + super(ImagingModelWrapper, self).__init__() + + self.arch = arch + self.tgt_modalities = tgt_modalities + self.img_size = img_size + self.patch_size = patch_size + self.train_backbone = train_backbone + self.ckpt_path = ckpt_path + self.device = device + self.out_dim = out_dim + self.layers = layers + self.fusion_stage = fusion_stage + + + if "swinunetr" in self.arch.lower(): + if "emb" not in self.arch.lower(): + ckpt_path = '/projectnb/ivc-ml/dlteif/pretrained_models/model_swinvit.pt' + ckpt = torch.load(ckpt_path, map_location='cpu') + self.img_model = SwinUNETR( + in_channels=1, + out_channels=1, + img_size=128, + feature_size=48, + use_checkpoint=True, + ) + ckpt["state_dict"] = {k.replace("swinViT.", "module."): v for k, v in ckpt["state_dict"].items()} + ic(ckpt["state_dict"].keys()) + self.img_model.load_from(ckpt) + self.dim = 768 + + elif "vit" in self.arch.lower(): + if "emb" not in self.arch.lower(): + # Initialize image model + self.img_model = nn.__dict__[self.arch]( + in_channels = 1, + img_size = self.img_size, + patch_size = self.patch_size, + ) + + if self.ckpt_path: + self.img_model.load(self.ckpt_path, map_location=self.device) + self.dim = self.img_model.hidden_size + else: + self.dim = 768 + + if "vit" in self.arch.lower() or "swinunetr" in self.arch.lower(): + dim = self.dim + if self.fusion_stage == 'middle': + downsample = torch.nn.ModuleList() + # print('Number of layers: ', self.layers) + for i in range(self.layers): + if i == self.layers - 1: + dim_out = self.out_dim + # print(layers) + ks = 2 + stride = 2 + else: + dim_out = dim // 2 + ks = 2 + stride = 2 + + downsample.append( + torch.nn.Conv1d(in_channels=dim, out_channels=dim_out, kernel_size=ks, stride=stride) + ) + + dim = dim_out + + downsample.append( + torch.nn.BatchNorm1d(dim) + ) + downsample.append( + torch.nn.ReLU() + ) + + + self.downsample = torch.nn.Sequential(*downsample) + elif self.fusion_stage == 'late': + self.downsample = torch.nn.Identity() + else: + pass + + # print('Downsample layers: ', self.downsample) + + elif "densenet" in self.arch.lower(): + if "emb" not in self.arch.lower(): + self.img_model = model.ImagingModel.from_ckpt(self.ckpt_path, device=self.device, img_backend=self.arch, load_from_ckpt=True).net_ + + self.downsample = torch.nn.Linear(3900, self.out_dim) + + # randomly initialize weights for downsample block + for p in self.downsample.parameters(): + if p.dim() > 1: + torch.nn.init.xavier_uniform_(p) + p.requires_grad = True + + if "emb" not in self.arch.lower(): + # freeze imaging model parameters + if "densenet" in self.arch.lower(): + for n, p in self.img_model.features.named_parameters(): + if not self.train_backbone: + p.requires_grad = False + else: + p.requires_grad = True + for n, p in self.img_model.tgt.named_parameters(): + p.requires_grad = False + else: + for n, p in self.img_model.named_parameters(): + # print(n, p.requires_grad) + if not self.train_backbone: + p.requires_grad = False + else: + p.requires_grad = True + + def forward(self, x): + # print("--------ImagingModelWrapper forward--------") + if "emb" not in self.arch.lower(): + if "swinunetr" in self.arch.lower(): + # print(x.size()) + out = self.img_model(x) + # print(out.size()) + out = self.downsample(out) + # print(out.size()) + out = torch.mean(out, dim=-1) + # print(out.size()) + elif "vit" in self.arch.lower(): + out = self.img_model(x, return_emb=True) + ic(out.size()) + out = self.downsample(out) + out = torch.mean(out, dim=-1) + elif "densenet" in self.arch.lower(): + out = torch.nn.Sequential(*list(self.img_model.features.children()))(x) + # print(out.size()) + out = torch.flatten(out, 1) + out = self.downsample(out) + else: + # print(x.size()) + if "swinunetr" in self.arch.lower(): + x = torch.squeeze(x, dim=1) + x = x.view(x.size(0),self.dim, -1) + # print('x: ', x.size()) + out = self.downsample(x) + # print('out: ', out.size()) + if self.fusion_stage == 'middle': + if "vit" in self.arch.lower() or "swinunetr" in self.arch.lower(): + out = torch.mean(out, dim=-1) + else: + out = torch.squeeze(out, dim=1) + elif self.fusion_stage == 'late': + pass + + return out + diff --git a/adrd/nn/net_resnet3d.py b/adrd/nn/net_resnet3d.py new file mode 100755 index 0000000000000000000000000000000000000000..ccb01bcbdb5437d9543083eb8d2c63dfd2516ab0 --- /dev/null +++ b/adrd/nn/net_resnet3d.py @@ -0,0 +1,338 @@ +""" +Created on Sat Nov 21 10:49:39 2021 + +@author: cxue2 +""" + +import torch.nn as nn + + +__all__ = ['r3d_18', 'mc3_18', 'r2plus1d_18'] + + +class Conv3DSimple(nn.Conv3d): + def __init__(self, + in_planes, + out_planes, + midplanes=None, + stride=1, + padding=1): + + super(Conv3DSimple, self).__init__( + in_channels=in_planes, + out_channels=out_planes, + kernel_size=(3, 3, 3), + stride=stride, + padding=padding, + bias=False) + + @staticmethod + def get_downsample_stride(stride): + return stride, stride, stride + + +class Conv2Plus1D(nn.Sequential): + + def __init__(self, + in_planes, + out_planes, + midplanes, + stride=1, + padding=1): + super(Conv2Plus1D, self).__init__( + nn.Conv3d(in_planes, midplanes, kernel_size=(1, 3, 3), + stride=(1, stride, stride), padding=(0, padding, padding), + bias=False), + nn.BatchNorm3d(midplanes), + nn.ReLU(inplace=True), + nn.Conv3d(midplanes, out_planes, kernel_size=(3, 1, 1), + stride=(stride, 1, 1), padding=(padding, 0, 0), + bias=False)) + + @staticmethod + def get_downsample_stride(stride): + return stride, stride, stride + + +class Conv3DNoTemporal(nn.Conv3d): + + def __init__(self, + in_planes, + out_planes, + midplanes=None, + stride=1, + padding=1): + + super(Conv3DNoTemporal, self).__init__( + in_channels=in_planes, + out_channels=out_planes, + kernel_size=(1, 3, 3), + stride=(1, stride, stride), + padding=(0, padding, padding), + bias=False) + + @staticmethod + def get_downsample_stride(stride): + return 1, stride, stride + + +class BasicBlock(nn.Module): + + expansion = 1 + + def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None): + midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes) + + super(BasicBlock, self).__init__() + self.conv1 = nn.Sequential( + conv_builder(inplanes, planes, midplanes, stride), + nn.BatchNorm3d(planes), + nn.ReLU(inplace=True) + ) + self.conv2 = nn.Sequential( + conv_builder(planes, planes, midplanes), + nn.BatchNorm3d(planes) + ) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.conv2(out) + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None): + + super(Bottleneck, self).__init__() + midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes) + + # 1x1x1 + self.conv1 = nn.Sequential( + nn.Conv3d(inplanes, planes, kernel_size=1, bias=False), + nn.BatchNorm3d(planes), + nn.ReLU(inplace=True) + ) + # Second kernel + self.conv2 = nn.Sequential( + conv_builder(planes, planes, midplanes, stride), + nn.BatchNorm3d(planes), + nn.ReLU(inplace=True) + ) + + # 1x1x1 + self.conv3 = nn.Sequential( + nn.Conv3d(planes, planes * self.expansion, kernel_size=1, bias=False), + nn.BatchNorm3d(planes * self.expansion) + ) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.conv2(out) + out = self.conv3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class BasicStem(nn.Sequential): + """The default conv-batchnorm-relu stem + """ + def __init__(self): + super(BasicStem, self).__init__( + nn.Conv3d(1, 64, kernel_size=(7, 7, 7), stride=(2, 2, 2), + padding=(3, 3, 3), bias=False), + nn.BatchNorm3d(64), + nn.ReLU(inplace=True)) + + +class R2Plus1dStem(nn.Sequential): + """R(2+1)D stem is different than the default one as it uses separated 3D convolution + """ + def __init__(self): + super(R2Plus1dStem, self).__init__( + nn.Conv3d(3, 45, kernel_size=(1, 7, 7), + stride=(1, 2, 2), padding=(0, 3, 3), + bias=False), + nn.BatchNorm3d(45), + nn.ReLU(inplace=True), + nn.Conv3d(45, 64, kernel_size=(3, 1, 1), + stride=(1, 1, 1), padding=(1, 0, 0), + bias=False), + nn.BatchNorm3d(64), + nn.ReLU(inplace=True)) + + +class VideoResNet(nn.Module): + + def __init__(self, block, conv_makers, layers, + stem, num_classes=16, + zero_init_residual=False): + """Generic resnet video generator. + Args: + block (nn.Module): resnet building block + conv_makers (list(functions)): generator function for each layer + layers (List[int]): number of blocks per layer + stem (nn.Module, optional): Resnet stem, if None, defaults to conv-bn-relu. Defaults to None. + num_classes (int, optional): Dimension of the final FC layer. Defaults to 400. + zero_init_residual (bool, optional): Zero init bottleneck residual BN. Defaults to False. + """ + super(VideoResNet, self).__init__() + self.inplanes = 64 + + self.stem = stem() + + self.layer1 = self._make_layer(block, conv_makers[0], 64, layers[0], stride=1) + self.layer2 = self._make_layer(block, conv_makers[1], 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, conv_makers[2], 192, layers[2], stride=2) + self.layer4 = self._make_layer(block, conv_makers[3], 256, layers[3], stride=2) + + self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1)) + self.fc = nn.Linear(256 * block.expansion, num_classes) + + # init weights + self._initialize_weights() + + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + + def forward(self, x): + x = self.stem(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + # Flatten the layer to fc + x = x.flatten(1) + x = self.fc(x) + + return x + + def _make_layer(self, block, conv_builder, planes, blocks, stride=1): + downsample = None + + if stride != 1 or self.inplanes != planes * block.expansion: + ds_stride = conv_builder.get_downsample_stride(stride) + downsample = nn.Sequential( + nn.Conv3d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=ds_stride, bias=False), + nn.BatchNorm3d(planes * block.expansion) + ) + layers = [] + layers.append(block(self.inplanes, planes, conv_builder, stride, downsample)) + + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes, conv_builder)) + + return nn.Sequential(*layers) + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv3d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', + nonlinearity='relu') + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm3d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.constant_(m.bias, 0) + + +def _video_resnet(arch, pretrained=False, progress=True, **kwargs): + model = VideoResNet(**kwargs) + + return model + + +def r3d_18(pretrained=False, progress=True, **kwargs): + """Construct 18 layer Resnet3D model as in + https://arxiv.org/abs/1711.11248 + Args: + pretrained (bool): If True, returns a model pre-trained on Kinetics-400 + progress (bool): If True, displays a progress bar of the download to stderr + Returns: + nn.Module: R3D-18 network + """ + + return _video_resnet('r3d_18', + pretrained, progress, + block=BasicBlock, + conv_makers=[Conv3DSimple] * 4, + layers=[2, 2, 2, 2], + stem=BasicStem, **kwargs) + + +def mc3_18(pretrained=False, progress=True, **kwargs): + """Constructor for 18 layer Mixed Convolution network as in + https://arxiv.org/abs/1711.11248 + Args: + pretrained (bool): If True, returns a model pre-trained on Kinetics-400 + progress (bool): If True, displays a progress bar of the download to stderr + Returns: + nn.Module: MC3 Network definition + """ + return _video_resnet('mc3_18', + pretrained, progress, + block=BasicBlock, + conv_makers=[Conv3DSimple] + [Conv3DNoTemporal] * 3, + layers=[2, 2, 2, 2], + stem=BasicStem, **kwargs) + + +def r2plus1d_18(pretrained=False, progress=True, **kwargs): + """Constructor for the 18 layer deep R(2+1)D network as in + https://arxiv.org/abs/1711.11248 + Args: + pretrained (bool): If True, returns a model pre-trained on Kinetics-400 + progress (bool): If True, displays a progress bar of the download to stderr + Returns: + nn.Module: R(2+1)D-18 network + """ + return _video_resnet('r2plus1d_18', + pretrained, progress, + block=BasicBlock, + conv_makers=[Conv2Plus1D] * 4, + layers=[2, 2, 2, 2], + stem=R2Plus1dStem, **kwargs) + + +if __name__ == '__main__': + + import torch + + net = r3d_18().to(0) + x = torch.zeros(3, 1, 182, 218, 182).to(0) + + print(net(x).shape) + print(net) \ No newline at end of file diff --git a/adrd/nn/resnet3d.py b/adrd/nn/resnet3d.py new file mode 100755 index 0000000000000000000000000000000000000000..21371da31270012a804b39df87f0cae8d588c5ba --- /dev/null +++ b/adrd/nn/resnet3d.py @@ -0,0 +1,256 @@ +""" +Simplified from torchvision.models.video.r3d_18. The citation information is +shown below. + +@article{DBLP:journals/corr/abs-1711-11248, + author = {Du Tran and + Heng Wang and + Lorenzo Torresani and + Jamie Ray and + Yann LeCun and + Manohar Paluri}, + title = {A Closer Look at Spatiotemporal Convolutions for Action Recognition}, + journal = {CoRR}, + volume = {abs/1711.11248}, + year = {2017}, + url = {http://arxiv.org/abs/1711.11248}, + archivePrefix = {arXiv}, + eprint = {1711.11248}, + timestamp = {Mon, 13 Aug 2018 16:46:39 +0200}, + biburl = {https://dblp.org/rec/journals/corr/abs-1711-11248.bib}, + bibsource = {dblp computer science bibliography, https://dblp.org} +} +""" + +import torch.nn as nn + + +class Conv3DSimple(nn.Conv3d): + def __init__(self, + in_planes, + out_planes, + midplanes=None, + stride=1, + padding=1): + + super().__init__( + in_channels=in_planes, + out_channels=out_planes, + kernel_size=(3, 3, 3), + stride=stride, + padding=padding, + bias=False) + + @staticmethod + def get_downsample_stride(stride): + return stride, stride, stride + + +class BasicBlock(nn.Module): + + expansion = 1 + + def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None): + midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes) + + super(BasicBlock, self).__init__() + self.conv1 = nn.Sequential( + conv_builder(inplanes, planes, midplanes, stride), + nn.BatchNorm3d(planes), + nn.ReLU(inplace=True) + ) + self.conv2 = nn.Sequential( + conv_builder(planes, planes, midplanes), + nn.BatchNorm3d(planes) + ) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.conv2(out) + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None): + + super(Bottleneck, self).__init__() + midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes) + + # 1x1x1 + self.conv1 = nn.Sequential( + nn.Conv3d(inplanes, planes, kernel_size=1, bias=False), + nn.BatchNorm3d(planes), + nn.ReLU(inplace=True) + ) + # Second kernel + self.conv2 = nn.Sequential( + conv_builder(planes, planes, midplanes, stride), + nn.BatchNorm3d(planes), + nn.ReLU(inplace=True) + ) + + # 1x1x1 + self.conv3 = nn.Sequential( + nn.Conv3d(planes, planes * self.expansion, kernel_size=1, bias=False), + nn.BatchNorm3d(planes * self.expansion) + ) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.conv2(out) + out = self.conv3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class BasicStem(nn.Sequential): + """The default conv-batchnorm-relu stem + """ + def __init__(self): + super(BasicStem, self).__init__( + nn.Conv3d(1, 64, kernel_size=(7, 7, 7), stride=(2, 2, 2), + padding=(3, 3, 3), bias=False), + nn.BatchNorm3d(64), + nn.ReLU(inplace=True)) + + +class VideoResNet(nn.Module): + + def __init__(self, block, conv_makers, layers, + stem, num_classes=2, + zero_init_residual=False): + """Generic resnet video generator. + Args: + block (nn.Module): resnet building block + conv_makers (list(functions)): generator function for each layer + layers (List[int]): number of blocks per layer + stem (nn.Module, optional): Resnet stem, if None, defaults to conv-bn-relu. Defaults to None. + num_classes (int, optional): Dimension of the final FC layer. Defaults to 400. + zero_init_residual (bool, optional): Zero init bottleneck residual BN. Defaults to False. + """ + super(VideoResNet, self).__init__() + self.inplanes = 64 + + self.stem = stem() + + self.layer1 = self._make_layer(block, conv_makers[0], 64, layers[0], stride=1) + self.layer2 = self._make_layer(block, conv_makers[1], 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, conv_makers[2], 192, layers[2], stride=2) + self.layer4 = self._make_layer(block, conv_makers[3], 256, layers[3], stride=2) + + self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1)) + # self.fc = nn.Linear(256 * block.expansion, num_classes) + + # init weights + self._initialize_weights() + + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + + def forward(self, x): + x = self.stem(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + # Flatten the layer to fc + x = x.flatten(1) + # x = self.fc(x) + + return x + + def _make_layer(self, block, conv_builder, planes, blocks, stride=1): + downsample = None + + if stride != 1 or self.inplanes != planes * block.expansion: + ds_stride = conv_builder.get_downsample_stride(stride) + downsample = nn.Sequential( + nn.Conv3d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=ds_stride, bias=False), + nn.BatchNorm3d(planes * block.expansion) + ) + layers = [] + layers.append(block(self.inplanes, planes, conv_builder, stride, downsample)) + + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes, conv_builder)) + + return nn.Sequential(*layers) + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv3d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', + nonlinearity='relu') + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm3d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.constant_(m.bias, 0) + + +def _video_resnet(arch, pretrained=False, progress=True, **kwargs): + model = VideoResNet(**kwargs) + + return model + + +def r3d_18(pretrained=False, progress=True, **kwargs): + """Construct 18 layer Resnet3D model as in + https://arxiv.org/abs/1711.11248 + Args: + pretrained (bool): If True, returns a model pre-trained on Kinetics-400 + progress (bool): If True, displays a progress bar of the download to stderr + Returns: + nn.Module: R3D-18 network + """ + + return _video_resnet('r3d_18', + pretrained, progress, + block=BasicBlock, + conv_makers=[Conv3DSimple] * 4, + layers=[2, 2, 2, 2], + stem=BasicStem, **kwargs) + + +if __name__ == '__main__': + """ ... """ + import torch + + net = r3d_18().to('cuda:1') + x = torch.zeros(8, 1, 182, 218, 182).to('cuda:1') + + print(net(x).shape) \ No newline at end of file diff --git a/adrd/nn/resnet_img_model.py b/adrd/nn/resnet_img_model.py new file mode 100755 index 0000000000000000000000000000000000000000..b06c57932ad8f4badae6fa7a074cb39666d00ec6 --- /dev/null +++ b/adrd/nn/resnet_img_model.py @@ -0,0 +1,81 @@ +import torch +import torch.nn as nn +import sys +from icecream import ic +# sys.path.append('/home/skowshik/ADRD_repo/adrd_tool/adrd/') +from .net_resnet3d import r3d_18 +# from dev.data.dataset_csv import CSVDataset + + +class ResNetModel(nn.Module): + ''' ... ''' + def __init__( + self, + tgt_modalities, + mri_feature = 'img_MRI_T1', + ): + ''' ... ''' + super().__init__() + + self.mri_feature = mri_feature + + self.img_net_ = r3d_18() + + # self.modules_emb_src = nn.Sequential( + # nn.BatchNorm1d(9), + # nn.Linear(9, d_model) + # ) + + # classifiers (binary only) + self.modules_cls = nn.ModuleDict() + for k, info in tgt_modalities.items(): + if info['type'] == 'categorical' and info['num_categories'] == 2: + # categorical + self.modules_cls[k] = nn.Linear(64, 1) + + else: + # unrecognized + raise ValueError + + def forward(self, x): + ''' ... ''' + tgt_iter = self.modules_cls.keys() + + img_x_batch = x[self.mri_feature] + img_out = self.img_net_(img_x_batch) + + # ic(img_out.shape) + + # run linear classifiers + out = [self.modules_cls[k](img_out).squeeze(1) for i, k in enumerate(tgt_iter)] + out = torch.stack(out, dim=1) + + # ic(out.shape) + + # out to dict + out = {k: out[:, i] for i, k in enumerate(tgt_iter)} + + return out + + +if __name__ == '__main__': + ''' for testing purpose only ''' + # import torch + # import numpy as np + + # seed = 0 + # print('Loading training dataset ... ') + # dat_trn = CSVDataset(mode=0, split=[1, 700], seed=seed) + # print(len(dat_trn)) + # tgt_modalities = dat_trn.label_modalities + # net = ResNetModel(tgt_modalities).to('cuda') + # x = dat_trn.features + # x = {k: torch.as_tensor(np.array([x[i][k] for i in range(len(x))])).to('cuda') for k in x[0]} + # ic(x) + + + # # print(net(x).shape) + # print(net(x)) + + + diff --git a/adrd/nn/selfattention.py b/adrd/nn/selfattention.py new file mode 100755 index 0000000000000000000000000000000000000000..a2575e657687db40bace40fde4d6b9f4af72b2f7 --- /dev/null +++ b/adrd/nn/selfattention.py @@ -0,0 +1,62 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from monai.utils import optional_import +import torch +import torch.nn as nn + + +Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange") + + +class SABlock(nn.Module): + """ + A self-attention block, based on: "Dosovitskiy et al., + An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " + """ + + def __init__(self, hidden_size: int, num_heads: int, dropout_rate: float = 0.0, qkv_bias: bool = False) -> None: + """ + Args: + hidden_size: dimension of hidden layer. + num_heads: number of attention heads. + dropout_rate: faction of the input units to drop. + qkv_bias: bias term for the qkv linear layer. + + """ + + super().__init__() + + if not (0 <= dropout_rate <= 1): + raise ValueError("dropout_rate should be between 0 and 1.") + + if hidden_size % num_heads != 0: + raise ValueError("hidden size should be divisible by num_heads.") + + self.num_heads = num_heads + self.out_proj = nn.Linear(hidden_size, hidden_size) + self.qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias) + self.input_rearrange = Rearrange("b h (qkv l d) -> qkv b l h d", qkv=3, l=num_heads) + self.out_rearrange = Rearrange("b h l d -> b l (h d)") + self.drop_output = nn.Dropout(dropout_rate) + self.drop_weights = nn.Dropout(dropout_rate) + self.head_dim = hidden_size // num_heads + self.scale = self.head_dim**-0.5 + + def forward(self, x): + output = self.input_rearrange(self.qkv(x)) + q, k, v = output[0], output[1], output[2] + att_mat = (torch.einsum("blxd,blyd->blxy", q, k) * self.scale).softmax(dim=-1) + att_mat = self.drop_weights(att_mat) + x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v) + x = self.out_rearrange(x) + x = self.out_proj(x) + x = self.drop_output(x) + return x, att_mat diff --git a/adrd/nn/transformer.py b/adrd/nn/transformer.py new file mode 100755 index 0000000000000000000000000000000000000000..1ee38fb24a1c8b7a7406600688ac8f08b87116af --- /dev/null +++ b/adrd/nn/transformer.py @@ -0,0 +1,268 @@ +import torch +import numpy as np +from .. import nn +# from ..nn import ImagingModelWrapper +from .net_resnet3d import r3d_18 +from typing import Any, Type +import math +Tensor = Type[torch.Tensor] +from icecream import ic +ic.disable() + +class Transformer(torch.nn.Module): + ''' ... ''' + def __init__(self, + src_modalities: dict[str, dict[str, Any]], + tgt_modalities: dict[str, dict[str, Any]], + d_model: int, + nhead: int, + num_encoder_layers: int = 1, + num_decoder_layers: int = 1, + device: str = 'cpu', + cuda_devices: list = [3], + img_net: str | None = None, + layers: int = 3, + img_size: int | None = 128, + patch_size: int | None = 16, + imgnet_ckpt: str | None = None, + train_imgnet: bool = False, + fusion_stage: str = 'middle', + ) -> None: + ''' ... ''' + super().__init__() + + self.d_model = d_model + self.nhead = nhead + self.num_encoder_layers = num_encoder_layers + self.num_decoder_layers = num_decoder_layers + self.img_net = img_net + self.img_size = img_size + self.patch_size = patch_size + self.imgnet_ckpt = imgnet_ckpt + self.train_imgnet = train_imgnet + self.layers = layers + self.src_modalities = src_modalities + self.tgt_modalities = tgt_modalities + self.device = device + self.fusion_stage = fusion_stage + + # embedding modules for source + + self.modules_emb_src = torch.nn.ModuleDict() + print('Downsample layers: ', self.layers) + self.img_model = nn.ImagingModelWrapper(arch=self.img_net, img_size=self.img_size, patch_size=self.patch_size, ckpt_path=self.imgnet_ckpt, train_backbone=self.train_imgnet, layers=self.layers, out_dim=self.d_model, device=self.device, fusion_stage=self.fusion_stage) + + for k, info in src_modalities.items(): + # ic(k) + # for key, val in info.items(): + # ic(key, val) + if info['type'] == 'categorical': + self.modules_emb_src[k] = torch.nn.Embedding(info['num_categories'], d_model) + elif info['type'] == 'numerical': + self.modules_emb_src[k] = torch.nn.Sequential( + torch.nn.BatchNorm1d(info['shape'][0]), + torch.nn.Linear(info['shape'][0], d_model) + ) + elif info['type'] == 'imaging': + # print(info['shape'], info['img_shape']) + if self.img_net: + self.modules_emb_src[k] = self.img_model + + else: + # unrecognized + raise ValueError('{} is an unrecognized data modality'.format(k)) + + # positional encoding + self.pe = PositionalEncoding(d_model) + + # auxiliary embedding vectors for targets + self.emb_aux = torch.nn.Parameter( + torch.zeros(len(tgt_modalities), 1, d_model), + requires_grad = True, + ) + + # transformer + enc = torch.nn.TransformerEncoderLayer( + self.d_model, self.nhead, + dim_feedforward = self.d_model, + activation = 'gelu', + dropout = 0.3, + ) + self.transformer = torch.nn.TransformerEncoder(enc, self.num_encoder_layers) + + + # classifiers (binary only) + self.modules_cls = torch.nn.ModuleDict() + for k, info in tgt_modalities.items(): + if info['type'] == 'categorical' and info['num_categories'] == 2: + self.modules_cls[k] = torch.nn.Linear(d_model, 1) + else: + # unrecognized + raise ValueError + + # for n,p in self.named_parameters(): + # print(n, p.requires_grad) + + def forward(self, + x: dict[str, Tensor], + mask: dict[str, Tensor], + # x_img: dict[str, Tensor] | Any = None, + skip_embedding: dict[str, bool] | None = None, + return_out_emb: bool = False, + ) -> dict[str, Tensor]: + """ ... """ + + out_emb = self.forward_emb(x, mask, skip_embedding) + if self.fusion_stage == "late": + out_emb = {k: v for k,v in out_emb.items() if "img_MRI" not in k} + img_out_emb = {k: v for k,v in out_emb.items() if "img_MRI" in k} + # for k,v in out_emb.items(): + # print(k, v.size()) + mask_nonimg = {k: v for k,v in mask.items() if "img_MRI" not in k} + out_trf = self.forward_trf(out_emb, mask_nonimg) # (8,128) + (8,50,128) + # print("out_trf: ", out_trf.size()) + out_trf = torch.concatenate() + else: + out_trf = self.forward_trf(out_emb, mask) + + out_cls = self.forward_cls(out_trf) + + if return_out_emb: + return out_emb, out_cls + return out_cls + + def forward_emb(self, + x: dict[str, Tensor], + mask: dict[str, Tensor], + skip_embedding: dict[str, bool] | None = None, + # x_img: dict[str, Tensor] | Any = None, + ) -> dict[str, Tensor]: + """ ... """ + # print("-------forward_emb--------") + out_emb = dict() + for k in self.modules_emb_src.keys(): + if skip_embedding is None or k not in skip_embedding or not skip_embedding[k]: + if "img_MRI" in k: + # print("img_MRI in ", k) + if torch.all(mask[k]): + if "swinunetr" in self.img_net.lower() and self.fusion_stage == 'late': + out_emb[k] = torch.zeros((1,768,4,4,4)) + else: + if 'cuda' in self.device: + device = x[k].device + # print(device) + else: + device = self.device + out_emb[k] = torch.zeros((mask[k].shape[0], self.d_model)).to(device, non_blocking=True) + # print("mask is True, out_emb[k]: ", out_emb[k].size()) + else: + # print("calling modules_emb_src...") + out_emb[k] = self.modules_emb_src[k](x[k]) + # print("mask is False, out_emb[k]: ", out_emb[k].size()) + + else: + out_emb[k] = self.modules_emb_src[k](x[k]) + + # out_emb[k] = self.modules_emb_src[k](x[k]) + else: + out_emb[k] = x[k] + return out_emb + + def forward_trf(self, + out_emb: dict[str, Tensor], + mask: dict[str, Tensor], + ) -> dict[str, Tensor]: + """ ... """ + # print('-----------forward_trf----------') + N = len(next(iter(out_emb.values()))) # batch size + S = len(self.modules_emb_src) # number of sources + T = len(self.modules_cls) # number of targets + if self.fusion_stage == 'late': + src_iter = [k for k in self.modules_emb_src.keys() if "img_MRI" not in k] + S = len(src_iter) # number of sources + + else: + src_iter = self.modules_emb_src.keys() + tgt_iter = self.modules_cls.keys() + + emb_src = torch.stack([o for o in out_emb.values()], dim=0) + # print('emb_src: ', emb_src.size()) + + self.pe.index = -1 + emb_src = self.pe(emb_src) + # print('emb_src + pe: ', emb_src.size()) + + # target embedding + # print('emb_aux: ', self.emb_aux.size()) + emb_tgt = self.emb_aux.repeat(1, N, 1) + # print('emb_tgt: ', emb_tgt.size()) + + # concatenate source embeddings and target embeddings + emb_all = torch.concatenate((emb_tgt, emb_src), dim=0) + + # combine masks + mask_src = [mask[k] for k in src_iter] + mask_src = torch.stack(mask_src, dim=1) + + # target masks + mask_tgt = torch.zeros((N, T), dtype=torch.bool, device=self.emb_aux.device) + + # concatenate source masks and target masks + mask_all = torch.concatenate((mask_tgt, mask_src), dim=1) + + # repeat mask_all to fit transformer + mask_all = mask_all.unsqueeze(1).expand(-1, S + T, -1).repeat(self.nhead, 1, 1) + + # run transformer + out_trf = self.transformer( + src = emb_all, + mask = mask_all, + )[0] + # print('out_trf: ', out_trf.size()) + # out_trf = {k: out_trf[i] for i, k in enumerate(tgt_iter)} + return out_trf + + def forward_cls(self, + out_trf: dict[str, Tensor], + ) -> dict[str, Tensor]: + """ ... """ + tgt_iter = self.modules_cls.keys() + out_cls = {k: self.modules_cls[k](out_trf).squeeze(1) for k in tgt_iter} + return out_cls + +class PositionalEncoding(torch.nn.Module): + + def __init__(self, + d_model: int, + max_len: int = 512 + ): + """ ... """ + super().__init__() + position = torch.arange(max_len).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) + pe = torch.zeros(max_len, 1, d_model) + pe[:, 0, 0::2] = torch.sin(position * div_term) + pe[:, 0, 1::2] = torch.cos(position * div_term) + self.register_buffer('pe', pe) + self.index = -1 + + def forward(self, x: Tensor, pe_type: str = 'non_img') -> Tensor: + """ + Arguments: + x: Tensor, shape ``[seq_len, batch_size, embedding_dim]`` + """ + # print('pe: ', self.pe.size()) + # print('x: ', x.size()) + if pe_type == 'img': + self.index += 1 + return x + self.pe[self.index] + else: + self.index += 1 + return x + self.pe[self.index:x.size(0)+self.index] + + +if __name__ == '__main__': + ''' for testing purpose only ''' + pass + + diff --git a/adrd/nn/unet.py b/adrd/nn/unet.py new file mode 100755 index 0000000000000000000000000000000000000000..ba7efe5cb34d38610069e1a5a41e36fa7c02045c --- /dev/null +++ b/adrd/nn/unet.py @@ -0,0 +1,232 @@ +import numpy as np +import torch +import torch.nn as nn +import torchvision +from torchvision import models +from torch.nn import init +import torch.nn.functional as F +from icecream import ic + + +class ContBatchNorm3d(nn.modules.batchnorm._BatchNorm): + def _check_input_dim(self, input): + + if input.dim() != 5: + raise ValueError('expected 5D input (got {}D input)'.format(input.dim())) + #super(ContBatchNorm3d, self)._check_input_dim(input) + + def forward(self, input): + self._check_input_dim(input) + return F.batch_norm( + input, self.running_mean, self.running_var, self.weight, self.bias, + True, self.momentum, self.eps) + + +class LUConv(nn.Module): + def __init__(self, in_chan, out_chan, act): + super(LUConv, self).__init__() + self.conv1 = nn.Conv3d(in_chan, out_chan, kernel_size=3, padding=1) + self.bn1 = ContBatchNorm3d(out_chan) + + if act == 'relu': + self.activation = nn.ReLU(out_chan) + elif act == 'prelu': + self.activation = nn.PReLU(out_chan) + elif act == 'elu': + self.activation = nn.ELU(inplace=True) + else: + raise + + def forward(self, x): + out = self.activation(self.bn1(self.conv1(x))) + return out + + +def _make_nConv(in_channel, depth, act, double_chnnel=False): + if double_chnnel: + layer1 = LUConv(in_channel, 32 * (2 ** (depth+1)),act) + layer2 = LUConv(32 * (2 ** (depth+1)), 32 * (2 ** (depth+1)),act) + else: + layer1 = LUConv(in_channel, 32*(2**depth),act) + layer2 = LUConv(32*(2**depth), 32*(2**depth)*2,act) + + return nn.Sequential(layer1,layer2) + + +class DownTransition(nn.Module): + def __init__(self, in_channel,depth, act): + super(DownTransition, self).__init__() + self.ops = _make_nConv(in_channel, depth,act) + self.maxpool = nn.MaxPool3d(2) + self.current_depth = depth + + def forward(self, x): + if self.current_depth == 3: + out = self.ops(x) + out_before_pool = out + else: + out_before_pool = self.ops(x) + out = self.maxpool(out_before_pool) + return out, out_before_pool + +class UpTransition(nn.Module): + def __init__(self, inChans, outChans, depth,act): + super(UpTransition, self).__init__() + self.depth = depth + self.up_conv = nn.ConvTranspose3d(inChans, outChans, kernel_size=2, stride=2) + self.ops = _make_nConv(inChans+ outChans//2,depth, act, double_chnnel=True) + + def forward(self, x, skip_x): + out_up_conv = self.up_conv(x) + concat = torch.cat((out_up_conv,skip_x),1) + out = self.ops(concat) + return out + +class OutputTransition(nn.Module): + def __init__(self, inChans, n_labels): + + super(OutputTransition, self).__init__() + self.final_conv = nn.Conv3d(inChans, n_labels, kernel_size=1) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + out = self.sigmoid(self.final_conv(x)) + return out + +class ConvLayer(nn.Module): + def __init__(self, in_channels, out_channels, drop_rate, kernel, pooling, BN=True, relu_type='leaky'): + super().__init__() + kernel_size, kernel_stride, kernel_padding = kernel + pool_kernel, pool_stride, pool_padding = pooling + self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, kernel_stride, kernel_padding, bias=False) + self.pooling = nn.MaxPool3d(pool_kernel, pool_stride, pool_padding) + self.BN = nn.BatchNorm3d(out_channels) + self.relu = nn.LeakyReLU(inplace=False) if relu_type=='leaky' else nn.ReLU(inplace=False) + self.dropout = nn.Dropout(drop_rate, inplace=False) + + def forward(self, x): + x = self.conv(x) + x = self.pooling(x) + x = self.BN(x) + x = self.relu(x) + x = self.dropout(x) + return x + +class AttentionModule(nn.Module): + def __init__(self, in_channels, out_channels, drop_rate=0.1): + super(AttentionModule, self).__init__() + self.conv = nn.Conv3d(in_channels, out_channels, 1, 1, 0, bias=False) + self.attention = ConvLayer(in_channels, out_channels, drop_rate, (1, 1, 0), (1, 1, 0)) + + def forward(self, x, return_attention=True): + feats = self.conv(x) + att = F.softmax(self.attention(x)) + + out = feats * att + + if return_attention: + return att, out + + return out + +class UNet3D(nn.Module): + # the number of convolutions in each layer corresponds + # to what is in the actual prototxt, not the intent + def __init__(self, n_class=1, act='relu', pretrained=False, input_size=(1,1,182,218,182), attention=False, drop_rate=0.1, blocks=4): + super(UNet3D, self).__init__() + + self.blocks = blocks + self.down_tr64 = DownTransition(1,0,act) + self.down_tr128 = DownTransition(64,1,act) + self.down_tr256 = DownTransition(128,2,act) + self.down_tr512 = DownTransition(256,3,act) + + self.up_tr256 = UpTransition(512, 512,2,act) + self.up_tr128 = UpTransition(256,256, 1,act) + self.up_tr64 = UpTransition(128,128,0,act) + self.out_tr = OutputTransition(64, 1) + + self.pretrained = pretrained + self.attention = attention + if pretrained: + print("Using image pretrained model checkpoint") + weight_dir = '/home/skowshik/ADRD_repo/img_pretrained_ckpt/Genesis_Chest_CT.pt' + checkpoint = torch.load(weight_dir) + state_dict = checkpoint['state_dict'] + unParalled_state_dict = {} + for key in state_dict.keys(): + unParalled_state_dict[key.replace("module.", "")] = state_dict[key] + self.load_state_dict(unParalled_state_dict) + del self.up_tr256 + del self.up_tr128 + del self.up_tr64 + del self.out_tr + + if self.blocks == 5: + self.down_tr1024 = DownTransition(512,4,act) + + + # self.conv1 = nn.Conv3d(512, 256, 1, 1, 0, bias=False) + # self.conv2 = nn.Conv3d(256, 128, 1, 1, 0, bias=False) + # self.conv3 = nn.Conv3d(128, 64, 1, 1, 0, bias=False) + + if attention: + self.attention_module = AttentionModule(1024 if self.blocks==5 else 512, n_class, drop_rate=drop_rate) + # Output. + self.avgpool = nn.AvgPool3d((6,7,6), stride=(6,6,6)) + + dummy_inp = torch.rand(input_size) + dummy_feats = self.forward(dummy_inp, stage='get_features') + dummy_feats = dummy_feats[0] + self.in_features = list(dummy_feats.shape) + ic(self.in_features) + + self._init_weights() + + def _init_weights(self): + if not self.pretrained: + for m in self.modules(): + if isinstance(m, nn.Conv3d): + init.kaiming_normal_(m.weight) + elif isinstance(m, ContBatchNorm3d): + init.constant_(m.weight, 1) + init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + init.kaiming_normal_(m.weight) + init.constant_(m.bias, 0) + elif self.attention: + for m in self.attention_module.modules(): + if isinstance(m, nn.Conv3d): + init.kaiming_normal_(m.weight) + elif isinstance(m, nn.BatchNorm3d): + init.constant_(m.weight, 1) + init.constant_(m.bias, 0) + else: + pass + # Zero initialize the last batchnorm in each residual branch. + # for m in self.modules(): + # if isinstance(m, BottleneckBlock): + # init.constant_(m.out_conv.bn.weight, 0) + + def forward(self, x, stage='normal', attention=False): + ic('backbone forward') + self.out64, self.skip_out64 = self.down_tr64(x) + self.out128,self.skip_out128 = self.down_tr128(self.out64) + self.out256,self.skip_out256 = self.down_tr256(self.out128) + self.out512,self.skip_out512 = self.down_tr512(self.out256) + if self.blocks == 5: + self.out1024,self.skip_out1024 = self.down_tr1024(self.out512) + ic(self.out1024.shape) + # self.out = self.conv1(self.out512) + # self.out = self.conv2(self.out) + # self.out = self.conv3(self.out) + # self.out = self.conv(self.out) + ic(hasattr(self, 'attention_module')) + if hasattr(self, 'attention_module'): + att, feats = self.attention_module(self.out1024 if self.blocks==5 else self.out512) + else: + feats = self.out1024 if self.blocks==5 else self.out512 + ic(feats.shape) + if attention: + return att, feats + return feats \ No newline at end of file diff --git a/adrd/nn/unet_3d.py b/adrd/nn/unet_3d.py new file mode 100755 index 0000000000000000000000000000000000000000..19cdfe010f023ef50178306551d81a64b2d545b8 --- /dev/null +++ b/adrd/nn/unet_3d.py @@ -0,0 +1,63 @@ +import sys +sys.path.append('..') +# from feature_extractor.for_image_data.backbone import CNN_GAP, ResNet3D, UNet3D +import torch +import torch.nn as nn +from torchvision import models +import torch.nn.functional as F +# from . import UNet3D +from .unet import UNet3D +from icecream import ic + + +class UNet3DBase(nn.Module): + def __init__(self, n_class=1, act='relu', attention=False, pretrained=False, drop_rate=0.1, blocks=4): + super(UNet3DBase, self).__init__() + model = UNet3D(n_class=n_class, attention=attention, pretrained=pretrained, blocks=blocks) + + self.blocks = blocks + + self.down_tr64 = model.down_tr64 + self.down_tr128 = model.down_tr128 + self.down_tr256 = model.down_tr256 + self.down_tr512 = model.down_tr512 + if self.blocks == 5: + self.down_tr1024 = model.down_tr1024 + # self.block_modules = nn.ModuleList([self.down_tr64, self.down_tr128, self.down_tr256, self.down_tr512]) + + self.in_features = model.in_features + # ic(attention) + if attention: + self.attention_module = model.attention_module + # self.attention_module = AttentionModule(512, n_class, drop_rate=drop_rate) + # self.avgpool = nn.AvgPool3d((6,7,6), stride=(6,6,6)) + + def forward(self, x, stage='normal', attention=False): + # ic('UNet3DBase forward') + self.out64, self.skip_out64 = self.down_tr64(x) + # ic(self.out64.shape, self.skip_out64.shape) + self.out128,self.skip_out128 = self.down_tr128(self.out64) + # ic(self.out128.shape, self.skip_out128.shape) + self.out256,self.skip_out256 = self.down_tr256(self.out128) + # ic(self.out256.shape, self.skip_out256.shape) + self.out512,self.skip_out512 = self.down_tr512(self.out256) + # ic(self.out512.shape, self.skip_out512.shape) + if self.blocks == 5: + self.out1024,self.skip_out1024 = self.down_tr1024(self.out512) + # ic(self.out1024.shape, self.skip_out1024.shape) + # ic(hasattr(self, 'attention_module')) + if hasattr(self, 'attention_module'): + att, feats = self.attention_module(self.out1024 if self.blocks == 5 else self.out512) + else: + feats = self.out1024 if self.blocks == 5 else self.out512 + # ic(feats.shape) + if attention: + return att, feats + return feats + + # self.out_up_256 = self.up_tr256(self.out512,self.skip_out256) + # self.out_up_128 = self.up_tr128(self.out_up_256, self.skip_out128) + # self.out_up_64 = self.up_tr64(self.out_up_128, self.skip_out64) + # self.out = self.out_tr(self.out_up_64) + + # return self.out \ No newline at end of file diff --git a/adrd/nn/unet_img_model.py b/adrd/nn/unet_img_model.py new file mode 100755 index 0000000000000000000000000000000000000000..3dc60f7f8b8f2b683d31bda49dc87dabd7327b4b --- /dev/null +++ b/adrd/nn/unet_img_model.py @@ -0,0 +1,211 @@ +from pyexpat import features +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.cuda.amp import autocast +import numpy as np +import re +from icecream import ic +import math +import torch.nn.utils.weight_norm as weightNorm + +# from . import UNet3DBase +from .unet_3d import UNet3DBase + + +def init_weights(m): + classname = m.__class__.__name__ + if classname.find('Conv2d') != -1 or classname.find('ConvTranspose2d') != -1: + nn.init.kaiming_uniform_(m.weight) + nn.init.zeros_(m.bias) + elif classname.find('BatchNorm') != -1: + nn.init.normal_(m.weight, 1.0, 0.02) + nn.init.zeros_(m.bias) + elif classname.find('Linear') != -1: + nn.init.xavier_normal_(m.weight) + nn.init.zeros_(m.bias) + +class feat_classifier(nn.Module): + def __init__(self, class_num, bottleneck_dim=256, type="linear"): + super(feat_classifier, self).__init__() + self.type = type + # if type in ['conv', 'gap'] and len(bottleneck_dim) > 3: + # bottleneck_dim = bottleneck_dim[-3:] + ic(bottleneck_dim) + if type == 'wn': + self.layer = weightNorm( + nn.Linear(bottleneck_dim[1:], class_num), name="weight") + # self.fc.apply(init_weights) + elif type == 'gap': + if len(bottleneck_dim) > 3: + bottleneck_dim = bottleneck_dim[-3:] + self.layer = nn.AvgPool3d(bottleneck_dim, stride=(1,1,1)) + elif type == 'conv': + if len(bottleneck_dim) > 3: + bottleneck_dim = bottleneck_dim[-4:] + ic(bottleneck_dim) + self.layer = nn.Conv3d(bottleneck_dim[0], class_num, kernel_size=bottleneck_dim[1:]) + ic(self.layer) + else: + print('bottleneck dim: ', bottleneck_dim) + self.layer = nn.Sequential( + torch.nn.Flatten(start_dim=1, end_dim=-1), + nn.Linear(math.prod(bottleneck_dim), class_num) + ) + self.layer.apply(init_weights) + + def forward(self, x): + # print('=> feat_classifier forward') + # ic(x.size()) + x = self.layer(x) + # ic(x.size()) + if self.type in ['gap','conv']: + x = torch.squeeze(x) + if len(x.shape) < 2: + x = torch.unsqueeze(x,0) + # print('returning x: ', x.size()) + return x + +class ImageModel(nn.Module): + """ + Empirical Risk Minimization (ERM) + """ + + def __init__( + self, + counts=None, + classifier='gap', + accum_iter=8, + save_emb=False, + # ssl, + num_classes=1, + load_img_ckpt=False, + ): + super(ImageModel, self).__init__() + if counts is not None: + if isinstance(counts[0], list): + counts = np.stack(counts, axis=0).sum(axis=0) + print('counts: ', counts) + total = np.sum(counts) + print(total/counts) + self.weight = total/torch.FloatTensor(counts) + else: + total = sum(counts) + self.weight = torch.FloatTensor([total/c for c in counts]) + else: + self.weight = None + print('weight: ', self.weight) + # device = torch.device(f'cuda:{args.gpu_id}' if args.gpu_id is not None else 'cpu') + self.criterion = nn.CrossEntropyLoss(weight=self.weight) + # if ssl: + # # add contrastive loss + # # self.ssl_criterion = + # pass + + self.featurizer = UNet3DBase(n_class=num_classes, attention=True, pretrained=load_img_ckpt) + self.classifier = feat_classifier( + num_classes, self.featurizer.in_features, classifier) + + self.network = nn.Sequential( + self.featurizer, self.classifier) + self.accum_iter = accum_iter + self.acc_steps = 0 + self.save_embedding = save_emb + + def update(self, minibatches, opt, sch, scaler): + print('--------------def update----------------') + device = list(self.parameters())[0].device + all_x = torch.cat([data[1].to(device).float() for data in minibatches]) + all_y = torch.cat([data[2].to(device).long() for data in minibatches]) + print('all_x: ', all_x.size()) + # all_p = self.predict(all_x) + # all_probs = + label_list = all_y.tolist() + count = float(len(label_list)) + ic(count) + + uniques = sorted(list(set(label_list))) + ic(uniques) + counts = [float(label_list.count(i)) for i in uniques] + ic(counts) + + weights = [count / c for c in counts] + ic(weights) + + with autocast(): + loss = self.criterion(self.predict(all_x), all_y) + self.acc_steps += 1 + print('class: ', loss.item()) + + scaler.scale(loss / self.accum_iter).backward() + + if self.acc_steps == self.accum_iter: + scaler.step(opt) + if sch: + sch.step() + scaler.update() + self.zero_grad() + self.acc_steps = 0 + torch.cuda.empty_cache() + + del all_x + del all_y + return {'class': loss.item()}, sch + + def forward(self, *args, **kwargs): + return self.network(*args, **kwargs) + + def predict(self, x, stage='normal', attention=False): + # print('network device: ', list(self.network.parameters())[0].device) + # print('x device: ', x.device) + if stage == 'get_features' or self.save_embedding: + feats = self.network[0](x, attention=attention) + output = self.network[1](feats[-1] if attention else feats) + return feats, output + else: + return self.network(x) + + def extract_features(self, x, attention=False): + feats = self.network[0](x, attention=attention) + return feats + + def load_checkpoint(self, state_dict): + try: + self.load_checkpoint_helper(state_dict) + except: + featurizer_dict = {} + net_dict = {} + for key,val in state_dict.items(): + if 'featurizer' in key: + featurizer_dict[key] = val + elif 'network' in key: + net_dict[key] = val + self.featurizer.load_state_dict(featurizer_dict) + self.classifier.load_state_dict(net_dict) + + def load_checkpoint_helper(self, state_dict): + try: + self.load_state_dict(state_dict) + print('try: loaded') + except RuntimeError as e: + print('--> except') + if 'Missing key(s) in state_dict:' in str(e): + state_dict = { + key.replace('module.', '', 1): value + for key, value in state_dict.items() + } + state_dict = { + key.replace('featurizer.', '', 1).replace('classifier.','',1): value + for key, value in state_dict.items() + } + state_dict = { + re.sub('network.[0-9].', '', key): value + for key, value in state_dict.items() + } + try: + del state_dict['criterion.weight'] + except: + pass + self.load_state_dict(state_dict) + + print('except: loaded') \ No newline at end of file diff --git a/adrd/nn/vitautoenc.py b/adrd/nn/vitautoenc.py new file mode 100755 index 0000000000000000000000000000000000000000..2cdbef2b0943b06d21dfe14750c2260127f5336c --- /dev/null +++ b/adrd/nn/vitautoenc.py @@ -0,0 +1,163 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from monai.networks.blocks.patchembedding import PatchEmbeddingBlock +from monai.networks.layers import Conv +from monai.utils import ensure_tuple_rep + +from typing import Sequence, Union +import torch +import torch.nn as nn + +from ..nn.blocks import TransformerBlock +from icecream import ic +ic.disable() + +__all__ = ["ViTAutoEnc"] + + +class ViTAutoEnc(nn.Module): + """ + Vision Transformer (ViT), based on: "Dosovitskiy et al., + An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " + + Modified to also give same dimension outputs as the input size of the image + """ + + def __init__( + self, + in_channels: int, + img_size: Union[Sequence[int], int], + patch_size: Union[Sequence[int], int], + out_channels: int = 1, + deconv_chns: int = 16, + hidden_size: int = 768, + mlp_dim: int = 3072, + num_layers: int = 12, + num_heads: int = 12, + pos_embed: str = "conv", + dropout_rate: float = 0.0, + spatial_dims: int = 3, + ) -> None: + """ + Args: + in_channels: dimension of input channels or the number of channels for input + img_size: dimension of input image. + patch_size: dimension of patch size. + hidden_size: dimension of hidden layer. + out_channels: number of output channels. + deconv_chns: number of channels for the deconvolution layers. + mlp_dim: dimension of feedforward layer. + num_layers: number of transformer blocks. + num_heads: number of attention heads. + pos_embed: position embedding layer type. + dropout_rate: faction of the input units to drop. + spatial_dims: number of spatial dimensions. + + Examples:: + + # for single channel input with image size of (96,96,96), conv position embedding and segmentation backbone + # It will provide an output of same size as that of the input + >>> net = ViTAutoEnc(in_channels=1, patch_size=(16,16,16), img_size=(96,96,96), pos_embed='conv') + + # for 3-channel with image size of (128,128,128), output will be same size as of input + >>> net = ViTAutoEnc(in_channels=3, patch_size=(16,16,16), img_size=(128,128,128), pos_embed='conv') + + """ + + super().__init__() + + self.patch_size = ensure_tuple_rep(patch_size, spatial_dims) + self.spatial_dims = spatial_dims + self.hidden_size = hidden_size + + self.patch_embedding = PatchEmbeddingBlock( + in_channels=in_channels, + img_size=img_size, + patch_size=patch_size, + hidden_size=hidden_size, + num_heads=num_heads, + pos_embed=pos_embed, + dropout_rate=dropout_rate, + spatial_dims=self.spatial_dims, + ) + self.blocks = nn.ModuleList( + [TransformerBlock(hidden_size, mlp_dim, num_heads, dropout_rate) for i in range(num_layers)] + ) + self.norm = nn.LayerNorm(hidden_size) + + new_patch_size = [4] * self.spatial_dims + conv_trans = Conv[Conv.CONVTRANS, self.spatial_dims] + # self.conv3d_transpose* is to be compatible with existing 3d model weights. + self.conv3d_transpose = conv_trans(hidden_size, deconv_chns, kernel_size=new_patch_size, stride=new_patch_size) + self.conv3d_transpose_1 = conv_trans( + in_channels=deconv_chns, out_channels=out_channels, kernel_size=new_patch_size, stride=new_patch_size + ) + + def forward(self, x, return_emb=False, return_hiddens=False): + """ + Args: + x: input tensor must have isotropic spatial dimensions, + such as ``[batch_size, channels, sp_size, sp_size[, sp_size]]``. + """ + spatial_size = x.shape[2:] + x = self.patch_embedding(x) + hidden_states_out = [] + for blk in self.blocks: + x = blk(x) + hidden_states_out.append(x) + x = self.norm(x) + x = x.transpose(1, 2) + if return_emb: + return x + d = [s // p for s, p in zip(spatial_size, self.patch_size)] + x = torch.reshape(x, [x.shape[0], x.shape[1], *d]) + x = self.conv3d_transpose(x) + x = self.conv3d_transpose_1(x) + if return_hiddens: + return x, hidden_states_out + return x + + def get_last_selfattention(self, x): + """ + Args: + x: input tensor must have isotropic spatial dimensions, + such as ``[batch_size, channels, sp_size, sp_size[, sp_size]]``. + """ + x = self.patch_embedding(x) + ic(x.size()) + for i, blk in enumerate(self.blocks): + if i < len(self.blocks) - 1: + x = blk(x) + x.size() + else: + return blk(x, return_attention=True) + + def load(self, ckpt_path, map_location='cpu', checkpoint_key='state_dict'): + """ + Args: + ckpt_path: path to the pretrained weights + map_location: device to load the checkpoint on + """ + state_dict = torch.load(ckpt_path, map_location=map_location) + ic(state_dict['epoch'], state_dict['train_loss']) + if checkpoint_key in state_dict: + print(f"Take key {checkpoint_key} in provided checkpoint dict") + state_dict = state_dict[checkpoint_key] + # remove `module.` prefix + state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} + # remove `backbone.` prefix induced by multicrop wrapper + state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} + msg = self.load_state_dict(state_dict, strict=False) + print('Pretrained weights found at {} and loaded with msg: {}'.format(ckpt_path, msg)) + + diff --git a/adrd/shap/__init__.py b/adrd/shap/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..9c44fce30dd27e85d1241415571aec0e9603d1dc --- /dev/null +++ b/adrd/shap/__init__.py @@ -0,0 +1,2 @@ +from .base import BaseExplainer +from .mc import MCExplainer diff --git a/adrd/shap/base.py b/adrd/shap/base.py new file mode 100755 index 0000000000000000000000000000000000000000..9e9aa7a112f82fdeb31973debadb314919ffa714 --- /dev/null +++ b/adrd/shap/base.py @@ -0,0 +1,68 @@ +from abc import ABC +from abc import abstractmethod +from typing import Any, Type +from functools import wraps +from torch.utils.data import DataLoader +from torch import set_grad_enabled +import torch +Tensor = Type[torch.Tensor] + +from ..utils.misc import convert_args_kwargs_to_kwargs +from ..utils import TransformerTestingDataset +from ..model import ADRDModel + +class BaseExplainer: + """ ... """ + def __init__(self, model: ADRDModel) -> None: + """ ... """ + self.model = model + + def shap_values(self, + x, + is_embedding: dict[str, bool] | None = None, + ): + """ ... """ + # result placeholder + phi = [ + { + tgt_k: { + src_k: 0.0 for src_k in self.model.src_modalities + } for tgt_k in self.model.tgt_modalities + } + ] + + # set nn to eval mode + set_grad_enabled(False) + self.model.net_.eval() + + # initialize dataset and dataloader object + dat = TransformerTestingDataset(x, self.model.src_modalities, is_embedding) + ldr = DataLoader( + dataset = dat, + batch_size = 1, + shuffle = False, + drop_last = False, + num_workers = 0, + collate_fn = TransformerTestingDataset.collate_fn, + ) + + # loop through instances and compute shap values + for idx, (smp, mask) in enumerate(ldr): + mask_flat = torch.concatenate(list(mask.values())) + if torch.logical_not(mask_flat).sum().item() == 0: + pass + elif torch.logical_not(mask_flat).sum().item() == 1: + pass + else: + self._shap_values_core(smp, mask, phi[idx], is_embedding) + + return phi + + @abstractmethod + def _shap_values_core(self, + smp: dict[str, Tensor], + mask: dict[str, Tensor], + phi_: dict[str, dict[str, float]], + ): + """ To implement different algorithms. """ + pass \ No newline at end of file diff --git a/adrd/shap/mc.py b/adrd/shap/mc.py new file mode 100755 index 0000000000000000000000000000000000000000..f14794dedc5423df0ecd214ab9c348105c458875 --- /dev/null +++ b/adrd/shap/mc.py @@ -0,0 +1,86 @@ +__all__ = ['MCExplainer'] + +from . import BaseExplainer +from typing import Any, Type +from torch import set_grad_enabled +from torch.utils.data import Dataset +from torch.utils.data import DataLoader +import random +import torch +import numpy as np +from tqdm import tqdm +Tensor = Type[torch.Tensor] + +NUM_PERMUTATIONS = 1024 +BATCH_SIZE = NUM_PERMUTATIONS + +class MCExplainer(BaseExplainer): + + def __init__(self, + model: Any, + ): + """ ... """ + super().__init__(model) + + def _shap_values_core(self, + smp: dict[str, Tensor], + mask: dict[str, Tensor], + phi_: dict[str, dict[str, float]], + is_embedding: dict[str, bool] | None = None, + ): + """ ... """ + # get the list of available feature names + avail = [k for k in mask if mask[k].item() == False] + + # repeat feature dict and mount to device + smps = dict() + for k, v in smp.items(): + if len(v.shape) == 1: + smps[k] = smp[k].expand(NUM_PERMUTATIONS) + elif len(v.shape) == 2: + smps[k] = smp[k].expand(NUM_PERMUTATIONS, -1) + elif len(v.shape) == 3: + smps[k] = smp[k].expand(NUM_PERMUTATIONS, -1, -1) + else: + raise ValueError + smps = {k: smps[k].to(self.model.device) for k in self.model.src_modalities} + + # loop through available features + print('{} features to evaluate ...'.format(len(avail))) + for src_k in tqdm(avail): + # get features to uncover + to_uncover = [] + for _ in range(NUM_PERMUTATIONS): + perm = avail.copy() + random.shuffle(perm) + to_uncover.append(perm[:perm.index(src_k)]) + + # construct masks without src_k + masks_wo_src_k = {k: np.ones(NUM_PERMUTATIONS, dtype=np.bool_) for k in self.model.src_modalities} + for i, lst in enumerate(to_uncover): + for k in lst: + masks_wo_src_k[k][i] = False + + # construct masks with src_k + masks_wi_src_k = masks_wo_src_k.copy() + masks_wi_src_k[src_k] = np.zeros(NUM_PERMUTATIONS, dtype=np.bool_) + + # mount inputs to device + masks_wi_src_k = {k: torch.tensor(masks_wi_src_k[k], device=self.model.device) for k in self.model.src_modalities} + masks_wo_src_k = {k: torch.tensor(masks_wo_src_k[k], device=self.model.device) for k in self.model.src_modalities} + + # run model + out_wi_src_k = self.model.net_(smps, masks_wi_src_k, is_embedding) + out_wo_src_k = self.model.net_(smps, masks_wo_src_k, is_embedding) + + # to numpy + out_wi_src_k = {k: out_wi_src_k[k].cpu().numpy() for k in self.model.tgt_modalities} + out_wo_src_k = {k: out_wo_src_k[k].cpu().numpy() for k in self.model.tgt_modalities} + + # replace nan with zeros when all input features are excluded + out_wo_src_k = {k: np.nan_to_num(out_wo_src_k[k]) for k in self.model.tgt_modalities} + + # calculate shap values + mean = {k: (out_wi_src_k[k] - out_wo_src_k[k]).mean() for k in self.model.tgt_modalities} + for tgt_k in self.model.tgt_modalities: + phi_[tgt_k][src_k] = mean[tgt_k] \ No newline at end of file diff --git a/adrd/typing.py b/adrd/typing.py new file mode 100755 index 0000000000000000000000000000000000000000..d3fd327495897106d11b649cdfd46ffeffb35574 --- /dev/null +++ b/adrd/typing.py @@ -0,0 +1,3 @@ +from typing import Any, Type +import torch + diff --git a/adrd/utils/__init__.py b/adrd/utils/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..a343f77ab910f91035ea5570899937221555d25b --- /dev/null +++ b/adrd/utils/__init__.py @@ -0,0 +1,11 @@ +from .formatter import Formatter +from .imputer import ConstantImputer +from .imputer import FrequencyImputer +from .masker import MissingMasker +from .masker import DropoutMasker +from .masker import LabelMasker +from .transformer_dataset import TransformerTrainingDataset +from .transformer_dataset import TransformerValidationDataset +from .transformer_dataset import TransformerTestingDataset +from .transformer_dataset import TransformerBalancedTrainingDataset +from .transformer_dataset import Transformer2ndOrderBalancedTrainingDataset diff --git a/adrd/utils/__pycache__/__init__.cpython-311.pyc b/adrd/utils/__pycache__/__init__.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..4f0c778903427de8f92a976bfacbea77c147dc3d Binary files /dev/null and b/adrd/utils/__pycache__/__init__.cpython-311.pyc differ diff --git a/adrd/utils/__pycache__/formatter.cpython-311.pyc b/adrd/utils/__pycache__/formatter.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..89aaf66bb997a7a19d7d4aa3f4685fae03445cc0 Binary files /dev/null and b/adrd/utils/__pycache__/formatter.cpython-311.pyc differ diff --git a/adrd/utils/__pycache__/imputer.cpython-311.pyc b/adrd/utils/__pycache__/imputer.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..300369cbe11fc5664424947edbfcb0fb555195f0 Binary files /dev/null and b/adrd/utils/__pycache__/imputer.cpython-311.pyc differ diff --git a/adrd/utils/__pycache__/masker.cpython-311.pyc b/adrd/utils/__pycache__/masker.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..086758b5383e05f7b60afe5e8373f90c38ee9da4 Binary files /dev/null and b/adrd/utils/__pycache__/masker.cpython-311.pyc differ diff --git a/adrd/utils/__pycache__/misc.cpython-311.pyc b/adrd/utils/__pycache__/misc.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..891b1cfb624c78bb8babe970860768b0722d9e5d Binary files /dev/null and b/adrd/utils/__pycache__/misc.cpython-311.pyc differ diff --git a/adrd/utils/__pycache__/transformer_dataset.cpython-311.pyc b/adrd/utils/__pycache__/transformer_dataset.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..fb5d135661f1465be51500bf2935d01d87097110 Binary files /dev/null and b/adrd/utils/__pycache__/transformer_dataset.cpython-311.pyc differ diff --git a/adrd/utils/formatter.py b/adrd/utils/formatter.py new file mode 100755 index 0000000000000000000000000000000000000000..a9e7d508b904ece6c9d7974dfd6089b5a55dced2 --- /dev/null +++ b/adrd/utils/formatter.py @@ -0,0 +1,74 @@ +from typing import Any +from numpy.typing import NDArray +import numpy as np + +class Formatter: + ''' ... ''' + def __init__(self, + modalities: dict[str, dict[str, Any]], + ) -> None: + ''' ... ''' + self.modalities = modalities + + def __call__(self, + smp: dict[str, Any], + ) -> dict[str, int | NDArray[np.float32] | None]: + ''' ... ''' + new = dict() + + # loop through all data modalities + for k, info in self.modalities.items(): + # the value is missing or equals None + if k not in smp or smp[k] is None: + new[k] = None + continue + + # get value + v = smp[k] + + # if info['type'] == 'imaging': + # print(k) + # print(v.shape) + if info['type'] == 'imaging' and len(info['shape']) == 4: + new[k] = v + continue + + # validate the value by using numpy's intrinsic machanism + try: + v_np = np.array(v, dtype=np.float32) + except: + raise ValueError('\"{}\" has unexpected value {}'.format(k, v)) + + # additional validation for categorical value + if info['type'] == 'categorical': + # print(k, v_np.shape) + if v_np.shape != (): + raise ValueError('Categorical data \"{}\" has unexpected value {}.'.format(k, v)) + elif int(v) != v: + raise ValueError('Categorical data \"{}\" has unexpected value {}.'.format(k, v)) + elif v < 0: # or v >= info['num_categories']: + raise ValueError('Categorical data \"{}\" has unexpected value {}.'.format(k, v)) + + # additional validation for numerical value + elif info['type'] == 'numerical': + if info['shape'] == [1] and v_np.shape != () and v_np.shape != (1,): + raise ValueError('Numerical data \"{}\" has unexpected shape {}.'.format(k, v_np.shape)) + elif info['shape'] != [1] and tuple(info['shape']) != v_np.shape: + raise ValueError('Numerical data \"{}\" has unexpected shape {}.'.format(k, v_np.shape)) + + + + # format categorical value + if info['type'] == 'categorical': + new[k] = int(v) + + # format numerical value + elif info['type'] == 'numerical' or info['type'] == 'imaging': + if info['shape'] == [1] and v_np.shape == (): + # unsqueeze the data + new[k] = np.array([v], dtype=np.float32) + else: + new[k] = v_np + + + return new \ No newline at end of file diff --git a/adrd/utils/imputer.py b/adrd/utils/imputer.py new file mode 100755 index 0000000000000000000000000000000000000000..eb2df1c43d71c514e627af7cc98becff4e616d2d --- /dev/null +++ b/adrd/utils/imputer.py @@ -0,0 +1,103 @@ +from abc import ABC, abstractmethod +from collections.abc import Callable +from functools import wraps +from typing import Any +from numpy.typing import NDArray +import numpy as np +import torch + +class Imputer(ABC): + ''' ... ''' + def __init__(self, + modalities: dict[str, dict[str, Any]], + is_embedding: dict[str, bool] | None = None + ) -> None: + ''' ... ''' + self.modalities = modalities + self.is_embedding = is_embedding + + @abstractmethod + def __call__(self, + smp: dict[str, int | NDArray[np.float32] | None], + ) -> dict[str, int | NDArray[np.float32]]: + ''' ... ''' + pass + + @staticmethod + def _keyerror_hint(func): + ''' Print hint for resolving KeyError. ''' + @wraps(func) + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except KeyError as err: + raise ValueError('Format the data using Formatter module.') from err + return wrapper + + +class ConstantImputer(Imputer): + ''' ... ''' + @Imputer._keyerror_hint + def __call__(self, + smp: dict[str, int | NDArray[np.float32] | None], + ) -> dict[str, int | NDArray[np.float32]]: + ''' ... ''' + new = dict() + for k, info in self.modalities.items(): + if smp[k] is not None: + new[k] = smp[k] + else: + if self.is_embedding is not None and k in self.is_embedding and self.is_embedding[k]: + new[k] = np.zeros(256, dtype=np.float32) + else: + if info['type'] == 'categorical': + new[k] = 0 + elif info['type'] == 'numerical' or info['type'] == 'imaging': + new[k] = np.zeros(tuple(info['shape']), dtype=np.float32) + else: + raise ValueError + return new + + +class FrequencyImputer(Imputer): + ''' ... ''' + @Imputer._keyerror_hint + def __init__(self, + modalities: dict[str, dict[str, Any]], + dat: list[dict[str, int | NDArray[np.float32] | None]], + ) -> None: + ''' ... ''' + super().__init__(modalities) + + # List[Dict] to Dict[List] + self.pool = {k: [smp[k] for smp in dat] for k in modalities} + + # remove None + self.pool = {k: [v for v in self.pool[k] if v is not None] for k in self.pool} + + + @Imputer._keyerror_hint + def __call__(self, + smp: dict[str, int | NDArray[np.float32] | None], + ) -> dict[str, int | NDArray[np.float32]]: + ''' ... ''' + new = dict() + for k, info in self.modalities.items(): + if smp[k] is not None: + new[k] = smp[k] + else: + # print(k) + if info['type'] == 'categorical': + new[k] = 0 + else: + if info['type'] == 'numerical': + rnd_idx = np.random.randint(0, len(self.pool[k])) + new[k] = np.array(self.pool[k][rnd_idx]) + # print(type(new[k])) + elif info['type'] == 'imaging': + new[k] = np.zeros(tuple(info['shape']), dtype=np.float32) + # print(new[k].shape) + else: + ic(info['shape']) + raise ValueError + return new \ No newline at end of file diff --git a/adrd/utils/masker.py b/adrd/utils/masker.py new file mode 100755 index 0000000000000000000000000000000000000000..a450dd8d6cbd2874c442a17ff9acd826bba02d7b --- /dev/null +++ b/adrd/utils/masker.py @@ -0,0 +1,127 @@ +from abc import ABC, abstractmethod +from functools import wraps +from typing import Any +from numpy.typing import NDArray +import numpy as np +from random import shuffle +from random import choice + + +class Masker(ABC): + ''' ... ''' + def __init__(self, + modalities: dict[str, dict[str, Any]], + ) -> None: + ''' ... ''' + self.modalities = modalities + + @abstractmethod + def __call__(self, + smp: dict[str, int | NDArray[np.float32] | None], + ) -> dict[str, bool]: + ''' ... ''' + pass + + @staticmethod + def _keyerror_hint(func): + ''' Print hint for resolving KeyError. ''' + @wraps(func) + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except KeyError as err: + raise ValueError('Format the data using Formatter module.') from err + return wrapper + + +class MissingMasker(Masker): + ''' ... ''' + @Masker._keyerror_hint + def __call__(self, + smp: dict[str, int | NDArray[np.float32] | None], + ) -> dict[str, bool]: + ''' ... ''' + return {k: smp[k] is None for k in self.modalities} + + +class DropoutMasker(Masker): + ''' ... ''' + @Masker._keyerror_hint + def __init__(self, + modalities: dict[str, dict[str, Any]], + dat: list[dict[str, int | NDArray[np.float32] | None]], + dropout_rate: float = .5, + dropout_strategy: str = 'permutation', + ) -> None: + ''' ... ''' + super().__init__(modalities) + + # allowed strategies for dropout + assert dropout_strategy in ['simple', 'compensated', 'permutation'] + self.dropout_strategy = dropout_strategy + + # calculate missing rates + missing_rates = {k: sum([dat[i][k] is None for i in range(len(dat))]) / len(dat) for k in modalities} + + # calculate dropout rates + if dropout_strategy == 'simple': + dropout_rates = {k: dropout_rate for k in modalities} + + elif dropout_strategy == 'compensated': + dropout_rates = {k: (dropout_rate - missing_rates[k]) / (1 - missing_rates[k] + 1e-16) for k in modalities} + dropout_rates = {k: 0 if dropout_rates[k] < 0 else dropout_rates[k] for k in modalities} + + # useful attributes + if dropout_strategy != 'permutation': + self.missing_rates = missing_rates + self.dropout_rates = dropout_rates + + @Masker._keyerror_hint + def __call__(self, + smp: dict[str, int | NDArray[np.float32] | None], + ) -> dict[str, bool]: + ''' ... ''' + if self.dropout_strategy == 'permutation': + src_keys = [k for k in self.modalities if smp[k] is not None] + shuffle(src_keys) + src_keys = src_keys[:choice(range(1, len(src_keys) + 1))] + mask = {k: True for k in self.modalities} + for k in src_keys: + mask[k] = False + return mask + + else: + # get missing mask first + missing_mask = {k: smp[k] is None for k in self.modalities} + + # vectorize + missing_mask_vec = np.array(list(missing_mask.values())) + dropout_rate_vec = np.array(list(self.dropout_rates.values())) + + # generate dropout mask, at least 1 element shall be kept + while True: + dropout_mask_vec = np.random.rand(len(dropout_rate_vec)) < dropout_rate_vec + dropout_mask_vec = dropout_mask_vec | missing_mask_vec + if not np.all(dropout_mask_vec): break + + return {k: dropout_mask_vec[i] for i, k in enumerate(self.modalities.keys())} + +class LabelMasker(): + ''' ... ''' + def __init__(self, + modalities: dict[str, dict[str, Any]], + ) -> None: + ''' ... ''' + + # useful attributes + self.modalities = modalities + + def __call__(self, + smp: dict[str, int | NDArray[np.float32] | None], + ) -> dict[str, int | NDArray[np.float32]]: + ''' ... ''' + # get missing mask + label_mask = {k: 1 if smp[k] is not None else 0 for k in self.modalities} + # print(label_mask) + + return label_mask diff --git a/adrd/utils/misc.py b/adrd/utils/misc.py new file mode 100755 index 0000000000000000000000000000000000000000..c4a70542d94ba3b108bcdb35fc29fc9c75f4398c --- /dev/null +++ b/adrd/utils/misc.py @@ -0,0 +1,319 @@ +import numpy as np +import sys, tqdm +import torch +import torch.nn.functional as F +from numpy import interp +from collections.abc import Sequence +from collections import defaultdict +from sklearn.metrics import confusion_matrix, roc_curve, roc_auc_score, auc +from sklearn.metrics import precision_recall_curve, average_precision_score +from sklearn.metrics import balanced_accuracy_score, precision_score +import warnings +import inspect + +_depth = lambda L: isinstance(L, (Sequence, np.ndarray)) and max(map(_depth, L)) + 1 + + +def get_metrics(y_true, y_pred, scores, mask): + ''' ... ''' + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + + masked_y_true = y_true[np.where(mask == 1)] + masked_y_pred = y_pred[np.where(mask == 1)] + masked_scores = scores[np.where(mask == 1)] + + # metrics that are based on predictions + try: + cnf = confusion_matrix(masked_y_true, masked_y_pred) + TN, FP, FN, TP = cnf.ravel() + TNR = TN / (TN + FP) + FPR = FP / (FP + TN) + FNR = FN / (FN + TP) + TPR = TP / (TP + FN) + N = TN + TP + FN + FP + S = (TP + FN) / N + P = (TP + FP) / N + acc = (TN + TP) / N + sen = TP / (TP + FN) + spc = TN / (TN + FP) + prc = TP / (TP + FP) + f1s = 2 * (prc * sen) / (prc + sen) + mcc = (TP / N - S * P) / np.sqrt(P * S * (1 - S) * (1 - P)) + + # metrics that are based on scores, + try: + auc_roc = roc_auc_score(masked_y_true, masked_scores) + except: + auc_roc = 0 + try: + auc_pr = average_precision_score(masked_y_true, masked_scores) + except: + auc_pr = 0 + + bal_acc = balanced_accuracy_score(masked_y_true, masked_y_pred) + except: + cnf, acc, bal_acc, prc, sen, spc, f1s, mcc, auc_roc, auc_pr = -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 + + # construct the dictionary of all metrics + met = {} + met['Confusion Matrix'] = cnf + met['Accuracy'] = acc + met['Balanced Accuracy'] = bal_acc + met['Precision'] = prc + met['Sensitivity/Recall'] = sen + met['Specificity'] = spc + met['F1 score'] = f1s + met['MCC'] = mcc + met['AUC (ROC)'] = auc_roc + met['AUC (PR)'] = auc_pr + + return met + + +def get_metrics_multitask(y_true, y_pred, scores, mask): + ''' ... ''' + if type(y_true) is dict: + met: dict[str, dict[str, float]] = dict() + for k in y_true.keys(): + met[k] = get_metrics(y_true[k], y_pred[k], scores[k], mask[k]) + else: + met = [] + for i in range(len(y_true[0])): + met.append(get_metrics(y_true[:, i], y_pred[:, i], scores[:, i], mask[:, i])) + return met + + +def print_metrics(met): + ''' ... ''' + for k, v in met.items(): + if k not in ['Confusion Matrix']: + print('{}:\t{:.4f}'.format(k, v).expandtabs(20)) + + +def print_metrics_multitask(met): + ''' ... ''' + if type(met) is dict: + lbl_ks = list(met.keys()) + met_ks = met[lbl_ks[0]].keys() + for met_k in met_ks: + if met_k not in ['Confusion Matrix']: + msg = '{}:\t' + '{:.4f} ' * len(met) + val = [met[lbl_k][met_k] for lbl_k in lbl_ks] + msg = msg.format(met_k, *val) + msg = msg.replace('nan', '------') + print(msg.expandtabs(20)) + else: + for k in met[0]: + if k not in ['Confusion Matrix']: + msg = '{}:\t' + '{:.4f} ' * len(met) + val = [met[i][k] for i in range(len(met))] + msg = msg.format(k, *val) + msg = msg.replace('nan', '------') + print(msg.expandtabs(20)) + + +def pr_interp(rc_, rc, pr): + + pr_ = np.zeros_like(rc_) + locs = np.searchsorted(rc, rc_) + + for idx, loc in enumerate(locs): + l = loc - 1 + r = loc + r1 = rc[l] if l > -1 else 0 + r2 = rc[r] if r < len(rc) else 1 + p1 = pr[l] if l > -1 else 1 + p2 = pr[r] if r < len(rc) else 0 + + t1 = (1 - p2) * r2 / p2 / (r2 - r1) if p2 * (r2 - r1) > 1e-16 else (1 - p2) * r2 / 1e-16 + t2 = (1 - p1) * r1 / p1 / (r2 - r1) if p1 * (r2 - r1) > 1e-16 else (1 - p1) * r1 / 1e-16 + t3 = (1 - p1) * r1 / p1 if p1 > 1e-16 else (1 - p1) * r1 / 1e-16 + + a = 1 + t1 - t2 + b = t3 - t1 * r1 + t2 * r1 + pr_[idx] = rc_[idx] / (a * rc_[idx] + b) + + return pr_ + + +def get_roc_info(y_true_all, scores_all): + + fpr_pt = np.linspace(0, 1, 1001) + tprs, aucs = [], [] + + for i in range(len(y_true_all)): + y_true = y_true_all[i] + scores = scores_all[i] + fpr, tpr, _ = roc_curve(y_true=y_true, y_score=scores, drop_intermediate=True) + tprs.append(interp(fpr_pt, fpr, tpr)) + tprs[-1][0] = 0.0 + aucs.append(auc(fpr, tpr)) + + tprs_mean = np.mean(tprs, axis=0) + tprs_std = np.std(tprs, axis=0) + tprs_upper = np.minimum(tprs_mean + tprs_std, 1) + tprs_lower = np.maximum(tprs_mean - tprs_std, 0) + auc_mean = auc(fpr_pt, tprs_mean) + auc_std = np.std(aucs) + auc_std = 1 - auc_mean if auc_mean + auc_std > 1 else auc_std + + rslt = { + 'xs': fpr_pt, + 'ys_mean': tprs_mean, + 'ys_upper': tprs_upper, + 'ys_lower': tprs_lower, + 'auc_mean': auc_mean, + 'auc_std': auc_std + } + + return rslt + + +def get_pr_info(y_true_all, scores_all): + + rc_pt = np.linspace(0, 1, 1001) + rc_pt[0] = 1e-16 + prs = [] + aps = [] + + for i in range(len(y_true_all)): + y_true = y_true_all[i] + scores = scores_all[i] + pr, rc, _ = precision_recall_curve(y_true=y_true, probas_pred=scores) + aps.append(average_precision_score(y_true=y_true, y_score=scores)) + pr, rc = pr[::-1], rc[::-1] + prs.append(pr_interp(rc_pt, rc, pr)) + + prs_mean = np.mean(prs, axis=0) + prs_std = np.std(prs, axis=0) + prs_upper = np.minimum(prs_mean + prs_std, 1) + prs_lower = np.maximum(prs_mean - prs_std, 0) + aps_mean = np.mean(aps) + aps_std = np.std(aps) + aps_std = 1 - aps_mean if aps_mean + aps_std > 1 else aps_std + + rslt = { + 'xs': rc_pt, + 'ys_mean': prs_mean, + 'ys_upper': prs_upper, + 'ys_lower': prs_lower, + 'auc_mean': aps_mean, + 'auc_std': aps_std + } + + return rslt + + +def get_and_print_metrics(mdl, dat): + ''' ... ''' + y_pred = mdl.predict(dat.x) + y_prob = mdl.predict_proba(dat.x) + met_all = get_metrics(dat.y, y_pred, y_prob) + + for k, v in met_all.items(): + if k not in ['Confusion Matrix']: + print('{}:\t{:.4f}'.format(k, v).expandtabs(20)) + + +def get_and_print_metrics_multitask(mdl, dat): + ''' ... ''' + y_pred = mdl.predict(dat.x) + y_prob = mdl.predict_proba(dat.x) + met = get_metrics_multitask(dat.y, y_pred, y_prob) + print_metrics_multitask(met) + + +def split_dataset(dat, ratio=.8, seed=0): + + len_trn = int(np.round(len(dat) * .8)) + len_vld = len(dat) - len_trn + dat_trn, dat_vld = torch.utils.data.random_split( + dat, (len_trn, len_vld), + generator=torch.Generator().manual_seed(0) + ) + return dat_trn, dat_vld + + +def l1_regularizer(model, lambda_l1=0.01): + ''' LASSO ''' + + lossl1 = 0 + for model_param_name, model_param_value in model.named_parameters(): + if model_param_name.endswith('weight'): + lossl1 += lambda_l1 * model_param_value.abs().sum() + return lossl1 + + +class ProgressBar(tqdm.tqdm): + + def __init__(self, total, desc, file=sys.stdout): + + super().__init__(total=total, desc=desc, ascii=True, bar_format='{l_bar}{r_bar}', file=file) + + def update(self, batch_size, to_disp): + + postfix = {} + + for k, v in to_disp.items(): + if k == 'cnf': + postfix[k] = v.__repr__().replace('\n', '') + + else: + postfix[k] = '{:.6f}'.format(v.cpu().numpy()) + + self.set_postfix(postfix) + super().update(batch_size) + + +def _get_gt_mask(logits, target): + target = target.reshape(-1) + mask = torch.zeros_like(logits).scatter_(1, target.unsqueeze(1), 1).bool() + return mask + + +def _get_other_mask(logits, target): + target = target.reshape(-1) + mask = torch.ones_like(logits).scatter_(1, target.unsqueeze(1), 0).bool() + return mask + + +def cat_mask(t, mask1, mask2): + t1 = (t * mask1).sum(dim=1, keepdims=True) + t2 = (t * mask2).sum(1, keepdims=True) + rt = torch.cat([t1, t2], dim=1) + return rt + + +def dkd_loss(logits_student, logits_teacher, target, alpha, beta, temperature): + gt_mask = _get_gt_mask(logits_student, target) + other_mask = _get_other_mask(logits_student, target) + pred_student = F.softmax(logits_student / temperature, dim=1) + pred_teacher = F.softmax(logits_teacher / temperature, dim=1) + pred_student = cat_mask(pred_student, gt_mask, other_mask) + pred_teacher = cat_mask(pred_teacher, gt_mask, other_mask) + log_pred_student = torch.log(pred_student) + tckd_loss = ( + F.kl_div(log_pred_student, pred_teacher, size_average=False) + * (temperature**2) + / target.shape[0] + ) + pred_teacher_part2 = F.softmax( + logits_teacher / temperature - 1000.0 * gt_mask, dim=1 + ) + log_pred_student_part2 = F.log_softmax( + logits_student / temperature - 1000.0 * gt_mask, dim=1 + ) + nckd_loss = ( + F.kl_div(log_pred_student_part2, pred_teacher_part2, size_average=False) + * (temperature**2) + / target.shape[0] + ) + return alpha * tckd_loss + beta * nckd_loss + +def convert_args_kwargs_to_kwargs(func, args, kwargs): + """ ... """ + signature = inspect.signature(func) + bound_args = signature.bind(*args, **kwargs) + bound_args.apply_defaults() + return bound_args.arguments \ No newline at end of file diff --git a/adrd/utils/transformer_dataset.py b/adrd/utils/transformer_dataset.py new file mode 100755 index 0000000000000000000000000000000000000000..152e3b237cf98d55b7ab970d907565ef15de52ee --- /dev/null +++ b/adrd/utils/transformer_dataset.py @@ -0,0 +1,385 @@ +import torch +from torch.utils.data import Dataset +import numpy as np +from functools import cached_property +from typing import Any, Type +from numpy.typing import NDArray +import random +from monai.utils.type_conversion import convert_to_tensor +import time + +Tensor = Type[torch.Tensor] + +from .masker import Masker +from . import DropoutMasker +from . import MissingMasker +from . import LabelMasker + +from .imputer import Imputer +from . import FrequencyImputer +from . import ConstantImputer +from . import Formatter +import random +import os + +class TransformerDataset(torch.utils.data.Dataset): + ''' ... ''' + def __init__(self, + src: list[dict[str, Any]], + tgt: list[dict[str, Any]] | None, + src_modalities: dict[str, dict[str, Any]], + tgt_modalities: dict[str, dict[str, Any]] | None, + img_transform: Any | None = None, + is_embedding: dict[str, bool] | None = None + ) -> None: + ''' ... ''' + # boolean dict to indicate which features are embeddings + self.is_embedding = is_embedding + + # format source + self.fmt_src = Formatter(src_modalities) + self.src = [self.fmt_src(smp) for smp in src] + self.src_modalities = src_modalities + # self.src = src + # format target + if tgt is None: return + self.fmt_tgt = Formatter(tgt_modalities) + self.tgt = [self.fmt_tgt(smp) for smp in tgt] + self.tgt_modalities = tgt_modalities + # self.tgt = tgt + + self.img_transform = img_transform + + def __len__(self) -> int: + ''' ... ''' + return len(self.src) + + def img_input_trans(self, k, x): + if self.img_transform is not None: + try: + mri = self.img_transform({"image": x})["image"] + if torch.isnan(mri).any() or mri.size(0) != 1: + return None + # print(mri) + # print(torch.all(mri == 0)) + return mri + except: + return None + else: + return x + + def __getitem__(self, + idx: int + ) -> tuple[ + dict[str, int | NDArray[np.float32]], + dict[str, int | NDArray[np.float32]], + dict[str, bool], + dict[str, int | NDArray[np.float32]], + ]: + ''' ... ''' + + for k, v in self.src[idx].items(): + if isinstance(v, str): + assert os.path.exists(v) + self.src[idx][k] = self.img_input_trans(k, v) + + # impute x and y + x_imp = self.imputer_src(self.src[idx]) + mask_x = self.masker_src(self.src[idx]) + y_imp = self.imputer_tgt(self.tgt[idx]) if hasattr(self, 'tgt') else None + mask_y = self.masker_tgt(self.tgt[idx]) if hasattr(self, 'tgt') else None + + # replace mmap object by the loaded one + for k, v in x_imp.items(): + if isinstance(v, np.memmap): + x_imp[k] = np.load(v.filename) + x_imp[k] = np.reshape(x_imp[k], v.shape) + # elif isinstance(v, str): + # assert os.path.exists(v) + # x_imp[k] = self.img_input_trans(k, v) + + return x_imp, y_imp, mask_x, mask_y + + @cached_property + def imputer_src(self) -> Imputer: + ''' imputer object ''' + raise NotImplementedError + + @cached_property + def imputer_tgt(self) -> Imputer: + ''' imputer object ''' + pass + + @cached_property + def masker_src(self) -> Masker: + ''' mask generator object ''' + raise NotImplementedError + + @cached_property + def masker_tgt(self) -> LabelMasker: + ''' mask generator object ''' + pass + + @staticmethod + def collate_fn( + batch: list[ + tuple[ + dict[str, int | NDArray[np.float32]], + dict[str, int | NDArray[np.float32]], + dict[str, bool], + dict[str, int | NDArray[np.float32]], + ] + ] + ) -> tuple[ + dict[str, Tensor], + dict[str, Tensor], + dict[str, Tensor], + dict[str, Tensor], + ]: + ''' ... ''' + # start_time = time.time() + # seperate entries + _x = [smp[0] for smp in batch] + y = [smp[1] for smp in batch] + m = [smp[2] for smp in batch] + m_y = [smp[3] for smp in batch] + + + y = [{k: v if v is not None else 0 for k, v in y[i].items()} for i in range(len(y))] + + x = {k: torch.stack([convert_to_tensor(_x[i][k]) for i in range(len(_x))]) for k in _x[0]} + y = {k: torch.as_tensor(np.array([y[i][k] for i in range(len(y))])) for k in y[0]} + m = {k: torch.as_tensor(np.array([m[i][k] for i in range(len(m))])) for k in m[0]} + m_y = {k: torch.as_tensor(np.array([m_y[i][k] for i in range(len(m_y))])) for k in m_y[0]} + + return x, y, m, m_y + + +class TransformerTrainingDataset(TransformerDataset): + ''' ... ''' + def __init__(self, + src: list[dict[str, Any]], + tgt: list[dict[str, Any]], + src_modalities: dict[str, dict[str, Any]], + tgt_modalities: dict[str, dict[str, Any]], + dropout_rate: float = .5, + dropout_strategy: str = 'permutation', + img_transform: Any | None = None, + ) -> None: + ''' ... ''' + # call the constructor of parent class + super().__init__(src, tgt, src_modalities, tgt_modalities, img_transform=img_transform) + + self.dropout_rate = dropout_rate + self.dropout_strategy = dropout_strategy + + print(img_transform) + + @cached_property + def imputer_src(self) -> FrequencyImputer: + ''' imputer object ''' + return FrequencyImputer(self.src_modalities, self.src) + + @cached_property + def imputer_tgt(self) -> ConstantImputer: + ''' imputer object ''' + return ConstantImputer(self.tgt_modalities) + + @cached_property + def masker_src(self) -> DropoutMasker: + ''' mask generator object ''' + return DropoutMasker( + self.src_modalities, self.src, + dropout_rate = self.dropout_rate, + dropout_strategy = self.dropout_strategy, + ) + + @cached_property + def masker_tgt(self) -> LabelMasker: + ''' mask generator object ''' + return LabelMasker(self.tgt_modalities) + +class TransformerValidationDataset(TransformerDataset): + def __init__(self, + src: list[dict[str, Any]], + tgt: list[dict[str, Any]], + src_modalities: dict[str, dict[str, Any]], + tgt_modalities: dict[str, dict[str, Any]], + img_transform: Any | None = None, + is_embedding: dict[str, bool] | None = None + ) -> None: + ''' ... ''' + # call the constructor of parent class + super().__init__(src, tgt, src_modalities, tgt_modalities, img_transform=img_transform, is_embedding=is_embedding) + + @cached_property + def imputer_src(self) -> ConstantImputer: + ''' imputer object ''' + return ConstantImputer(self.src_modalities, self.is_embedding) + + @cached_property + def imputer_tgt(self) -> ConstantImputer: + ''' imputer object ''' + return ConstantImputer(self.tgt_modalities) + + @cached_property + def masker_src(self) -> MissingMasker: + ''' mask generator object ''' + return MissingMasker(self.src_modalities) + + @cached_property + def masker_tgt(self) -> LabelMasker: + ''' mask generator object ''' + return LabelMasker(self.tgt_modalities) + + +class TransformerTestingDataset(TransformerValidationDataset): + + def __init__(self, + src: list[dict[str, Any]], + src_modalities: dict[str, dict[str, Any]], + img_transform: Any | None = None, + is_embedding: dict[str, bool] | None = None + ) -> None: + ''' ... ''' + # call the constructor of parent class + super().__init__(src, None, src_modalities, None, img_transform=img_transform, is_embedding=is_embedding) + + def __getitem__(self, + idx: int + ) -> tuple[ + dict[str, int | NDArray[np.float32]], + dict[str, bool], + ]: + ''' ... ''' + x_imp, _, mask_x, _ = super().__getitem__(idx) + return x_imp, mask_x + + @staticmethod + def collate_fn( + batch: list[ + tuple[ + dict[str, int | NDArray[np.float32]], + dict[str, bool], + ] + ] + ) -> tuple[ + dict[str, Tensor], + dict[str, Tensor], + ]: + ''' ... ''' + # seperate entries + x = [smp[0] for smp in batch] + m = [smp[1] for smp in batch] + + # stack and convert to tensor + x = {k: torch.as_tensor(np.array([x[i][k] for i in range(len(x))])) for k in x[0]} + m = {k: torch.as_tensor(np.array([m[i][k] for i in range(len(m))])) for k in m[0]} + + return x, m + + +class TransformerBalancedTrainingDataset(TransformerTrainingDataset): + + def __init__(self, + src: list[dict[str, Any]], + tgt: list[dict[str, Any]], + src_modalities: dict[str, dict[str, Any]], + tgt_modalities: dict[str, dict[str, Any]], + dropout_rate: float = .5, + dropout_strategy: str = 'permutation', + img_transform: Any | None = None, + ) -> None: + ''' ... ''' + # call the constructor of parent class + super().__init__( + src, tgt, src_modalities, tgt_modalities, + dropout_rate, dropout_strategy, img_transform, + ) + + # for each target/label, collect the indices of available cases + self.tgt_indices: dict[str, dict[int, list[int]]] = dict() + for tgt_k in self.tgt_modalities: + tmp = [self.tgt[i][tgt_k] for i in range(len(self.tgt))] + self.tgt_indices[tgt_k] = dict() + self.tgt_indices[tgt_k][0] = [i for i in range(len(self.tgt)) if tmp[i] == 0] + self.tgt_indices[tgt_k][1] = [i for i in range(len(self.tgt)) if tmp[i] == 1] + + def __getitem__(self, + idx: int + ) -> tuple[ + dict[str, int | NDArray[np.float32]], + dict[str, int | NDArray[np.float32]], + dict[str, bool], + dict[str, bool], + ]: + # select random target, class and index + tgt_k = random.choice(list(self.tgt_modalities.keys())) + cls = random.choice([0, 1]) + idx = random.choice(self.tgt_indices[tgt_k][cls]) + + # call __getitem__ of super class + x_imp, y_imp, mask_x, mask_y = super().__getitem__(idx) + + # modify mask_y, all targets are masked except tgt_k + mask_y = {k: mask_y[k] if k == tgt_k else 0 for k in self.tgt_modalities} + # mask_y[tgt_k] = mask_y[k] + + return x_imp, y_imp, mask_x, mask_y + + +class Transformer2ndOrderBalancedTrainingDataset(TransformerTrainingDataset): + + def __init__(self, + src: list[dict[str, Any]], + tgt: list[dict[str, Any]], + src_modalities: dict[str, dict[str, Any]], + tgt_modalities: dict[str, dict[str, Any]], + dropout_rate: float = .5, + dropout_strategy: str = 'permutation', + img_transform: Any | None = None, + ) -> None: + """ ... """ + # call the constructor of parent class + super().__init__( + src, tgt, src_modalities, tgt_modalities, + dropout_rate, dropout_strategy, img_transform, + ) + + # construct dictionary of paired tasks + self.tasks: dict[tuple[str, str], list[int]] = {} + tgt_keys = list(self.tgt_modalities.keys()) + for tgt_k_0 in tgt_keys: + for tgt_k_1 in tgt_keys: + self.tasks[(tgt_k_0, tgt_k_1)] = [] + + for i, smp in enumerate(tgt): + for tgt_k_0 in tgt_keys: + for tgt_k_1 in tgt_keys: + if smp[tgt_k_0] == 0 and smp[tgt_k_1] == 1: + self.tasks[(tgt_k_0, tgt_k_1)].append(i) + + def __getitem__(self, + idx: int + ) -> tuple[ + dict[str, int | NDArray[np.float32]], + dict[str, int | NDArray[np.float32]], + dict[str, bool], + dict[str, bool], + ]: + # select random task + while True: + tgt_k_0 = random.choice(list(self.tgt_modalities.keys())) + tgt_k_1 = random.choice(list(self.tgt_modalities.keys())) + if len(self.tasks[(tgt_k_0, tgt_k_1)]) != 0: + idx = random.choice(self.tasks[(tgt_k_0, tgt_k_1)]) + break + + # call __getitem__ of super class + x_imp, y_imp, mask_x, mask_y = super().__getitem__(idx) + + # modify mask_y, all targets are masked except tgt_k + mask_y = {k: mask_y[k] if k in [tgt_k_0, tgt_k_1] else 0 for k in self.tgt_modalities} + + return x_imp, y_imp, mask_x, mask_y + diff --git a/app.py b/app.py index e413f59bf3895be0488a7fb3d5ec9c9c59d273d2..d72b35ccf8254d26f7ebfe7dc2033e97f044a68f 100644 --- a/app.py +++ b/app.py @@ -1,15 +1,9 @@ import streamlit as st import json -import torch -@st.cache(allow_output_mutation=True) -def load_model(): - # Load the model using torch.hub.load or torch.load, depending on how it's stored - model = torch.hub.load('vkola-lab/ADRD-V050324', 'adrd_model', source='hf') - model.eval() - return model - -mdl = load_model() +import adrd +ckpt_path = './ckpt_densenet_emb_encoder_2_AUPR.pt' +mdl = adrd.model.ADRDModel.from_ckpt(ckpt_path, device='cpu') # Create a form for user input with st.form("json_input_form"): diff --git a/requirements.txt b/requirements.txt old mode 100644 new mode 100755 index 12c6d5d5eac2aa9e97516ec03233adc7e98b9801..9525f7e6670a8f98febd61e11f6a1c00fab6361f --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,9 @@ +icecream +monai +numpy +scikit_learn +scipy torch +torchvision +tqdm +wandb