Add Brain2Vec-v2 files and model card
Browse files- .DS_Store +0 -0
- README.md +119 -3
- autoencoder_final.pth +3 -0
- create_csv.py +39 -0
- discriminator_final.pth +3 -0
- inference_brain2vec.py +240 -0
- inputs_example.csv +6 -0
- requirements.txt +21 -0
- train_brain2vec.py +526 -0
.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
README.md
CHANGED
@@ -1,3 +1,119 @@
|
|
1 |
-
---
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
base_model: radiata-ai/brain2vec
|
3 |
+
license: apache-2.0
|
4 |
+
language:
|
5 |
+
- en
|
6 |
+
task_categories:
|
7 |
+
- image-classification
|
8 |
+
tags:
|
9 |
+
- medical
|
10 |
+
- brain-data
|
11 |
+
- mri
|
12 |
+
pretty_name: 3D Brain Structure MRI Autoencoder
|
13 |
+
---
|
14 |
+
|
15 |
+
## 🧠 Model Summary
|
16 |
+
# brain2vec
|
17 |
+
Version 2 of an autoencoder model for brain structure T1 MRIs (forked from [Brain Latent Progression](https://github.com/LemuelPuglisi/BrLP/tree/main)). The autoencoder takes in a 3d MRI NIfTI file and compresses to 1200 latent dimensions before reconstructing the image. The loss functions for training the autoencoder are:
|
18 |
+
- [L1Loss](https://pytorch.org/docs/stable/generated/torch.nn.L1Loss.html)
|
19 |
+
- [KLDivergenceLoss](https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html)
|
20 |
+
- [PatchAdversarialLoss](https://docs.monai.io/en/stable/losses.html#patchadversarialloss)
|
21 |
+
- [PerceptualLoss](https://docs.monai.io/en/stable/losses.html#perceptualloss)
|
22 |
+
|
23 |
+
|
24 |
+
# Training data
|
25 |
+
[Radiata brain-structure](https://huggingface.co/datasets/radiata-ai/brain-structure): 3066 scans from 2085 individuals in the 'train' split. Mean age = 45.1 +- 24.5, including 2847 scans from cognitively normal subjects and 219 scans from individuals with an Alzheimer's disease clinical diagnosis.
|
26 |
+
|
27 |
+
|
28 |
+
# Example usage
|
29 |
+
```
|
30 |
+
# get brain2vec model repository
|
31 |
+
git clone https://huggingface.co/radiata-ai/brain2vec
|
32 |
+
cd brain2vec
|
33 |
+
|
34 |
+
# pull pre-trained model weights
|
35 |
+
sudo apt-get update
|
36 |
+
sudo apt install git-lfs
|
37 |
+
git lfs install
|
38 |
+
git lfs pull
|
39 |
+
|
40 |
+
# set up virtual environemt
|
41 |
+
python3 -m venv venv_brain2vec
|
42 |
+
source venv_brain2vec/bin/activate
|
43 |
+
|
44 |
+
# install Python libraries
|
45 |
+
pip install -r requirements.txt
|
46 |
+
|
47 |
+
# create the csv file inputs.csv listing the scan paths and other info
|
48 |
+
# this script loads the radiata-ai/brain-structure dataset from Hugging Face
|
49 |
+
python create_csv.py
|
50 |
+
|
51 |
+
mkdir ae_cache
|
52 |
+
mkdir ae_output
|
53 |
+
|
54 |
+
# train the model
|
55 |
+
nohup python train_brain2vec.py \
|
56 |
+
--dataset_csv inputs.csv \
|
57 |
+
--cache_dir ./ae_cache \
|
58 |
+
--output_dir ./ae_output \
|
59 |
+
--n_epochs 10 \
|
60 |
+
> train_log.txt 2>&1 &
|
61 |
+
|
62 |
+
# model inference
|
63 |
+
# for a set of scans in inputs.csv
|
64 |
+
python inference_brain2vec.py \
|
65 |
+
--checkpoint_path /path/to/model.pth \
|
66 |
+
--csv_input inputs.csv \
|
67 |
+
--output_dir ./ae_output \
|
68 |
+
--embeddings_filename ae_embeddings_all.npy
|
69 |
+
|
70 |
+
# or for individual scans
|
71 |
+
python inference_brain2vec.py \
|
72 |
+
--checkpoint_path /path/to/model.pth \
|
73 |
+
--input_images /path/to/img1.nii.gz /path/to/img2.nii.gz \
|
74 |
+
--output_dir ./ae_output \
|
75 |
+
--embeddings_filename ae_embeddings_2.npy
|
76 |
+
```
|
77 |
+
|
78 |
+
# Methods
|
79 |
+
Input scan image dimensions are 113x137x113, 1.5mm^3 resolution, aligned to MNI152 space (see [radiata-ai/brain-structure](https://huggingface.co/datasets/radiata-ai/brain-structure)).
|
80 |
+
|
81 |
+
The image transform crops to 80 x 96 x 80, 2mm^3 resolution, and scales image intensity to range [0,1].
|
82 |
+
|
83 |
+
The model was trained with an effective batch size=16, 10 epochs, learning rate=1e-4 (see references 1 and 2).
|
84 |
+
|
85 |
+
|
86 |
+
# References
|
87 |
+
1. Puglisi L, Alexander DC, Ravì D. Enhancing Spatiotemporal Disease Progression Models via Latent Diffusion and Prior Knowledge [Internet]. arXiv; 2024. Available from: http://arxiv.org/abs/2405.03328
|
88 |
+
2. Pinaya WHL, Tudosiu PD, Dafflon J, Costa PF da, Fernandez V, Nachev P, et al. Brain Imaging Generation with Latent Diffusion Models [Internet]. arXiv; 2022. Available from: http://arxiv.org/abs/2209.07162
|
89 |
+
|
90 |
+
|
91 |
+
# Citation
|
92 |
+
```
|
93 |
+
@misc{Radiata-Brain2vec,
|
94 |
+
author = {Jesse Brown and Clayton Young},
|
95 |
+
title = {Brain2vec: An Autoencoder Model for Brain Structure T1 MRIs},
|
96 |
+
year = {2025},
|
97 |
+
url = {https://huggingface.co/radiata-ai/brain2vec},
|
98 |
+
note = {Version 1.0},
|
99 |
+
publisher = {Hugging Face}
|
100 |
+
}
|
101 |
+
```
|
102 |
+
|
103 |
+
|
104 |
+
# License
|
105 |
+
### Apache License 2.0
|
106 |
+
|
107 |
+
Copyright 2025 Jesse Brown
|
108 |
+
|
109 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
110 |
+
you may not use this file except in compliance with the License.
|
111 |
+
You may obtain a copy of the License at:
|
112 |
+
|
113 |
+
[http://www.apache.org/licenses/LICENSE-2.0](http://www.apache.org/licenses/LICENSE-2.0)
|
114 |
+
|
115 |
+
Unless required by applicable law or agreed to in writing, software
|
116 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
117 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
118 |
+
See the License for the specific language governing permissions and
|
119 |
+
limitations under the License.
|
autoencoder_final.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f66d42dd3bd58c39a110497ea463c35e52dfed097274338a37cb2efbfc4bf11c
|
3 |
+
size 339644650
|
create_csv.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
import os
|
3 |
+
import pandas as pd
|
4 |
+
from datasets import load_dataset
|
5 |
+
|
6 |
+
def row_to_dict(row, split_name):
|
7 |
+
return {
|
8 |
+
"image_uid": row["id"],
|
9 |
+
"age": int(row["metadata"]["age"]),
|
10 |
+
"sex": 1 if row["metadata"]["sex"].lower() == "male" else 2,
|
11 |
+
"image_path": os.path.abspath(row["nii_filepath"]),
|
12 |
+
"split": split_name
|
13 |
+
}
|
14 |
+
|
15 |
+
def main():
|
16 |
+
# Load the datasets
|
17 |
+
ds_train = load_dataset("radiata-ai/brain-structure", split="train", trust_remote_code=True)
|
18 |
+
ds_val = load_dataset("radiata-ai/brain-structure", split="validation", trust_remote_code=True)
|
19 |
+
ds_test = load_dataset("radiata-ai/brain-structure", split="test", trust_remote_code=True)
|
20 |
+
|
21 |
+
rows = []
|
22 |
+
|
23 |
+
# Process each split
|
24 |
+
for data_row in ds_train:
|
25 |
+
rows.append(row_to_dict(data_row, "train"))
|
26 |
+
for data_row in ds_val:
|
27 |
+
rows.append(row_to_dict(data_row, "validation"))
|
28 |
+
for data_row in ds_test:
|
29 |
+
rows.append(row_to_dict(data_row, "test"))
|
30 |
+
|
31 |
+
# Create a DataFrame and write it to CSV
|
32 |
+
df = pd.DataFrame(rows)
|
33 |
+
output_csv = "inputs.csv"
|
34 |
+
df.to_csv(output_csv, index=False)
|
35 |
+
print(f"CSV file created: {output_csv}")
|
36 |
+
|
37 |
+
if __name__ == "__main__":
|
38 |
+
main()
|
39 |
+
|
discriminator_final.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4ac993e13040de22843bb077dbb77cf0904e0aafced8429ba3ec3adfb47b3d02
|
3 |
+
size 11099084
|
inference_brain2vec.py
ADDED
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
"""
|
4 |
+
inference_brain2vec.py
|
5 |
+
|
6 |
+
Loads a pretrained Brain2vec VAE (AutoencoderKL) model and performs inference
|
7 |
+
on one or more MRI images, generating reconstructions and latent parameters
|
8 |
+
(z_mu, z_sigma).
|
9 |
+
|
10 |
+
Example usage:
|
11 |
+
|
12 |
+
# 1) Multiple file paths
|
13 |
+
python inference_brain2vec.py \
|
14 |
+
--checkpoint_path /path/to/autoencoder_checkpoint.pth \
|
15 |
+
--input_images /path/to/img1.nii.gz /path/to/img2.nii.gz \
|
16 |
+
--output_dir ./vae_inference_outputs \
|
17 |
+
--device cuda
|
18 |
+
|
19 |
+
# 2) Use a CSV containing image paths
|
20 |
+
python inference_brain2vec.py \
|
21 |
+
--checkpoint_path /path/to/autoencoder_checkpoint.pth \
|
22 |
+
--csv_input /path/to/images.csv \
|
23 |
+
--output_dir ./vae_inference_outputs
|
24 |
+
"""
|
25 |
+
|
26 |
+
import os
|
27 |
+
import argparse
|
28 |
+
import numpy as np
|
29 |
+
import torch
|
30 |
+
import torch.nn as nn
|
31 |
+
from typing import Optional
|
32 |
+
from monai.transforms import (
|
33 |
+
Compose,
|
34 |
+
CopyItemsD,
|
35 |
+
LoadImageD,
|
36 |
+
EnsureChannelFirstD,
|
37 |
+
SpacingD,
|
38 |
+
ResizeWithPadOrCropD,
|
39 |
+
ScaleIntensityD,
|
40 |
+
)
|
41 |
+
from generative.networks.nets import AutoencoderKL
|
42 |
+
import pandas as pd
|
43 |
+
|
44 |
+
|
45 |
+
RESOLUTION = 2
|
46 |
+
INPUT_SHAPE_AE = (80, 96, 80)
|
47 |
+
|
48 |
+
transforms_fn = Compose([
|
49 |
+
CopyItemsD(keys={'image_path'}, names=['image']),
|
50 |
+
LoadImageD(image_only=True, keys=['image']),
|
51 |
+
EnsureChannelFirstD(keys=['image']),
|
52 |
+
SpacingD(pixdim=RESOLUTION, keys=['image']),
|
53 |
+
ResizeWithPadOrCropD(spatial_size=INPUT_SHAPE_AE, mode='minimum', keys=['image']),
|
54 |
+
ScaleIntensityD(minv=0, maxv=1, keys=['image']),
|
55 |
+
])
|
56 |
+
|
57 |
+
|
58 |
+
def preprocess_mri(image_path: str, device: str = "cpu") -> torch.Tensor:
|
59 |
+
"""
|
60 |
+
Preprocess an MRI using MONAI transforms to produce
|
61 |
+
a 5D tensor (batch=1, channel=1, D, H, W) for inference.
|
62 |
+
|
63 |
+
Args:
|
64 |
+
image_path (str): Path to the MRI (e.g. .nii.gz).
|
65 |
+
device (str): Device to place the tensor on.
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
torch.Tensor: Shape (1, 1, D, H, W).
|
69 |
+
"""
|
70 |
+
data_dict = {"image_path": image_path}
|
71 |
+
output_dict = transforms_fn(data_dict)
|
72 |
+
image_tensor = output_dict["image"] # shape: (1, D, H, W)
|
73 |
+
image_tensor = image_tensor.unsqueeze(0) # => (1, 1, D, H, W)
|
74 |
+
return image_tensor.to(device)
|
75 |
+
|
76 |
+
|
77 |
+
class Brain2vec(AutoencoderKL):
|
78 |
+
"""
|
79 |
+
Subclass of MONAI's AutoencoderKL that includes:
|
80 |
+
- a from_pretrained(...) for loading a .pth checkpoint
|
81 |
+
- uses the existing forward(...) that returns (reconstruction, z_mu, z_sigma)
|
82 |
+
|
83 |
+
Usage:
|
84 |
+
>>> model = Brain2vec.from_pretrained("my_checkpoint.pth", device="cuda")
|
85 |
+
>>> image_tensor = preprocess_mri("/path/to/mri.nii.gz", device="cuda")
|
86 |
+
>>> reconstruction, z_mu, z_sigma = model.forward(image_tensor)
|
87 |
+
"""
|
88 |
+
|
89 |
+
@staticmethod
|
90 |
+
def from_pretrained(
|
91 |
+
checkpoint_path: Optional[str] = None,
|
92 |
+
device: str = "cpu"
|
93 |
+
) -> nn.Module:
|
94 |
+
"""
|
95 |
+
Load a pretrained Brain2vec (AutoencoderKL) if a checkpoint_path is provided.
|
96 |
+
Otherwise, return an uninitialized model.
|
97 |
+
|
98 |
+
Args:
|
99 |
+
checkpoint_path (Optional[str]): Path to a .pth checkpoint file.
|
100 |
+
device (str): "cpu", "cuda", "mps", etc.
|
101 |
+
|
102 |
+
Returns:
|
103 |
+
nn.Module: The loaded Brain2vec model on the chosen device.
|
104 |
+
"""
|
105 |
+
model = Brain2vec(
|
106 |
+
spatial_dims=3,
|
107 |
+
in_channels=1,
|
108 |
+
out_channels=1,
|
109 |
+
latent_channels=1,
|
110 |
+
num_channels=(64, 128, 256, 512),
|
111 |
+
num_res_blocks=2,
|
112 |
+
norm_num_groups=32,
|
113 |
+
norm_eps=1e-06,
|
114 |
+
attention_levels=(False, False, False, False),
|
115 |
+
with_decoder_nonlocal_attn=False,
|
116 |
+
with_encoder_nonlocal_attn=False,
|
117 |
+
)
|
118 |
+
|
119 |
+
if checkpoint_path is not None:
|
120 |
+
if not os.path.exists(checkpoint_path):
|
121 |
+
raise FileNotFoundError(f"Checkpoint {checkpoint_path} not found.")
|
122 |
+
state_dict = torch.load(checkpoint_path, map_location=device)
|
123 |
+
model.load_state_dict(state_dict)
|
124 |
+
|
125 |
+
model.to(device)
|
126 |
+
model.eval()
|
127 |
+
return model
|
128 |
+
|
129 |
+
|
130 |
+
def main() -> None:
|
131 |
+
"""
|
132 |
+
Main function to parse command-line arguments and run inference
|
133 |
+
with a pretrained Brain2vec model.
|
134 |
+
"""
|
135 |
+
parser = argparse.ArgumentParser(
|
136 |
+
description="Inference script for a Brain2vec (VAE) model."
|
137 |
+
)
|
138 |
+
parser.add_argument(
|
139 |
+
"--checkpoint_path", type=str, required=True,
|
140 |
+
help="Path to the .pth checkpoint of the pretrained Brain2vec model."
|
141 |
+
)
|
142 |
+
parser.add_argument(
|
143 |
+
"--output_dir", type=str, default="./vae_inference_outputs",
|
144 |
+
help="Directory to save reconstructions and latent parameters."
|
145 |
+
)
|
146 |
+
# Two ways to supply images: multiple file paths or a CSV
|
147 |
+
parser.add_argument(
|
148 |
+
"--input_images", type=str, nargs="*",
|
149 |
+
help="One or more MRI file paths (e.g. .nii.gz)."
|
150 |
+
)
|
151 |
+
parser.add_argument(
|
152 |
+
"--csv_input", type=str,
|
153 |
+
help="Path to a CSV file with an 'image_path' column."
|
154 |
+
)
|
155 |
+
parser.add_argument(
|
156 |
+
"--embeddings_filename",
|
157 |
+
type=str,
|
158 |
+
required=True,
|
159 |
+
help="Filename (in output_dir) to save the stacked z_mu embeddings (e.g. 'all_z_mu.npy')."
|
160 |
+
)
|
161 |
+
parser.add_argument(
|
162 |
+
"--save_recons",
|
163 |
+
action="store_true",
|
164 |
+
help="If set, saves each reconstruction as .npy. Default is not to save."
|
165 |
+
)
|
166 |
+
|
167 |
+
args = parser.parse_args()
|
168 |
+
|
169 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
170 |
+
|
171 |
+
# After parsing args, add:
|
172 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
173 |
+
|
174 |
+
# Then pass that device to the model:
|
175 |
+
model = Brain2vec.from_pretrained(
|
176 |
+
checkpoint_path=args.checkpoint_path,
|
177 |
+
device=device
|
178 |
+
)
|
179 |
+
|
180 |
+
# Gather image paths
|
181 |
+
if args.csv_input:
|
182 |
+
df = pd.read_csv(args.csv_input)
|
183 |
+
if "image_path" not in df.columns:
|
184 |
+
raise ValueError("CSV must contain a column named 'image_path'.")
|
185 |
+
image_paths = df["image_path"].tolist()
|
186 |
+
else:
|
187 |
+
if not args.input_images:
|
188 |
+
raise ValueError("Must provide either --csv_input or --input_images.")
|
189 |
+
image_paths = args.input_images
|
190 |
+
|
191 |
+
# Lists for stacking latent parameters later
|
192 |
+
all_z_mu = []
|
193 |
+
all_z_sigma = []
|
194 |
+
|
195 |
+
# Inference on each image
|
196 |
+
for i, img_path in enumerate(image_paths):
|
197 |
+
if not os.path.exists(img_path):
|
198 |
+
raise FileNotFoundError(f"Image not found: {img_path}")
|
199 |
+
|
200 |
+
print(f"[INFO] Processing image {i}: {img_path}")
|
201 |
+
img_tensor = preprocess_mri(img_path, device=device)
|
202 |
+
|
203 |
+
with torch.no_grad():
|
204 |
+
recon, z_mu, z_sigma = model.forward(img_tensor)
|
205 |
+
|
206 |
+
# Convert to NumPy
|
207 |
+
recon_np = recon.detach().cpu().numpy() # shape: (1, 1, D, H, W)
|
208 |
+
z_mu_np = z_mu.detach().cpu().numpy() # shape: (1, latent_channels, ...)
|
209 |
+
z_sigma_np = z_sigma.detach().cpu().numpy()
|
210 |
+
|
211 |
+
# Save each reconstruction (per image) as .npy
|
212 |
+
if args.save_recons:
|
213 |
+
recon_path = os.path.join(args.output_dir, f"reconstruction_{i}.npy")
|
214 |
+
np.save(recon_path, recon_np)
|
215 |
+
print(f"[INFO] Saved reconstruction to {recon_path}")
|
216 |
+
|
217 |
+
# Store latent parameters for optional combined saving
|
218 |
+
all_z_mu.append(z_mu_np)
|
219 |
+
all_z_sigma.append(z_sigma_np)
|
220 |
+
|
221 |
+
# Combine latent parameters from all images and save
|
222 |
+
stacked_mu = np.concatenate(all_z_mu, axis=0) # e.g., shape (N, latent_channels, ...)
|
223 |
+
stacked_sigma = np.concatenate(all_z_sigma, axis=0) # e.g., shape (N, latent_channels, ...)
|
224 |
+
|
225 |
+
mu_filename = args.embeddings_filename
|
226 |
+
if not mu_filename.lower().endswith(".npy"):
|
227 |
+
mu_filename += ".npy"
|
228 |
+
|
229 |
+
mu_path = os.path.join(args.output_dir, mu_filename)
|
230 |
+
sigma_path = os.path.join(args.output_dir, "all_z_sigma.npy")
|
231 |
+
|
232 |
+
np.save(mu_path, stacked_mu)
|
233 |
+
np.save(sigma_path, stacked_sigma)
|
234 |
+
|
235 |
+
print(f"[INFO] Saved z_mu of shape {stacked_mu.shape} to {mu_path}")
|
236 |
+
print(f"[INFO] Saved z_sigma of shape {stacked_sigma.shape} to {sigma_path}")
|
237 |
+
|
238 |
+
|
239 |
+
if __name__ == "__main__":
|
240 |
+
main()
|
inputs_example.csv
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
image_uid,age,sex,image_path,split
|
2 |
+
0,81,2,/Users/jbrown2/.cache/huggingface/datasets/downloads/extracted/6429865a89f9ae54df1c3c2db5d0f1f25cf7dd43cb87704d76ed08cf8c194aba/OASIS-2/sub-OASIS20133/ses-03/anat/msub-OASIS20133_ses-03_T1w_brain_affine_mni.nii.gz,train
|
3 |
+
1,78,2,/Users/jbrown2/.cache/huggingface/datasets/downloads/extracted/6429865a89f9ae54df1c3c2db5d0f1f25cf7dd43cb87704d76ed08cf8c194aba/OASIS-2/sub-OASIS20133/ses-01/anat/msub-OASIS20133_ses-01_T1w_brain_affine_mni.nii.gz,train
|
4 |
+
2,87,1,/Users/jbrown2/.cache/huggingface/datasets/downloads/extracted/6429865a89f9ae54df1c3c2db5d0f1f25cf7dd43cb87704d76ed08cf8c194aba/OASIS-2/sub-OASIS20105/ses-02/anat/msub-OASIS20105_ses-02_T1w_brain_affine_mni.nii.gz,train
|
5 |
+
3,86,1,/Users/jbrown2/.cache/huggingface/datasets/downloads/extracted/6429865a89f9ae54df1c3c2db5d0f1f25cf7dd43cb87704d76ed08cf8c194aba/OASIS-2/sub-OASIS20105/ses-01/anat/msub-OASIS20105_ses-01_T1w_brain_affine_mni.nii.gz,train
|
6 |
+
4,84,1,/Users/jbrown2/.cache/huggingface/datasets/downloads/extracted/6429865a89f9ae54df1c3c2db5d0f1f25cf7dd43cb87704d76ed08cf8c194aba/OASIS-2/sub-OASIS20102/ses-02/anat/msub-OASIS20102_ses-02_T1w_brain_affine_mni.nii.gz,train
|
requirements.txt
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# PyTorch (CUDA or CPU version).
|
2 |
+
torch>=1.12
|
3 |
+
|
4 |
+
# Install MONAI Generative first
|
5 |
+
monai-generative
|
6 |
+
|
7 |
+
# Install the latest MONAI directly from GitHub (development version)
|
8 |
+
git+https://github.com/Project-MONAI/MONAI.git#egg=monai
|
9 |
+
|
10 |
+
# For perceptual losses in MONAI's generative module.
|
11 |
+
lpips
|
12 |
+
|
13 |
+
# Common Python libraries
|
14 |
+
pandas
|
15 |
+
numpy
|
16 |
+
nibabel
|
17 |
+
tqdm
|
18 |
+
tensorboard
|
19 |
+
matplotlib
|
20 |
+
datasets
|
21 |
+
|
train_brain2vec.py
ADDED
@@ -0,0 +1,526 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
"""
|
4 |
+
train_brain2vec.py
|
5 |
+
|
6 |
+
Trains a 3D VAE-based Brain2Vec model using MONAI. This script implements
|
7 |
+
autoencoder training with adversarial loss (via a patch discriminator),
|
8 |
+
a perceptual loss, and KL divergence regularization for robust latent
|
9 |
+
representations.
|
10 |
+
|
11 |
+
Example usage:
|
12 |
+
python train_brain2vec.py \
|
13 |
+
--dataset_csv inputs.csv \
|
14 |
+
--cache_dir ./ae_cache \
|
15 |
+
--output_dir ./ae_output \
|
16 |
+
--n_epochs 10
|
17 |
+
"""
|
18 |
+
|
19 |
+
import os
|
20 |
+
os.environ["PYTORCH_WEIGHTS_ONLY"] = "False"
|
21 |
+
from typing import Optional, Union
|
22 |
+
import pandas as pd
|
23 |
+
import argparse
|
24 |
+
import numpy as np
|
25 |
+
import warnings
|
26 |
+
import torch
|
27 |
+
import torch.nn as nn
|
28 |
+
from torch import Tensor
|
29 |
+
from torch.optim.optimizer import Optimizer
|
30 |
+
from torch.nn import L1Loss
|
31 |
+
from torch.utils.data import DataLoader
|
32 |
+
from torch.amp import autocast
|
33 |
+
from torch.amp import GradScaler
|
34 |
+
from generative.networks.nets import (
|
35 |
+
AutoencoderKL,
|
36 |
+
PatchDiscriminator,
|
37 |
+
)
|
38 |
+
from generative.losses import PerceptualLoss, PatchAdversarialLoss
|
39 |
+
from monai.data import Dataset, PersistentDataset
|
40 |
+
from monai.transforms.transform import Transform
|
41 |
+
from monai import transforms
|
42 |
+
from monai.utils import set_determinism
|
43 |
+
from monai.data.meta_tensor import MetaTensor
|
44 |
+
import torch.serialization
|
45 |
+
from numpy.core.multiarray import _reconstruct
|
46 |
+
from numpy import ndarray, dtype
|
47 |
+
torch.serialization.add_safe_globals([_reconstruct])
|
48 |
+
torch.serialization.add_safe_globals([MetaTensor])
|
49 |
+
torch.serialization.add_safe_globals([ndarray])
|
50 |
+
torch.serialization.add_safe_globals([dtype])
|
51 |
+
from tqdm import tqdm
|
52 |
+
import matplotlib.pyplot as plt
|
53 |
+
from torch.utils.tensorboard import SummaryWriter
|
54 |
+
|
55 |
+
# voxel resolution
|
56 |
+
RESOLUTION = 2
|
57 |
+
|
58 |
+
# shape of the MNI152 (1mm^3) template
|
59 |
+
INPUT_SHAPE_1mm = (182, 218, 182)
|
60 |
+
|
61 |
+
# resampling the MNI152 to (1.5mm^3)
|
62 |
+
INPUT_SHAPE_1p5mm = (122, 146, 122)
|
63 |
+
|
64 |
+
# Adjusting the dimensions to be divisible by 8 (2^3 where 3 are the downsampling layers of the AE)
|
65 |
+
#INPUT_SHAPE_AE = (120, 144, 120)
|
66 |
+
INPUT_SHAPE_AE = (80, 96, 80)
|
67 |
+
|
68 |
+
# Latent shape of the autoencoder
|
69 |
+
LATENT_SHAPE_AE = (1, 10, 12, 10)
|
70 |
+
|
71 |
+
|
72 |
+
def load_if(checkpoints_path: Optional[str], network: nn.Module) -> nn.Module:
|
73 |
+
"""
|
74 |
+
Load pretrained weights if available.
|
75 |
+
|
76 |
+
Args:
|
77 |
+
checkpoints_path (Optional[str]): path of the checkpoints
|
78 |
+
network (nn.Module): the neural network to initialize
|
79 |
+
|
80 |
+
Returns:
|
81 |
+
nn.Module: the initialized neural network
|
82 |
+
"""
|
83 |
+
if checkpoints_path is not None:
|
84 |
+
assert os.path.exists(checkpoints_path), 'Invalid path'
|
85 |
+
network.load_state_dict(torch.load(checkpoints_path))
|
86 |
+
return network
|
87 |
+
|
88 |
+
|
89 |
+
def init_autoencoder(checkpoints_path: Optional[str] = None) -> nn.Module:
|
90 |
+
"""
|
91 |
+
Load the KL autoencoder (pretrained if `checkpoints_path` points to previous params).
|
92 |
+
|
93 |
+
Args:
|
94 |
+
checkpoints_path (Optional[str], optional): path of the checkpoints. Defaults to None.
|
95 |
+
|
96 |
+
Returns:
|
97 |
+
nn.Module: the KL autoencoder
|
98 |
+
"""
|
99 |
+
autoencoder = AutoencoderKL(spatial_dims=3,
|
100 |
+
in_channels=1,
|
101 |
+
out_channels=1,
|
102 |
+
latent_channels=1, #3,
|
103 |
+
num_channels=(64, 128, 256, 512),
|
104 |
+
num_res_blocks=2,
|
105 |
+
norm_num_groups=32,
|
106 |
+
norm_eps=1e-06,
|
107 |
+
attention_levels=(False, False, False, False),
|
108 |
+
with_decoder_nonlocal_attn=False,
|
109 |
+
with_encoder_nonlocal_attn=False)
|
110 |
+
return load_if(checkpoints_path, autoencoder)
|
111 |
+
|
112 |
+
|
113 |
+
def init_patch_discriminator(checkpoints_path: Optional[str] = None) -> nn.Module:
|
114 |
+
"""
|
115 |
+
Load the patch discriminator (pretrained if `checkpoints_path` points to previous params).
|
116 |
+
|
117 |
+
Args:
|
118 |
+
checkpoints_path (Optional[str], optional): path of the checkpoints. Defaults to None.
|
119 |
+
|
120 |
+
Returns:
|
121 |
+
nn.Module: the patch discriminator
|
122 |
+
"""
|
123 |
+
patch_discriminator = PatchDiscriminator(spatial_dims=3,
|
124 |
+
num_layers_d=3,
|
125 |
+
num_channels=32,
|
126 |
+
in_channels=1,
|
127 |
+
out_channels=1)
|
128 |
+
return load_if(checkpoints_path, patch_discriminator)
|
129 |
+
|
130 |
+
|
131 |
+
class KLDivergenceLoss:
|
132 |
+
"""
|
133 |
+
A class for computing the Kullback-Leibler divergence loss.
|
134 |
+
"""
|
135 |
+
|
136 |
+
def __call__(self, z_mu: Tensor, z_sigma: Tensor) -> Tensor:
|
137 |
+
"""
|
138 |
+
Computes the KL divergence loss for the given parameters.
|
139 |
+
|
140 |
+
Args:
|
141 |
+
z_mu (Tensor): The mean of the distribution.
|
142 |
+
z_sigma (Tensor): The standard deviation of the distribution.
|
143 |
+
|
144 |
+
Returns:
|
145 |
+
Tensor: The computed KL divergence loss, averaged over the batch size.
|
146 |
+
"""
|
147 |
+
|
148 |
+
kl_loss = 0.5 * torch.sum(z_mu.pow(2) + z_sigma.pow(2) - torch.log(z_sigma.pow(2)) - 1, dim=[1, 2, 3, 4])
|
149 |
+
return torch.sum(kl_loss) / kl_loss.shape[0]
|
150 |
+
|
151 |
+
|
152 |
+
class GradientAccumulation:
|
153 |
+
"""
|
154 |
+
Implements gradient accumulation to facilitate training with larger
|
155 |
+
effective batch sizes than what can be physically accommodated in memory.
|
156 |
+
"""
|
157 |
+
|
158 |
+
def __init__(self,
|
159 |
+
actual_batch_size: int,
|
160 |
+
expect_batch_size: int,
|
161 |
+
loader_len: int,
|
162 |
+
optimizer: Optimizer,
|
163 |
+
grad_scaler: Optional[GradScaler] = None) -> None:
|
164 |
+
"""
|
165 |
+
Initializes the GradientAccumulation instance with the necessary parameters for
|
166 |
+
managing gradient accumulation.
|
167 |
+
|
168 |
+
Args:
|
169 |
+
actual_batch_size (int): The size of the mini-batches actually used in training.
|
170 |
+
expect_batch_size (int): The desired (effective) batch size to simulate through gradient accumulation.
|
171 |
+
loader_len (int): The length of the data loader, representing the total number of mini-batches.
|
172 |
+
optimizer (Optimizer): The optimizer used for performing optimization steps.
|
173 |
+
grad_scaler (Optional[GradScaler], optional): A GradScaler for mixed precision training. Defaults to None.
|
174 |
+
|
175 |
+
Raises:
|
176 |
+
AssertionError: If `expect_batch_size` is not divisible by `actual_batch_size`.
|
177 |
+
"""
|
178 |
+
|
179 |
+
assert expect_batch_size % actual_batch_size == 0, \
|
180 |
+
'expect_batch_size must be divisible by actual_batch_size'
|
181 |
+
self.actual_batch_size = actual_batch_size
|
182 |
+
self.expect_batch_size = expect_batch_size
|
183 |
+
self.loader_len = loader_len
|
184 |
+
self.optimizer = optimizer
|
185 |
+
self.grad_scaler = grad_scaler
|
186 |
+
|
187 |
+
# if the expected batch size is N=KM, and the actual batch size
|
188 |
+
# is M, then we need to accumulate gradient from N / M = K optimization steps.
|
189 |
+
self.steps_until_update = expect_batch_size / actual_batch_size
|
190 |
+
|
191 |
+
def step(self, loss: Tensor, step: int) -> None:
|
192 |
+
"""
|
193 |
+
Performs a backward pass for the given loss and potentially executes an optimization
|
194 |
+
step if the conditions for gradient accumulation are met. The optimization step is taken
|
195 |
+
only after a specified number of steps (defined by the expected batch size) or at the end
|
196 |
+
of the dataset.
|
197 |
+
|
198 |
+
Args:
|
199 |
+
loss (Tensor): The loss value for the current forward pass.
|
200 |
+
step (int): The current step (mini-batch index) within the epoch.
|
201 |
+
"""
|
202 |
+
loss = loss / self.expect_batch_size
|
203 |
+
|
204 |
+
if self.grad_scaler is not None:
|
205 |
+
self.grad_scaler.scale(loss).backward()
|
206 |
+
else:
|
207 |
+
loss.backward()
|
208 |
+
if (step + 1) % self.steps_until_update == 0 or (step + 1) == self.loader_len:
|
209 |
+
if self.grad_scaler is not None:
|
210 |
+
self.grad_scaler.step(self.optimizer)
|
211 |
+
self.grad_scaler.update()
|
212 |
+
else:
|
213 |
+
self.optimizer.step()
|
214 |
+
self.optimizer.zero_grad(set_to_none=True)
|
215 |
+
|
216 |
+
|
217 |
+
class AverageLoss:
|
218 |
+
"""
|
219 |
+
Utility class to track losses
|
220 |
+
and metrics during training.
|
221 |
+
"""
|
222 |
+
|
223 |
+
def __init__(self):
|
224 |
+
self.losses_accumulator = {}
|
225 |
+
|
226 |
+
def put(self, loss_key:str, loss_value:Union[int,float]) -> None:
|
227 |
+
"""
|
228 |
+
Store value
|
229 |
+
|
230 |
+
Args:
|
231 |
+
loss_key (str): Metric name
|
232 |
+
loss_value (int | float): Metric value to store
|
233 |
+
"""
|
234 |
+
if loss_key not in self.losses_accumulator:
|
235 |
+
self.losses_accumulator[loss_key] = []
|
236 |
+
self.losses_accumulator[loss_key].append(loss_value)
|
237 |
+
|
238 |
+
def pop_avg(self, loss_key:str) -> float:
|
239 |
+
"""
|
240 |
+
Average the stored values of a given metric
|
241 |
+
|
242 |
+
Args:
|
243 |
+
loss_key (str): Metric name
|
244 |
+
|
245 |
+
Returns:
|
246 |
+
float: average of the stored values
|
247 |
+
"""
|
248 |
+
if loss_key not in self.losses_accumulator:
|
249 |
+
return None
|
250 |
+
losses = self.losses_accumulator[loss_key]
|
251 |
+
self.losses_accumulator[loss_key] = []
|
252 |
+
return sum(losses) / len(losses)
|
253 |
+
|
254 |
+
def to_tensorboard(self, writer: SummaryWriter, step: int):
|
255 |
+
"""
|
256 |
+
Logs the average value of all the metrics stored
|
257 |
+
into Tensorboard.
|
258 |
+
|
259 |
+
Args:
|
260 |
+
writer (SummaryWriter): Tensorboard writer
|
261 |
+
step (int): Tensorboard logging global step
|
262 |
+
"""
|
263 |
+
for metric_key in self.losses_accumulator.keys():
|
264 |
+
writer.add_scalar(metric_key, self.pop_avg(metric_key), step)
|
265 |
+
|
266 |
+
|
267 |
+
def get_dataset_from_pd(df: pd.DataFrame, transforms_fn: Transform, cache_dir: Optional[str]) -> Union[Dataset,PersistentDataset]:
|
268 |
+
"""
|
269 |
+
If `cache_dir` is defined, returns a `monai.data.PersistenDataset`.
|
270 |
+
Otherwise, returns a simple `monai.data.Dataset`.
|
271 |
+
|
272 |
+
Args:
|
273 |
+
df (pd.DataFrame): Dataframe describing each image in the longitudinal dataset.
|
274 |
+
transforms_fn (Transform): Set of transformations
|
275 |
+
cache_dir (Optional[str]): Cache directory (ensure enough storage is available)
|
276 |
+
|
277 |
+
Returns:
|
278 |
+
Dataset|PersistentDataset: The dataset
|
279 |
+
"""
|
280 |
+
assert cache_dir is None or os.path.exists(cache_dir), 'Invalid cache directory path'
|
281 |
+
data = df.to_dict(orient='records')
|
282 |
+
return Dataset(data=data, transform=transforms_fn) if cache_dir is None \
|
283 |
+
else PersistentDataset(data=data, transform=transforms_fn, cache_dir=cache_dir)
|
284 |
+
|
285 |
+
|
286 |
+
def tb_display_reconstruction(writer, step, image, recon):
|
287 |
+
"""
|
288 |
+
Display reconstruction in TensorBoard during AE training.
|
289 |
+
"""
|
290 |
+
plt.style.use('dark_background')
|
291 |
+
_, ax = plt.subplots(ncols=3, nrows=2, figsize=(7, 5))
|
292 |
+
for _ax in ax.flatten(): _ax.set_axis_off()
|
293 |
+
|
294 |
+
if len(image.shape) == 4: image = image.squeeze(0)
|
295 |
+
if len(recon.shape) == 4: recon = recon.squeeze(0)
|
296 |
+
|
297 |
+
ax[0, 0].set_title('original image', color='cyan')
|
298 |
+
ax[0, 0].imshow(image[image.shape[0] // 2, :, :], cmap='gray')
|
299 |
+
ax[0, 1].imshow(image[:, image.shape[1] // 2, :], cmap='gray')
|
300 |
+
ax[0, 2].imshow(image[:, :, image.shape[2] // 2], cmap='gray')
|
301 |
+
|
302 |
+
ax[1, 0].set_title('reconstructed image', color='magenta')
|
303 |
+
ax[1, 0].imshow(recon[recon.shape[0] // 2, :, :], cmap='gray')
|
304 |
+
ax[1, 1].imshow(recon[:, recon.shape[1] // 2, :], cmap='gray')
|
305 |
+
ax[1, 2].imshow(recon[:, :, recon.shape[2] // 2], cmap='gray')
|
306 |
+
|
307 |
+
plt.tight_layout()
|
308 |
+
writer.add_figure('Reconstruction', plt.gcf(), global_step=step)
|
309 |
+
|
310 |
+
|
311 |
+
def set_environment(seed: int = 0) -> None:
|
312 |
+
"""
|
313 |
+
Set deterministic behavior for reproducibility.
|
314 |
+
|
315 |
+
Args:
|
316 |
+
seed (int, optional): Seed value. Defaults to 0.
|
317 |
+
"""
|
318 |
+
set_determinism(seed)
|
319 |
+
|
320 |
+
|
321 |
+
def train(
|
322 |
+
dataset_csv: str,
|
323 |
+
cache_dir: str,
|
324 |
+
output_dir: str,
|
325 |
+
aekl_ckpt: Optional[str] = None,
|
326 |
+
disc_ckpt: Optional[str] = None,
|
327 |
+
num_workers: int = 8,
|
328 |
+
n_epochs: int = 5,
|
329 |
+
max_batch_size: int = 2,
|
330 |
+
batch_size: int = 16,
|
331 |
+
lr: float = 1e-4,
|
332 |
+
aug_p: float = 0.8,
|
333 |
+
device: str = ('cuda' if torch.cuda.is_available() else
|
334 |
+
'cpu'),
|
335 |
+
) -> None:
|
336 |
+
"""
|
337 |
+
Train the autoencoder and discriminator models.
|
338 |
+
|
339 |
+
Args:
|
340 |
+
dataset_csv (str): Path to the dataset CSV file.
|
341 |
+
cache_dir (str): Directory for caching data.
|
342 |
+
output_dir (str): Directory to save model checkpoints.
|
343 |
+
aekl_ckpt (Optional[str], optional): Path to the autoencoder checkpoint. Defaults to None.
|
344 |
+
disc_ckpt (Optional[str], optional): Path to the discriminator checkpoint. Defaults to None.
|
345 |
+
num_workers (int, optional): Number of data loader workers. Defaults to 8.
|
346 |
+
n_epochs (int, optional): Number of training epochs. Defaults to 5.
|
347 |
+
max_batch_size (int, optional): Actual batch size per iteration. Defaults to 2.
|
348 |
+
batch_size (int, optional): Expected (effective) batch size. Defaults to 16.
|
349 |
+
lr (float, optional): Learning rate. Defaults to 1e-4.
|
350 |
+
aug_p (float, optional): Augmentation probability. Defaults to 0.8.
|
351 |
+
device (str, optional): Device to run the training on. Defaults to 'cuda' if available.
|
352 |
+
"""
|
353 |
+
set_environment(0)
|
354 |
+
|
355 |
+
transforms_fn = transforms.Compose([
|
356 |
+
transforms.CopyItemsD(keys={'image_path'}, names=['image']),
|
357 |
+
transforms.LoadImageD(image_only=True, keys=['image']),
|
358 |
+
transforms.EnsureChannelFirstD(keys=['image']),
|
359 |
+
transforms.SpacingD(pixdim=2, keys=['image']),
|
360 |
+
transforms.ResizeWithPadOrCropD(spatial_size=(80, 96, 80), mode='minimum', keys=['image']),
|
361 |
+
transforms.ScaleIntensityD(minv=0, maxv=1, keys=['image'])
|
362 |
+
])
|
363 |
+
|
364 |
+
dataset_df = pd.read_csv(dataset_csv)
|
365 |
+
train_df = dataset_df[dataset_df.split == 'train']
|
366 |
+
trainset = get_dataset_from_pd(train_df, transforms_fn, cache_dir)
|
367 |
+
|
368 |
+
train_loader = DataLoader(
|
369 |
+
dataset=trainset,
|
370 |
+
num_workers=num_workers,
|
371 |
+
batch_size=max_batch_size,
|
372 |
+
shuffle=True,
|
373 |
+
persistent_workers=True,
|
374 |
+
pin_memory=True,
|
375 |
+
)
|
376 |
+
|
377 |
+
print('Device is %s' %(device))
|
378 |
+
autoencoder = init_autoencoder(aekl_ckpt).to(device)
|
379 |
+
discriminator = init_patch_discriminator(disc_ckpt).to(device)
|
380 |
+
|
381 |
+
# Loss Weights
|
382 |
+
adv_weight = 0.025
|
383 |
+
perceptual_weight = 0.001
|
384 |
+
kl_weight = 1e-7
|
385 |
+
|
386 |
+
# Loss Functions
|
387 |
+
l1_loss_fn = L1Loss()
|
388 |
+
kl_loss_fn = KLDivergenceLoss()
|
389 |
+
adv_loss_fn = PatchAdversarialLoss(criterion="least_squares")
|
390 |
+
|
391 |
+
with warnings.catch_warnings():
|
392 |
+
warnings.simplefilter("ignore")
|
393 |
+
perc_loss_fn = PerceptualLoss(
|
394 |
+
spatial_dims=3,
|
395 |
+
network_type="squeeze",
|
396 |
+
is_fake_3d=True,
|
397 |
+
fake_3d_ratio=0.2
|
398 |
+
).to(device)
|
399 |
+
|
400 |
+
# Optimizers
|
401 |
+
optimizer_g = torch.optim.Adam(autoencoder.parameters(), lr=lr)
|
402 |
+
optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=lr)
|
403 |
+
|
404 |
+
# Gradient Accumulation
|
405 |
+
gradacc_g = GradientAccumulation(
|
406 |
+
actual_batch_size=max_batch_size,
|
407 |
+
expect_batch_size=batch_size,
|
408 |
+
loader_len=len(train_loader),
|
409 |
+
optimizer=optimizer_g,
|
410 |
+
grad_scaler=GradScaler()
|
411 |
+
)
|
412 |
+
|
413 |
+
gradacc_d = GradientAccumulation(
|
414 |
+
actual_batch_size=max_batch_size,
|
415 |
+
expect_batch_size=batch_size,
|
416 |
+
loader_len=len(train_loader),
|
417 |
+
optimizer=optimizer_d,
|
418 |
+
grad_scaler=GradScaler()
|
419 |
+
)
|
420 |
+
|
421 |
+
# Logging
|
422 |
+
avgloss = AverageLoss()
|
423 |
+
writer = SummaryWriter()
|
424 |
+
total_counter = 0
|
425 |
+
|
426 |
+
for epoch in range(n_epochs):
|
427 |
+
print(f"[DEBUG] Starting epoch {epoch}/{n_epochs-1}")
|
428 |
+
autoencoder.train()
|
429 |
+
progress_bar = tqdm(enumerate(train_loader), total=len(train_loader))
|
430 |
+
progress_bar.set_description(f'Epoch {epoch}')
|
431 |
+
|
432 |
+
for step, batch in progress_bar:
|
433 |
+
# Generator Training
|
434 |
+
with autocast(device, enabled=True):
|
435 |
+
images = batch["image"].to(device)
|
436 |
+
reconstruction, z_mu, z_sigma = autoencoder(images)
|
437 |
+
|
438 |
+
logits_fake = discriminator(reconstruction.contiguous().float())[-1]
|
439 |
+
|
440 |
+
rec_loss = l1_loss_fn(reconstruction.float(), images.float())
|
441 |
+
kl_loss = kl_weight * kl_loss_fn(z_mu, z_sigma)
|
442 |
+
per_loss = perceptual_weight * perc_loss_fn(reconstruction.float(), images.float())
|
443 |
+
gen_loss = adv_weight * adv_loss_fn(logits_fake, target_is_real=True, for_discriminator=False)
|
444 |
+
|
445 |
+
loss_g = rec_loss + kl_loss + per_loss + gen_loss
|
446 |
+
|
447 |
+
gradacc_g.step(loss_g, step)
|
448 |
+
|
449 |
+
# Discriminator Training
|
450 |
+
with autocast(device, enabled=True):
|
451 |
+
logits_fake = discriminator(reconstruction.contiguous().detach())[-1]
|
452 |
+
d_loss_fake = adv_loss_fn(logits_fake, target_is_real=False, for_discriminator=True)
|
453 |
+
logits_real = discriminator(images.contiguous().detach())[-1]
|
454 |
+
d_loss_real = adv_loss_fn(logits_real, target_is_real=True, for_discriminator=True)
|
455 |
+
discriminator_loss = (d_loss_fake + d_loss_real) * 0.5
|
456 |
+
loss_d = adv_weight * discriminator_loss
|
457 |
+
|
458 |
+
gradacc_d.step(loss_d, step)
|
459 |
+
|
460 |
+
# Logging
|
461 |
+
avgloss.put('Generator/reconstruction_loss', rec_loss.item())
|
462 |
+
avgloss.put('Generator/perceptual_loss', per_loss.item())
|
463 |
+
avgloss.put('Generator/adversarial_loss', gen_loss.item())
|
464 |
+
avgloss.put('Generator/kl_regularization', kl_loss.item())
|
465 |
+
avgloss.put('Discriminator/adversarial_loss', loss_d.item())
|
466 |
+
|
467 |
+
if total_counter % 10 == 0:
|
468 |
+
step_log = total_counter // 10
|
469 |
+
avgloss.to_tensorboard(writer, step_log)
|
470 |
+
tb_display_reconstruction(
|
471 |
+
writer,
|
472 |
+
step_log,
|
473 |
+
images[0].detach().cpu(),
|
474 |
+
reconstruction[0].detach().cpu()
|
475 |
+
)
|
476 |
+
|
477 |
+
total_counter += 1
|
478 |
+
|
479 |
+
# Save the model after each epoch.
|
480 |
+
os.makedirs(output_dir, exist_ok=True)
|
481 |
+
torch.save(discriminator.state_dict(), os.path.join(output_dir, f'discriminator-ep-{epoch}.pth'))
|
482 |
+
torch.save(autoencoder.state_dict(), os.path.join(output_dir, f'autoencoder-ep-{epoch}.pth'))
|
483 |
+
|
484 |
+
writer.close()
|
485 |
+
print("Training completed and models saved.")
|
486 |
+
|
487 |
+
|
488 |
+
def main():
|
489 |
+
"""
|
490 |
+
Main function to parse command-line arguments and run train().
|
491 |
+
"""
|
492 |
+
import argparse
|
493 |
+
|
494 |
+
parser = argparse.ArgumentParser(description="brain2vec Training Script")
|
495 |
+
|
496 |
+
parser.add_argument('--dataset_csv', type=str, required=True, help='Path to the dataset CSV file.')
|
497 |
+
parser.add_argument('--cache_dir', type=str, required=True, help='Directory for caching data.')
|
498 |
+
parser.add_argument('--output_dir', type=str, required=True, help='Directory to save model checkpoints.')
|
499 |
+
parser.add_argument('--aekl_ckpt', type=str, default=None, help='Path to the autoencoder checkpoint.')
|
500 |
+
parser.add_argument('--disc_ckpt', type=str, default=None, help='Path to the discriminator checkpoint.')
|
501 |
+
parser.add_argument('--num_workers', type=int, default=8, help='Number of data loader workers.')
|
502 |
+
parser.add_argument('--n_epochs', type=int, default=5, help='Number of training epochs.')
|
503 |
+
parser.add_argument('--max_batch_size', type=int, default=2, help='Actual batch size per iteration.')
|
504 |
+
parser.add_argument('--batch_size', type=int, default=16, help='Expected (effective) batch size.')
|
505 |
+
parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate.')
|
506 |
+
parser.add_argument('--aug_p', type=float, default=0.8, help='Augmentation probability.')
|
507 |
+
|
508 |
+
args = parser.parse_args()
|
509 |
+
|
510 |
+
train(
|
511 |
+
dataset_csv=args.dataset_csv,
|
512 |
+
cache_dir=args.cache_dir,
|
513 |
+
output_dir=args.output_dir,
|
514 |
+
aekl_ckpt=args.aekl_ckpt,
|
515 |
+
disc_ckpt=args.disc_ckpt,
|
516 |
+
num_workers=args.num_workers,
|
517 |
+
n_epochs=args.n_epochs,
|
518 |
+
max_batch_size=args.max_batch_size,
|
519 |
+
batch_size=args.batch_size,
|
520 |
+
lr=args.lr,
|
521 |
+
aug_p=args.aug_p,
|
522 |
+
)
|
523 |
+
|
524 |
+
|
525 |
+
if __name__ == '__main__':
|
526 |
+
main()
|