nmed2024 / adrd /model /cnn_resnet3d_with_linear_classifier.py
xf3227's picture
ok
6fc43ab
raw
history blame
20.3 kB
__all__ = ['CNNResNet3DWithLinearClassifier']
import torch
from torch.utils.data import DataLoader
import numpy as np
import tqdm
from sklearn.base import BaseEstimator
from sklearn.utils.validation import check_is_fitted
from sklearn.model_selection import train_test_split
from scipy.special import expit
from copy import deepcopy
from contextlib import suppress
from typing import Any, Self, Type
from functools import wraps
Tensor = Type[torch.Tensor]
Module = Type[torch.nn.Module]
from ..utils.misc import ProgressBar
from ..utils.misc import get_metrics_multitask, print_metrics_multitask
from .. import nn
from ..utils import TransformerTrainingDataset
from ..utils import Transformer2ndOrderBalancedTrainingDataset
from ..utils import TransformerValidationDataset
from ..utils import TransformerTestingDataset
from ..utils.misc import ProgressBar
from ..utils.misc import get_metrics_multitask, print_metrics_multitask
from ..utils.misc import convert_args_kwargs_to_kwargs
def _manage_ctx_fit(func):
''' ... '''
@wraps(func)
def wrapper(*args, **kwargs):
# format arguments
kwargs = convert_args_kwargs_to_kwargs(func, args, kwargs)
if kwargs['self']._device_ids is None:
return func(**kwargs)
else:
# change primary device
default_device = kwargs['self'].device
kwargs['self'].device = kwargs['self']._device_ids[0]
rtn = func(**kwargs)
# the actual module is wrapped
kwargs['self'].net_ = kwargs['self'].net_.module
kwargs['self'].to(default_device)
return rtn
return wrapper
class CNNResNet3DWithLinearClassifier(BaseEstimator):
def __init__(self,
src_modalities: dict[str, dict[str, Any]],
tgt_modalities: dict[str, dict[str, Any]],
num_epochs: int = 32,
batch_size: int = 8,
batch_size_multiplier: int = 1,
lr: float = 1e-2,
weight_decay: float = 0.0,
beta: float = 0.9999,
gamma: float = 2.0,
scale: float = 1.0,
criterion: str | None = None,
device: str = 'cpu',
verbose: int = 0,
_device_ids: list | None = None,
_dataloader_num_workers: int = 0,
_amp_enabled: bool = False,
_tmp_ckpt_filepath: str | None = None,
) -> None:
''' ... '''
# for multiprocessing
self._rank = 0
self._lock = None
# positional parameters
self.src_modalities = src_modalities
self.tgt_modalities = tgt_modalities
# training parameters
self.num_epochs = num_epochs
self.batch_size = batch_size
self.batch_size_multiplier = batch_size_multiplier
self.lr = lr
self.weight_decay = weight_decay
self.beta = beta
self.gamma = gamma
self.scale = scale
self.criterion = criterion
self.device = device
self.verbose = verbose
self._device_ids = _device_ids
self._dataloader_num_workers = _dataloader_num_workers
self._amp_enabled = _amp_enabled
self._tmp_ckpt_filepath = _tmp_ckpt_filepath
@_manage_ctx_fit
def fit(self, x, y) -> Self:
''' ... '''
# for PyTorch computational efficiency
torch.set_num_threads(1)
# initialize neural network
self.net_ = self._init_net()
# initialize dataloaders
ldr_trn, ldr_vld = self._init_dataloader(x, y)
# initialize optimizer and scheduler
optimizer = self._init_optimizer()
scheduler = self._init_scheduler(optimizer)
# gradient scaler for AMP
if self._amp_enabled: scaler = torch.cuda.amp.GradScaler()
# initialize loss function (binary cross entropy)
loss_func = self._init_loss_func({
k: (
sum([_[k] == 0 for _ in ldr_trn.dataset.tgt]),
sum([_[k] == 1 for _ in ldr_trn.dataset.tgt]),
) for k in self.tgt_modalities
})
# to record the best validation performance criterion
if self.criterion is not None: best_crit = None
# progress bar for epoch loops
if self.verbose == 1:
with self._lock if self._lock is not None else suppress():
pbr_epoch = tqdm.tqdm(
desc = 'Rank {:02d}'.format(self._rank),
total = self.num_epochs,
position = self._rank,
ascii = True,
leave = False,
bar_format='{l_bar}{r_bar}'
)
# training loop
for epoch in range(self.num_epochs):
# progress bar for batch loops
if self.verbose > 1:
pbr_batch = ProgressBar(len(ldr_trn.dataset), 'Epoch {:03d} (TRN)'.format(epoch))
# set model to train mode
torch.set_grad_enabled(True)
self.net_.train()
scores_trn: dict[str, list[float]] = {k: [] for k in self.tgt_modalities}
y_true_trn: dict[str, list[int]] = {k: [] for k in self.tgt_modalities}
losses_trn: dict[str, list[float]] = {k: [] for k in self.tgt_modalities}
for n_iter, (x_batch, y_batch, _, mask_y) in enumerate(ldr_trn):
# mount data to the proper device
x_batch = {k: x_batch[k].to(self.device) for k in self.src_modalities}
y_batch = {k: y_batch[k].to(torch.float).to(self.device) for k in self.tgt_modalities}
# mask_x = {k: mask_x[k].to(self.device) for k in self.src_modalities}
mask_y = {k: mask_y[k].to(self.device) for k in self.tgt_modalities}
# forward
with torch.autocast(
device_type = 'cpu' if self.device == 'cpu' else 'cuda',
dtype = torch.bfloat16 if self.device == 'cpu' else torch.float16,
enabled = self._amp_enabled,
):
outputs = self.net_(x_batch)
# calculate multitask loss
loss = 0
for i, tgt_k in enumerate(self.tgt_modalities):
loss_k = loss_func[tgt_k](outputs[tgt_k], y_batch[tgt_k])
loss_k = torch.masked_select(loss_k, torch.logical_not(mask_y[tgt_k].squeeze()))
loss += loss_k.mean()
losses_trn[tgt_k] += loss_k.detach().cpu().numpy().tolist()
# backward
if self._amp_enabled:
scaler.scale(loss).backward()
else:
loss.backward()
# update parameters
if n_iter != 0 and n_iter % self.batch_size_multiplier == 0:
if self._amp_enabled:
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
else:
optimizer.step()
optimizer.zero_grad()
# save outputs to evaluate performance later
for tgt_k in self.tgt_modalities:
tmp = torch.masked_select(outputs[tgt_k], torch.logical_not(mask_y[tgt_k].squeeze()))
scores_trn[tgt_k] += tmp.detach().cpu().numpy().tolist()
tmp = torch.masked_select(y_batch[tgt_k], torch.logical_not(mask_y[tgt_k].squeeze()))
y_true_trn[tgt_k] += tmp.cpu().numpy().tolist()
# update progress bar
if self.verbose > 1:
batch_size = len(next(iter(x_batch.values())))
pbr_batch.update(batch_size, {})
pbr_batch.refresh()
# for better tqdm progress bar display
if self.verbose > 1:
pbr_batch.close()
# set scheduler
scheduler.step()
# calculate and print training performance metrics
y_pred_trn: dict[str, list[int]] = {k: [] for k in self.tgt_modalities}
y_prob_trn: dict[str, list[float]] = {k: [] for k in self.tgt_modalities}
for tgt_k in self.tgt_modalities:
for i in range(len(scores_trn[tgt_k])):
y_pred_trn[tgt_k].append(1 if scores_trn[tgt_k][i] > 0 else 0)
y_prob_trn[tgt_k].append(expit(scores_trn[tgt_k][i]))
met_trn = get_metrics_multitask(y_true_trn, y_pred_trn, y_prob_trn)
# add loss to metrics
for tgt_k in self.tgt_modalities:
met_trn[tgt_k]['Loss'] = np.mean(losses_trn[tgt_k])
if self.verbose > 2:
print_metrics_multitask(met_trn)
# progress bar for validation
if self.verbose > 1:
pbr_batch = ProgressBar(len(ldr_vld.dataset), 'Epoch {:03d} (VLD)'.format(epoch))
# set model to validation mode
torch.set_grad_enabled(False)
self.net_.eval()
scores_vld: dict[str, list[float]] = {k: [] for k in self.tgt_modalities}
y_true_vld: dict[str, list[int]] = {k: [] for k in self.tgt_modalities}
losses_vld: dict[str, list[float]] = {k: [] for k in self.tgt_modalities}
for x_batch, y_batch, _, mask_y in ldr_vld:
# mount data to the proper device
x_batch = {k: x_batch[k].to(self.device) for k in self.src_modalities}
y_batch = {k: y_batch[k].to(torch.float).to(self.device) for k in self.tgt_modalities}
# mask_x = {k: mask_x[k].to(self.device) for k in self.src_modalities}
mask_y = {k: mask_y[k].to(self.device) for k in self.tgt_modalities}
# forward
with torch.autocast(
device_type = 'cpu' if self.device == 'cpu' else 'cuda',
dtype = torch.bfloat16 if self.device == 'cpu' else torch.float16,
enabled = self._amp_enabled
):
outputs = self.net_(x_batch)
# calculate multitask loss
for i, tgt_k in enumerate(self.tgt_modalities):
loss_k = loss_func[tgt_k](outputs[tgt_k], y_batch[tgt_k])
loss_k = torch.masked_select(loss_k, torch.logical_not(mask_y[tgt_k].squeeze()))
losses_vld[tgt_k] += loss_k.detach().cpu().numpy().tolist()
# save outputs to evaluate performance later
for tgt_k in self.tgt_modalities:
tmp = torch.masked_select(outputs[tgt_k], torch.logical_not(mask_y[tgt_k].squeeze()))
scores_vld[tgt_k] += tmp.detach().cpu().numpy().tolist()
tmp = torch.masked_select(y_batch[tgt_k], torch.logical_not(mask_y[tgt_k].squeeze()))
y_true_vld[tgt_k] += tmp.cpu().numpy().tolist()
# update progress bar
if self.verbose > 1:
batch_size = len(next(iter(x_batch.values())))
pbr_batch.update(batch_size, {})
pbr_batch.refresh()
# for better tqdm progress bar display
if self.verbose > 1:
pbr_batch.close()
# calculate and print validation performance metrics
y_pred_vld: dict[str, list[int]] = {k: [] for k in self.tgt_modalities}
y_prob_vld: dict[str, list[float]] = {k: [] for k in self.tgt_modalities}
for tgt_k in self.tgt_modalities:
for i in range(len(scores_vld[tgt_k])):
y_pred_vld[tgt_k].append(1 if scores_vld[tgt_k][i] > 0 else 0)
y_prob_vld[tgt_k].append(expit(scores_vld[tgt_k][i]))
met_vld = get_metrics_multitask(y_true_vld, y_pred_vld, y_prob_vld)
# add loss to metrics
for tgt_k in self.tgt_modalities:
met_vld[tgt_k]['Loss'] = np.mean(losses_vld[tgt_k])
if self.verbose > 2:
print_metrics_multitask(met_vld)
# save the model if it has the best validation performance criterion by far
if self.criterion is None: continue
# is current criterion better than previous best?
curr_crit = np.mean([met_vld[k][self.criterion] for k in self.tgt_modalities])
if best_crit is None or np.isnan(best_crit):
is_better = True
elif self.criterion == 'Loss' and best_crit >= curr_crit:
is_better = True
elif self.criterion != 'Loss' and best_crit <= curr_crit:
is_better = True
else:
is_better = False
# update best criterion
if is_better:
best_crit = curr_crit
best_state_dict = deepcopy(self.net_.state_dict())
if self._tmp_ckpt_filepath is not None:
self.save(self._tmp_ckpt_filepath)
if self.verbose > 2:
print('Best {}: {}'.format(self.criterion, best_crit))
if self.verbose == 1:
with self._lock if self._lock is not None else suppress():
pbr_epoch.update(1)
pbr_epoch.refresh()
if self.verbose == 1:
with self._lock if self._lock is not None else suppress():
pbr_epoch.close()
# restore the model of the best validation performance across all epoches
if ldr_vld is not None and self.criterion is not None:
self.net_.load_state_dict(best_state_dict)
return self
def predict_logits(self,
x: list[dict[str, Any]],
_batch_size: int | None = None,
) -> list[dict[str, float]]:
"""
The input x can be a single sample or a list of samples.
"""
# input validation
check_is_fitted(self)
# for PyTorch computational efficiency
torch.set_num_threads(1)
# set model to eval mode
torch.set_grad_enabled(False)
self.net_.eval()
# intialize dataset and dataloader object
dat = TransformerTestingDataset(x, self.src_modalities)
ldr = DataLoader(
dataset = dat,
batch_size = _batch_size if _batch_size is not None else len(x),
shuffle = False,
drop_last = False,
num_workers = 0,
collate_fn = TransformerTestingDataset.collate_fn,
)
# run model and collect results
logits: list[dict[str, float]] = []
for x_batch, _ in ldr:
# mount data to the proper device
x_batch = {k: x_batch[k].to(self.device) for k in self.src_modalities}
# forward
output: dict[str, Tensor] = self.net_(x_batch)
# convert output from dict-of-list to list of dict, then append
tmp = {k: output[k].tolist() for k in self.tgt_modalities}
tmp = [{k: tmp[k][i] for k in self.tgt_modalities} for i in range(len(next(iter(tmp.values()))))]
logits += tmp
return logits
def predict_proba(self,
x: list[dict[str, Any]],
temperature: float = 1.0,
_batch_size: int | None = None,
) -> list[dict[str, float]]:
''' ... '''
logits = self.predict_logits(x, _batch_size)
return [{k: expit(smp[k] / temperature) for k in self.tgt_modalities} for smp in logits]
def predict(self,
x: list[dict[str, Any]],
_batch_size: int | None = None,
) -> list[dict[str, int]]:
''' ... '''
logits = self.predict_logits(x, _batch_size)
return [{k: int(smp[k] > 0.0) for k in self.tgt_modalities} for smp in logits]
def save(self, filepath: str) -> None:
''' ... '''
check_is_fitted(self)
state_dict = self.net_.state_dict()
# attach model hyper parameters
state_dict['src_modalities'] = self.src_modalities
state_dict['tgt_modalities'] = self.tgt_modalities
print('Saving model checkpoint to {} ... '.format(filepath), end='')
torch.save(state_dict, filepath)
print('Done.')
def load(self, filepath: str) -> None:
''' ... '''
# load state_dict
state_dict = torch.load(filepath, map_location='cpu')
# load essential parameters
self.src_modalities: dict[str, dict[str, Any]] = state_dict.pop('src_modalities')
self.tgt_modalities: dict[str, dict[str, Any]] = state_dict.pop('tgt_modalities')
# initialize model
self.net_ = nn.CNNResNet3DWithLinearClassifier(
self.src_modalities,
self.tgt_modalities,
)
# load model parameters
self.net_.load_state_dict(state_dict)
self.to(self.device)
def to(self, device: str) -> Self:
''' Mount model to the given device. '''
self.device = device
if hasattr(self, 'net_'): self.net_ = self.net_.to(device)
return self
@classmethod
def from_ckpt(cls, filepath: str) -> Self:
''' ... '''
obj = cls(None, None)
obj.load(filepath)
return obj
def _init_net(self):
""" ... """
net = nn.CNNResNet3DWithLinearClassifier(
self.src_modalities,
self.tgt_modalities,
).to(self.device)
# train on multiple GPUs using torch.nn.DataParallel
if self._device_ids is not None:
net = torch.nn.DataParallel(net, device_ids=self._device_ids)
# intialize model parameters using xavier_uniform
for p in net.parameters():
if p.dim() > 1:
torch.nn.init.xavier_uniform_(p)
return net
def _init_dataloader(self, x, y):
""" ... """
# split dataset
x_trn, x_vld, y_trn, y_vld = train_test_split(
x, y, test_size = 0.2, random_state = 0,
)
# initialize dataset and dataloader
# dat_trn = TransformerTrainingDataset(
dat_trn = Transformer2ndOrderBalancedTrainingDataset(
x_trn, y_trn,
self.src_modalities,
self.tgt_modalities,
dropout_rate = .5,
# dropout_strategy = 'compensated',
dropout_strategy = 'permutation',
)
dat_vld = TransformerValidationDataset(
x_vld, y_vld,
self.src_modalities,
self.tgt_modalities,
)
ldr_trn = DataLoader(
dataset = dat_trn,
batch_size = self.batch_size,
shuffle = True,
drop_last = False,
num_workers = self._dataloader_num_workers,
collate_fn = TransformerTrainingDataset.collate_fn,
# pin_memory = True
)
ldr_vld = DataLoader(
dataset = dat_vld,
batch_size = self.batch_size,
shuffle = False,
drop_last = False,
num_workers = self._dataloader_num_workers,
collate_fn = TransformerValidationDataset.collate_fn,
# pin_memory = True
)
return ldr_trn, ldr_vld
def _init_optimizer(self):
""" ... """
return torch.optim.AdamW(
self.net_.parameters(),
lr = self.lr,
betas = (0.9, 0.98),
weight_decay = self.weight_decay
)
def _init_scheduler(self, optimizer):
""" ... """
return torch.optim.lr_scheduler.OneCycleLR(
optimizer = optimizer,
max_lr = self.lr,
total_steps = self.num_epochs,
verbose = (self.verbose > 2)
)
def _init_loss_func(self,
num_per_cls: dict[str, tuple[int, int]],
) -> dict[str, Module]:
""" ... """
return {k: nn.SigmoidFocalLoss(
beta = self.beta,
gamma = self.gamma,
scale = self.scale,
num_per_cls = num_per_cls[k],
reduction = 'none',
) for k in self.tgt_modalities}