|
|
|
|
|
""" |
|
train_brain2vec.py |
|
|
|
Trains a 3D VAE-based Brain2Vec model using MONAI. This script implements |
|
autoencoder training with adversarial loss (via a patch discriminator), |
|
a perceptual loss, and KL divergence regularization for robust latent |
|
representations. |
|
|
|
Example usage: |
|
python train_brain2vec.py \ |
|
--dataset_csv inputs.csv \ |
|
--cache_dir ./ae_cache \ |
|
--output_dir ./ae_output \ |
|
--n_epochs 10 |
|
""" |
|
|
|
import os |
|
os.environ["PYTORCH_WEIGHTS_ONLY"] = "False" |
|
from typing import Optional, Union |
|
import pandas as pd |
|
import argparse |
|
import numpy as np |
|
import warnings |
|
import torch |
|
import torch.nn as nn |
|
from torch import Tensor |
|
from torch.optim.optimizer import Optimizer |
|
from torch.nn import L1Loss |
|
from torch.utils.data import DataLoader |
|
from torch.amp import autocast |
|
from torch.amp import GradScaler |
|
from generative.networks.nets import ( |
|
AutoencoderKL, |
|
PatchDiscriminator, |
|
) |
|
from generative.losses import PerceptualLoss, PatchAdversarialLoss |
|
from monai.data import Dataset, PersistentDataset |
|
from monai.transforms.transform import Transform |
|
from monai import transforms |
|
from monai.utils import set_determinism |
|
from monai.data.meta_tensor import MetaTensor |
|
import torch.serialization |
|
from numpy.core.multiarray import _reconstruct |
|
from numpy import ndarray, dtype |
|
torch.serialization.add_safe_globals([_reconstruct]) |
|
torch.serialization.add_safe_globals([MetaTensor]) |
|
torch.serialization.add_safe_globals([ndarray]) |
|
torch.serialization.add_safe_globals([dtype]) |
|
from tqdm import tqdm |
|
import matplotlib.pyplot as plt |
|
from torch.utils.tensorboard import SummaryWriter |
|
|
|
|
|
RESOLUTION = 2 |
|
|
|
|
|
INPUT_SHAPE_1mm = (182, 218, 182) |
|
|
|
|
|
INPUT_SHAPE_1p5mm = (122, 146, 122) |
|
|
|
|
|
|
|
INPUT_SHAPE_AE = (80, 96, 80) |
|
|
|
|
|
LATENT_SHAPE_AE = (1, 10, 12, 10) |
|
|
|
|
|
def load_if(checkpoints_path: Optional[str], network: nn.Module) -> nn.Module: |
|
""" |
|
Load pretrained weights if available. |
|
|
|
Args: |
|
checkpoints_path (Optional[str]): path of the checkpoints |
|
network (nn.Module): the neural network to initialize |
|
|
|
Returns: |
|
nn.Module: the initialized neural network |
|
""" |
|
if checkpoints_path is not None: |
|
assert os.path.exists(checkpoints_path), 'Invalid path' |
|
network.load_state_dict(torch.load(checkpoints_path)) |
|
return network |
|
|
|
|
|
def init_autoencoder(checkpoints_path: Optional[str] = None) -> nn.Module: |
|
""" |
|
Load the KL autoencoder (pretrained if `checkpoints_path` points to previous params). |
|
|
|
Args: |
|
checkpoints_path (Optional[str], optional): path of the checkpoints. Defaults to None. |
|
|
|
Returns: |
|
nn.Module: the KL autoencoder |
|
""" |
|
autoencoder = AutoencoderKL(spatial_dims=3, |
|
in_channels=1, |
|
out_channels=1, |
|
latent_channels=1, |
|
num_channels=(64, 128, 256, 512), |
|
num_res_blocks=2, |
|
norm_num_groups=32, |
|
norm_eps=1e-06, |
|
attention_levels=(False, False, False, False), |
|
with_decoder_nonlocal_attn=False, |
|
with_encoder_nonlocal_attn=False) |
|
return load_if(checkpoints_path, autoencoder) |
|
|
|
|
|
def init_patch_discriminator(checkpoints_path: Optional[str] = None) -> nn.Module: |
|
""" |
|
Load the patch discriminator (pretrained if `checkpoints_path` points to previous params). |
|
|
|
Args: |
|
checkpoints_path (Optional[str], optional): path of the checkpoints. Defaults to None. |
|
|
|
Returns: |
|
nn.Module: the patch discriminator |
|
""" |
|
patch_discriminator = PatchDiscriminator(spatial_dims=3, |
|
num_layers_d=3, |
|
num_channels=32, |
|
in_channels=1, |
|
out_channels=1) |
|
return load_if(checkpoints_path, patch_discriminator) |
|
|
|
|
|
class KLDivergenceLoss: |
|
""" |
|
A class for computing the Kullback-Leibler divergence loss. |
|
""" |
|
|
|
def __call__(self, z_mu: Tensor, z_sigma: Tensor) -> Tensor: |
|
""" |
|
Computes the KL divergence loss for the given parameters. |
|
|
|
Args: |
|
z_mu (Tensor): The mean of the distribution. |
|
z_sigma (Tensor): The standard deviation of the distribution. |
|
|
|
Returns: |
|
Tensor: The computed KL divergence loss, averaged over the batch size. |
|
""" |
|
|
|
kl_loss = 0.5 * torch.sum(z_mu.pow(2) + z_sigma.pow(2) - torch.log(z_sigma.pow(2)) - 1, dim=[1, 2, 3, 4]) |
|
return torch.sum(kl_loss) / kl_loss.shape[0] |
|
|
|
|
|
class GradientAccumulation: |
|
""" |
|
Implements gradient accumulation to facilitate training with larger |
|
effective batch sizes than what can be physically accommodated in memory. |
|
""" |
|
|
|
def __init__(self, |
|
actual_batch_size: int, |
|
expect_batch_size: int, |
|
loader_len: int, |
|
optimizer: Optimizer, |
|
grad_scaler: Optional[GradScaler] = None) -> None: |
|
""" |
|
Initializes the GradientAccumulation instance with the necessary parameters for |
|
managing gradient accumulation. |
|
|
|
Args: |
|
actual_batch_size (int): The size of the mini-batches actually used in training. |
|
expect_batch_size (int): The desired (effective) batch size to simulate through gradient accumulation. |
|
loader_len (int): The length of the data loader, representing the total number of mini-batches. |
|
optimizer (Optimizer): The optimizer used for performing optimization steps. |
|
grad_scaler (Optional[GradScaler], optional): A GradScaler for mixed precision training. Defaults to None. |
|
|
|
Raises: |
|
AssertionError: If `expect_batch_size` is not divisible by `actual_batch_size`. |
|
""" |
|
|
|
assert expect_batch_size % actual_batch_size == 0, \ |
|
'expect_batch_size must be divisible by actual_batch_size' |
|
self.actual_batch_size = actual_batch_size |
|
self.expect_batch_size = expect_batch_size |
|
self.loader_len = loader_len |
|
self.optimizer = optimizer |
|
self.grad_scaler = grad_scaler |
|
|
|
|
|
|
|
self.steps_until_update = expect_batch_size / actual_batch_size |
|
|
|
def step(self, loss: Tensor, step: int) -> None: |
|
""" |
|
Performs a backward pass for the given loss and potentially executes an optimization |
|
step if the conditions for gradient accumulation are met. The optimization step is taken |
|
only after a specified number of steps (defined by the expected batch size) or at the end |
|
of the dataset. |
|
|
|
Args: |
|
loss (Tensor): The loss value for the current forward pass. |
|
step (int): The current step (mini-batch index) within the epoch. |
|
""" |
|
loss = loss / self.expect_batch_size |
|
|
|
if self.grad_scaler is not None: |
|
self.grad_scaler.scale(loss).backward() |
|
else: |
|
loss.backward() |
|
if (step + 1) % self.steps_until_update == 0 or (step + 1) == self.loader_len: |
|
if self.grad_scaler is not None: |
|
self.grad_scaler.step(self.optimizer) |
|
self.grad_scaler.update() |
|
else: |
|
self.optimizer.step() |
|
self.optimizer.zero_grad(set_to_none=True) |
|
|
|
|
|
class AverageLoss: |
|
""" |
|
Utility class to track losses |
|
and metrics during training. |
|
""" |
|
|
|
def __init__(self): |
|
self.losses_accumulator = {} |
|
|
|
def put(self, loss_key:str, loss_value:Union[int,float]) -> None: |
|
""" |
|
Store value |
|
|
|
Args: |
|
loss_key (str): Metric name |
|
loss_value (int | float): Metric value to store |
|
""" |
|
if loss_key not in self.losses_accumulator: |
|
self.losses_accumulator[loss_key] = [] |
|
self.losses_accumulator[loss_key].append(loss_value) |
|
|
|
def pop_avg(self, loss_key:str) -> float: |
|
""" |
|
Average the stored values of a given metric |
|
|
|
Args: |
|
loss_key (str): Metric name |
|
|
|
Returns: |
|
float: average of the stored values |
|
""" |
|
if loss_key not in self.losses_accumulator: |
|
return None |
|
losses = self.losses_accumulator[loss_key] |
|
self.losses_accumulator[loss_key] = [] |
|
return sum(losses) / len(losses) |
|
|
|
def to_tensorboard(self, writer: SummaryWriter, step: int): |
|
""" |
|
Logs the average value of all the metrics stored |
|
into Tensorboard. |
|
|
|
Args: |
|
writer (SummaryWriter): Tensorboard writer |
|
step (int): Tensorboard logging global step |
|
""" |
|
for metric_key in self.losses_accumulator.keys(): |
|
writer.add_scalar(metric_key, self.pop_avg(metric_key), step) |
|
|
|
|
|
def get_dataset_from_pd(df: pd.DataFrame, transforms_fn: Transform, cache_dir: Optional[str]) -> Union[Dataset,PersistentDataset]: |
|
""" |
|
If `cache_dir` is defined, returns a `monai.data.PersistenDataset`. |
|
Otherwise, returns a simple `monai.data.Dataset`. |
|
|
|
Args: |
|
df (pd.DataFrame): Dataframe describing each image in the longitudinal dataset. |
|
transforms_fn (Transform): Set of transformations |
|
cache_dir (Optional[str]): Cache directory (ensure enough storage is available) |
|
|
|
Returns: |
|
Dataset|PersistentDataset: The dataset |
|
""" |
|
assert cache_dir is None or os.path.exists(cache_dir), 'Invalid cache directory path' |
|
data = df.to_dict(orient='records') |
|
return Dataset(data=data, transform=transforms_fn) if cache_dir is None \ |
|
else PersistentDataset(data=data, transform=transforms_fn, cache_dir=cache_dir) |
|
|
|
|
|
def tb_display_reconstruction(writer, step, image, recon): |
|
""" |
|
Display reconstruction in TensorBoard during AE training. |
|
""" |
|
plt.style.use('dark_background') |
|
_, ax = plt.subplots(ncols=3, nrows=2, figsize=(7, 5)) |
|
for _ax in ax.flatten(): _ax.set_axis_off() |
|
|
|
if len(image.shape) == 4: image = image.squeeze(0) |
|
if len(recon.shape) == 4: recon = recon.squeeze(0) |
|
|
|
ax[0, 0].set_title('original image', color='cyan') |
|
ax[0, 0].imshow(image[image.shape[0] // 2, :, :], cmap='gray') |
|
ax[0, 1].imshow(image[:, image.shape[1] // 2, :], cmap='gray') |
|
ax[0, 2].imshow(image[:, :, image.shape[2] // 2], cmap='gray') |
|
|
|
ax[1, 0].set_title('reconstructed image', color='magenta') |
|
ax[1, 0].imshow(recon[recon.shape[0] // 2, :, :], cmap='gray') |
|
ax[1, 1].imshow(recon[:, recon.shape[1] // 2, :], cmap='gray') |
|
ax[1, 2].imshow(recon[:, :, recon.shape[2] // 2], cmap='gray') |
|
|
|
plt.tight_layout() |
|
writer.add_figure('Reconstruction', plt.gcf(), global_step=step) |
|
|
|
|
|
def set_environment(seed: int = 0) -> None: |
|
""" |
|
Set deterministic behavior for reproducibility. |
|
|
|
Args: |
|
seed (int, optional): Seed value. Defaults to 0. |
|
""" |
|
set_determinism(seed) |
|
|
|
|
|
def train( |
|
dataset_csv: str, |
|
cache_dir: str, |
|
output_dir: str, |
|
aekl_ckpt: Optional[str] = None, |
|
disc_ckpt: Optional[str] = None, |
|
num_workers: int = 8, |
|
n_epochs: int = 5, |
|
max_batch_size: int = 2, |
|
batch_size: int = 16, |
|
lr: float = 1e-4, |
|
aug_p: float = 0.8, |
|
device: str = ('cuda' if torch.cuda.is_available() else |
|
'cpu'), |
|
) -> None: |
|
""" |
|
Train the autoencoder and discriminator models. |
|
|
|
Args: |
|
dataset_csv (str): Path to the dataset CSV file. |
|
cache_dir (str): Directory for caching data. |
|
output_dir (str): Directory to save model checkpoints. |
|
aekl_ckpt (Optional[str], optional): Path to the autoencoder checkpoint. Defaults to None. |
|
disc_ckpt (Optional[str], optional): Path to the discriminator checkpoint. Defaults to None. |
|
num_workers (int, optional): Number of data loader workers. Defaults to 8. |
|
n_epochs (int, optional): Number of training epochs. Defaults to 5. |
|
max_batch_size (int, optional): Actual batch size per iteration. Defaults to 2. |
|
batch_size (int, optional): Expected (effective) batch size. Defaults to 16. |
|
lr (float, optional): Learning rate. Defaults to 1e-4. |
|
aug_p (float, optional): Augmentation probability. Defaults to 0.8. |
|
device (str, optional): Device to run the training on. Defaults to 'cuda' if available. |
|
""" |
|
set_environment(0) |
|
|
|
transforms_fn = transforms.Compose([ |
|
transforms.CopyItemsD(keys={'image_path'}, names=['image']), |
|
transforms.LoadImageD(image_only=True, keys=['image']), |
|
transforms.EnsureChannelFirstD(keys=['image']), |
|
transforms.SpacingD(pixdim=2, keys=['image']), |
|
transforms.ResizeWithPadOrCropD(spatial_size=(80, 96, 80), mode='minimum', keys=['image']), |
|
transforms.ScaleIntensityD(minv=0, maxv=1, keys=['image']) |
|
]) |
|
|
|
dataset_df = pd.read_csv(dataset_csv) |
|
train_df = dataset_df[dataset_df.split == 'train'] |
|
trainset = get_dataset_from_pd(train_df, transforms_fn, cache_dir) |
|
|
|
train_loader = DataLoader( |
|
dataset=trainset, |
|
num_workers=num_workers, |
|
batch_size=max_batch_size, |
|
shuffle=True, |
|
persistent_workers=True, |
|
pin_memory=True, |
|
) |
|
|
|
print('Device is %s' %(device)) |
|
autoencoder = init_autoencoder(aekl_ckpt).to(device) |
|
discriminator = init_patch_discriminator(disc_ckpt).to(device) |
|
|
|
|
|
adv_weight = 0.025 |
|
perceptual_weight = 0.001 |
|
kl_weight = 1e-7 |
|
|
|
|
|
l1_loss_fn = L1Loss() |
|
kl_loss_fn = KLDivergenceLoss() |
|
adv_loss_fn = PatchAdversarialLoss(criterion="least_squares") |
|
|
|
with warnings.catch_warnings(): |
|
warnings.simplefilter("ignore") |
|
perc_loss_fn = PerceptualLoss( |
|
spatial_dims=3, |
|
network_type="squeeze", |
|
is_fake_3d=True, |
|
fake_3d_ratio=0.2 |
|
).to(device) |
|
|
|
|
|
optimizer_g = torch.optim.Adam(autoencoder.parameters(), lr=lr) |
|
optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=lr) |
|
|
|
|
|
gradacc_g = GradientAccumulation( |
|
actual_batch_size=max_batch_size, |
|
expect_batch_size=batch_size, |
|
loader_len=len(train_loader), |
|
optimizer=optimizer_g, |
|
grad_scaler=GradScaler() |
|
) |
|
|
|
gradacc_d = GradientAccumulation( |
|
actual_batch_size=max_batch_size, |
|
expect_batch_size=batch_size, |
|
loader_len=len(train_loader), |
|
optimizer=optimizer_d, |
|
grad_scaler=GradScaler() |
|
) |
|
|
|
|
|
avgloss = AverageLoss() |
|
writer = SummaryWriter() |
|
total_counter = 0 |
|
|
|
for epoch in range(n_epochs): |
|
print(f"[DEBUG] Starting epoch {epoch}/{n_epochs-1}") |
|
autoencoder.train() |
|
progress_bar = tqdm(enumerate(train_loader), total=len(train_loader)) |
|
progress_bar.set_description(f'Epoch {epoch}') |
|
|
|
for step, batch in progress_bar: |
|
|
|
with autocast(device, enabled=True): |
|
images = batch["image"].to(device) |
|
reconstruction, z_mu, z_sigma = autoencoder(images) |
|
|
|
logits_fake = discriminator(reconstruction.contiguous().float())[-1] |
|
|
|
rec_loss = l1_loss_fn(reconstruction.float(), images.float()) |
|
kl_loss = kl_weight * kl_loss_fn(z_mu, z_sigma) |
|
per_loss = perceptual_weight * perc_loss_fn(reconstruction.float(), images.float()) |
|
gen_loss = adv_weight * adv_loss_fn(logits_fake, target_is_real=True, for_discriminator=False) |
|
|
|
loss_g = rec_loss + kl_loss + per_loss + gen_loss |
|
|
|
gradacc_g.step(loss_g, step) |
|
|
|
|
|
with autocast(device, enabled=True): |
|
logits_fake = discriminator(reconstruction.contiguous().detach())[-1] |
|
d_loss_fake = adv_loss_fn(logits_fake, target_is_real=False, for_discriminator=True) |
|
logits_real = discriminator(images.contiguous().detach())[-1] |
|
d_loss_real = adv_loss_fn(logits_real, target_is_real=True, for_discriminator=True) |
|
discriminator_loss = (d_loss_fake + d_loss_real) * 0.5 |
|
loss_d = adv_weight * discriminator_loss |
|
|
|
gradacc_d.step(loss_d, step) |
|
|
|
|
|
avgloss.put('Generator/reconstruction_loss', rec_loss.item()) |
|
avgloss.put('Generator/perceptual_loss', per_loss.item()) |
|
avgloss.put('Generator/adversarial_loss', gen_loss.item()) |
|
avgloss.put('Generator/kl_regularization', kl_loss.item()) |
|
avgloss.put('Discriminator/adversarial_loss', loss_d.item()) |
|
|
|
if total_counter % 10 == 0: |
|
step_log = total_counter // 10 |
|
avgloss.to_tensorboard(writer, step_log) |
|
tb_display_reconstruction( |
|
writer, |
|
step_log, |
|
images[0].detach().cpu(), |
|
reconstruction[0].detach().cpu() |
|
) |
|
|
|
total_counter += 1 |
|
|
|
|
|
os.makedirs(output_dir, exist_ok=True) |
|
torch.save(discriminator.state_dict(), os.path.join(output_dir, f'discriminator-ep-{epoch}.pth')) |
|
torch.save(autoencoder.state_dict(), os.path.join(output_dir, f'autoencoder-ep-{epoch}.pth')) |
|
|
|
writer.close() |
|
print("Training completed and models saved.") |
|
|
|
|
|
def main(): |
|
""" |
|
Main function to parse command-line arguments and run train(). |
|
""" |
|
import argparse |
|
|
|
parser = argparse.ArgumentParser(description="brain2vec Training Script") |
|
|
|
parser.add_argument('--dataset_csv', type=str, required=True, help='Path to the dataset CSV file.') |
|
parser.add_argument('--cache_dir', type=str, required=True, help='Directory for caching data.') |
|
parser.add_argument('--output_dir', type=str, required=True, help='Directory to save model checkpoints.') |
|
parser.add_argument('--aekl_ckpt', type=str, default=None, help='Path to the autoencoder checkpoint.') |
|
parser.add_argument('--disc_ckpt', type=str, default=None, help='Path to the discriminator checkpoint.') |
|
parser.add_argument('--num_workers', type=int, default=8, help='Number of data loader workers.') |
|
parser.add_argument('--n_epochs', type=int, default=5, help='Number of training epochs.') |
|
parser.add_argument('--max_batch_size', type=int, default=2, help='Actual batch size per iteration.') |
|
parser.add_argument('--batch_size', type=int, default=16, help='Expected (effective) batch size.') |
|
parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate.') |
|
parser.add_argument('--aug_p', type=float, default=0.8, help='Augmentation probability.') |
|
|
|
args = parser.parse_args() |
|
|
|
train( |
|
dataset_csv=args.dataset_csv, |
|
cache_dir=args.cache_dir, |
|
output_dir=args.output_dir, |
|
aekl_ckpt=args.aekl_ckpt, |
|
disc_ckpt=args.disc_ckpt, |
|
num_workers=args.num_workers, |
|
n_epochs=args.n_epochs, |
|
max_batch_size=args.max_batch_size, |
|
batch_size=args.batch_size, |
|
lr=args.lr, |
|
aug_p=args.aug_p, |
|
) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|