English
medical
brain-data
mri
jesseab commited on
Commit
0556c0e
·
1 Parent(s): a4d6fb3

Add Brain2Vec-v2 files and model card

Browse files
.DS_Store ADDED
Binary file (6.15 kB). View file
 
README.md CHANGED
@@ -1,3 +1,119 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ base_model: radiata-ai/brain2vec
3
+ license: apache-2.0
4
+ language:
5
+ - en
6
+ task_categories:
7
+ - image-classification
8
+ tags:
9
+ - medical
10
+ - brain-data
11
+ - mri
12
+ pretty_name: 3D Brain Structure MRI Autoencoder
13
+ ---
14
+
15
+ ## 🧠 Model Summary
16
+ # brain2vec
17
+ Version 2 of an autoencoder model for brain structure T1 MRIs (forked from [Brain Latent Progression](https://github.com/LemuelPuglisi/BrLP/tree/main)). The autoencoder takes in a 3d MRI NIfTI file and compresses to 1200 latent dimensions before reconstructing the image. The loss functions for training the autoencoder are:
18
+ - [L1Loss](https://pytorch.org/docs/stable/generated/torch.nn.L1Loss.html)
19
+ - [KLDivergenceLoss](https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html)
20
+ - [PatchAdversarialLoss](https://docs.monai.io/en/stable/losses.html#patchadversarialloss)
21
+ - [PerceptualLoss](https://docs.monai.io/en/stable/losses.html#perceptualloss)
22
+
23
+
24
+ # Training data
25
+ [Radiata brain-structure](https://huggingface.co/datasets/radiata-ai/brain-structure): 3066 scans from 2085 individuals in the 'train' split. Mean age = 45.1 +- 24.5, including 2847 scans from cognitively normal subjects and 219 scans from individuals with an Alzheimer's disease clinical diagnosis.
26
+
27
+
28
+ # Example usage
29
+ ```
30
+ # get brain2vec model repository
31
+ git clone https://huggingface.co/radiata-ai/brain2vec
32
+ cd brain2vec
33
+
34
+ # pull pre-trained model weights
35
+ sudo apt-get update
36
+ sudo apt install git-lfs
37
+ git lfs install
38
+ git lfs pull
39
+
40
+ # set up virtual environemt
41
+ python3 -m venv venv_brain2vec
42
+ source venv_brain2vec/bin/activate
43
+
44
+ # install Python libraries
45
+ pip install -r requirements.txt
46
+
47
+ # create the csv file inputs.csv listing the scan paths and other info
48
+ # this script loads the radiata-ai/brain-structure dataset from Hugging Face
49
+ python create_csv.py
50
+
51
+ mkdir ae_cache
52
+ mkdir ae_output
53
+
54
+ # train the model
55
+ nohup python train_brain2vec.py \
56
+ --dataset_csv inputs.csv \
57
+ --cache_dir ./ae_cache \
58
+ --output_dir ./ae_output \
59
+ --n_epochs 10 \
60
+ > train_log.txt 2>&1 &
61
+
62
+ # model inference
63
+ # for a set of scans in inputs.csv
64
+ python inference_brain2vec.py \
65
+ --checkpoint_path /path/to/model.pth \
66
+ --csv_input inputs.csv \
67
+ --output_dir ./ae_output \
68
+ --embeddings_filename ae_embeddings_all.npy
69
+
70
+ # or for individual scans
71
+ python inference_brain2vec.py \
72
+ --checkpoint_path /path/to/model.pth \
73
+ --input_images /path/to/img1.nii.gz /path/to/img2.nii.gz \
74
+ --output_dir ./ae_output \
75
+ --embeddings_filename ae_embeddings_2.npy
76
+ ```
77
+
78
+ # Methods
79
+ Input scan image dimensions are 113x137x113, 1.5mm^3 resolution, aligned to MNI152 space (see [radiata-ai/brain-structure](https://huggingface.co/datasets/radiata-ai/brain-structure)).
80
+
81
+ The image transform crops to 80 x 96 x 80, 2mm^3 resolution, and scales image intensity to range [0,1].
82
+
83
+ The model was trained with an effective batch size=16, 10 epochs, learning rate=1e-4 (see references 1 and 2).
84
+
85
+
86
+ # References
87
+ 1. Puglisi L, Alexander DC, Ravì D. Enhancing Spatiotemporal Disease Progression Models via Latent Diffusion and Prior Knowledge [Internet]. arXiv; 2024. Available from: http://arxiv.org/abs/2405.03328
88
+ 2. Pinaya WHL, Tudosiu PD, Dafflon J, Costa PF da, Fernandez V, Nachev P, et al. Brain Imaging Generation with Latent Diffusion Models [Internet]. arXiv; 2022. Available from: http://arxiv.org/abs/2209.07162
89
+
90
+
91
+ # Citation
92
+ ```
93
+ @misc{Radiata-Brain2vec,
94
+ author = {Jesse Brown and Clayton Young},
95
+ title = {Brain2vec: An Autoencoder Model for Brain Structure T1 MRIs},
96
+ year = {2025},
97
+ url = {https://huggingface.co/radiata-ai/brain2vec},
98
+ note = {Version 1.0},
99
+ publisher = {Hugging Face}
100
+ }
101
+ ```
102
+
103
+
104
+ # License
105
+ ### Apache License 2.0
106
+
107
+ Copyright 2025 Jesse Brown
108
+
109
+ Licensed under the Apache License, Version 2.0 (the "License");
110
+ you may not use this file except in compliance with the License.
111
+ You may obtain a copy of the License at:
112
+
113
+ [http://www.apache.org/licenses/LICENSE-2.0](http://www.apache.org/licenses/LICENSE-2.0)
114
+
115
+ Unless required by applicable law or agreed to in writing, software
116
+ distributed under the License is distributed on an "AS IS" BASIS,
117
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
118
+ See the License for the specific language governing permissions and
119
+ limitations under the License.
autoencoder_final.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f66d42dd3bd58c39a110497ea463c35e52dfed097274338a37cb2efbfc4bf11c
3
+ size 339644650
create_csv.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import os
3
+ import pandas as pd
4
+ from datasets import load_dataset
5
+
6
+ def row_to_dict(row, split_name):
7
+ return {
8
+ "image_uid": row["id"],
9
+ "age": int(row["metadata"]["age"]),
10
+ "sex": 1 if row["metadata"]["sex"].lower() == "male" else 2,
11
+ "image_path": os.path.abspath(row["nii_filepath"]),
12
+ "split": split_name
13
+ }
14
+
15
+ def main():
16
+ # Load the datasets
17
+ ds_train = load_dataset("radiata-ai/brain-structure", split="train", trust_remote_code=True)
18
+ ds_val = load_dataset("radiata-ai/brain-structure", split="validation", trust_remote_code=True)
19
+ ds_test = load_dataset("radiata-ai/brain-structure", split="test", trust_remote_code=True)
20
+
21
+ rows = []
22
+
23
+ # Process each split
24
+ for data_row in ds_train:
25
+ rows.append(row_to_dict(data_row, "train"))
26
+ for data_row in ds_val:
27
+ rows.append(row_to_dict(data_row, "validation"))
28
+ for data_row in ds_test:
29
+ rows.append(row_to_dict(data_row, "test"))
30
+
31
+ # Create a DataFrame and write it to CSV
32
+ df = pd.DataFrame(rows)
33
+ output_csv = "inputs.csv"
34
+ df.to_csv(output_csv, index=False)
35
+ print(f"CSV file created: {output_csv}")
36
+
37
+ if __name__ == "__main__":
38
+ main()
39
+
discriminator_final.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4ac993e13040de22843bb077dbb77cf0904e0aafced8429ba3ec3adfb47b3d02
3
+ size 11099084
inference_brain2vec.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ """
4
+ inference_brain2vec.py
5
+
6
+ Loads a pretrained Brain2vec VAE (AutoencoderKL) model and performs inference
7
+ on one or more MRI images, generating reconstructions and latent parameters
8
+ (z_mu, z_sigma).
9
+
10
+ Example usage:
11
+
12
+ # 1) Multiple file paths
13
+ python inference_brain2vec.py \
14
+ --checkpoint_path /path/to/autoencoder_checkpoint.pth \
15
+ --input_images /path/to/img1.nii.gz /path/to/img2.nii.gz \
16
+ --output_dir ./vae_inference_outputs \
17
+ --device cuda
18
+
19
+ # 2) Use a CSV containing image paths
20
+ python inference_brain2vec.py \
21
+ --checkpoint_path /path/to/autoencoder_checkpoint.pth \
22
+ --csv_input /path/to/images.csv \
23
+ --output_dir ./vae_inference_outputs
24
+ """
25
+
26
+ import os
27
+ import argparse
28
+ import numpy as np
29
+ import torch
30
+ import torch.nn as nn
31
+ from typing import Optional
32
+ from monai.transforms import (
33
+ Compose,
34
+ CopyItemsD,
35
+ LoadImageD,
36
+ EnsureChannelFirstD,
37
+ SpacingD,
38
+ ResizeWithPadOrCropD,
39
+ ScaleIntensityD,
40
+ )
41
+ from generative.networks.nets import AutoencoderKL
42
+ import pandas as pd
43
+
44
+
45
+ RESOLUTION = 2
46
+ INPUT_SHAPE_AE = (80, 96, 80)
47
+
48
+ transforms_fn = Compose([
49
+ CopyItemsD(keys={'image_path'}, names=['image']),
50
+ LoadImageD(image_only=True, keys=['image']),
51
+ EnsureChannelFirstD(keys=['image']),
52
+ SpacingD(pixdim=RESOLUTION, keys=['image']),
53
+ ResizeWithPadOrCropD(spatial_size=INPUT_SHAPE_AE, mode='minimum', keys=['image']),
54
+ ScaleIntensityD(minv=0, maxv=1, keys=['image']),
55
+ ])
56
+
57
+
58
+ def preprocess_mri(image_path: str, device: str = "cpu") -> torch.Tensor:
59
+ """
60
+ Preprocess an MRI using MONAI transforms to produce
61
+ a 5D tensor (batch=1, channel=1, D, H, W) for inference.
62
+
63
+ Args:
64
+ image_path (str): Path to the MRI (e.g. .nii.gz).
65
+ device (str): Device to place the tensor on.
66
+
67
+ Returns:
68
+ torch.Tensor: Shape (1, 1, D, H, W).
69
+ """
70
+ data_dict = {"image_path": image_path}
71
+ output_dict = transforms_fn(data_dict)
72
+ image_tensor = output_dict["image"] # shape: (1, D, H, W)
73
+ image_tensor = image_tensor.unsqueeze(0) # => (1, 1, D, H, W)
74
+ return image_tensor.to(device)
75
+
76
+
77
+ class Brain2vec(AutoencoderKL):
78
+ """
79
+ Subclass of MONAI's AutoencoderKL that includes:
80
+ - a from_pretrained(...) for loading a .pth checkpoint
81
+ - uses the existing forward(...) that returns (reconstruction, z_mu, z_sigma)
82
+
83
+ Usage:
84
+ >>> model = Brain2vec.from_pretrained("my_checkpoint.pth", device="cuda")
85
+ >>> image_tensor = preprocess_mri("/path/to/mri.nii.gz", device="cuda")
86
+ >>> reconstruction, z_mu, z_sigma = model.forward(image_tensor)
87
+ """
88
+
89
+ @staticmethod
90
+ def from_pretrained(
91
+ checkpoint_path: Optional[str] = None,
92
+ device: str = "cpu"
93
+ ) -> nn.Module:
94
+ """
95
+ Load a pretrained Brain2vec (AutoencoderKL) if a checkpoint_path is provided.
96
+ Otherwise, return an uninitialized model.
97
+
98
+ Args:
99
+ checkpoint_path (Optional[str]): Path to a .pth checkpoint file.
100
+ device (str): "cpu", "cuda", "mps", etc.
101
+
102
+ Returns:
103
+ nn.Module: The loaded Brain2vec model on the chosen device.
104
+ """
105
+ model = Brain2vec(
106
+ spatial_dims=3,
107
+ in_channels=1,
108
+ out_channels=1,
109
+ latent_channels=1,
110
+ num_channels=(64, 128, 256, 512),
111
+ num_res_blocks=2,
112
+ norm_num_groups=32,
113
+ norm_eps=1e-06,
114
+ attention_levels=(False, False, False, False),
115
+ with_decoder_nonlocal_attn=False,
116
+ with_encoder_nonlocal_attn=False,
117
+ )
118
+
119
+ if checkpoint_path is not None:
120
+ if not os.path.exists(checkpoint_path):
121
+ raise FileNotFoundError(f"Checkpoint {checkpoint_path} not found.")
122
+ state_dict = torch.load(checkpoint_path, map_location=device)
123
+ model.load_state_dict(state_dict)
124
+
125
+ model.to(device)
126
+ model.eval()
127
+ return model
128
+
129
+
130
+ def main() -> None:
131
+ """
132
+ Main function to parse command-line arguments and run inference
133
+ with a pretrained Brain2vec model.
134
+ """
135
+ parser = argparse.ArgumentParser(
136
+ description="Inference script for a Brain2vec (VAE) model."
137
+ )
138
+ parser.add_argument(
139
+ "--checkpoint_path", type=str, required=True,
140
+ help="Path to the .pth checkpoint of the pretrained Brain2vec model."
141
+ )
142
+ parser.add_argument(
143
+ "--output_dir", type=str, default="./vae_inference_outputs",
144
+ help="Directory to save reconstructions and latent parameters."
145
+ )
146
+ # Two ways to supply images: multiple file paths or a CSV
147
+ parser.add_argument(
148
+ "--input_images", type=str, nargs="*",
149
+ help="One or more MRI file paths (e.g. .nii.gz)."
150
+ )
151
+ parser.add_argument(
152
+ "--csv_input", type=str,
153
+ help="Path to a CSV file with an 'image_path' column."
154
+ )
155
+ parser.add_argument(
156
+ "--embeddings_filename",
157
+ type=str,
158
+ required=True,
159
+ help="Filename (in output_dir) to save the stacked z_mu embeddings (e.g. 'all_z_mu.npy')."
160
+ )
161
+ parser.add_argument(
162
+ "--save_recons",
163
+ action="store_true",
164
+ help="If set, saves each reconstruction as .npy. Default is not to save."
165
+ )
166
+
167
+ args = parser.parse_args()
168
+
169
+ os.makedirs(args.output_dir, exist_ok=True)
170
+
171
+ # After parsing args, add:
172
+ device = "cuda" if torch.cuda.is_available() else "cpu"
173
+
174
+ # Then pass that device to the model:
175
+ model = Brain2vec.from_pretrained(
176
+ checkpoint_path=args.checkpoint_path,
177
+ device=device
178
+ )
179
+
180
+ # Gather image paths
181
+ if args.csv_input:
182
+ df = pd.read_csv(args.csv_input)
183
+ if "image_path" not in df.columns:
184
+ raise ValueError("CSV must contain a column named 'image_path'.")
185
+ image_paths = df["image_path"].tolist()
186
+ else:
187
+ if not args.input_images:
188
+ raise ValueError("Must provide either --csv_input or --input_images.")
189
+ image_paths = args.input_images
190
+
191
+ # Lists for stacking latent parameters later
192
+ all_z_mu = []
193
+ all_z_sigma = []
194
+
195
+ # Inference on each image
196
+ for i, img_path in enumerate(image_paths):
197
+ if not os.path.exists(img_path):
198
+ raise FileNotFoundError(f"Image not found: {img_path}")
199
+
200
+ print(f"[INFO] Processing image {i}: {img_path}")
201
+ img_tensor = preprocess_mri(img_path, device=device)
202
+
203
+ with torch.no_grad():
204
+ recon, z_mu, z_sigma = model.forward(img_tensor)
205
+
206
+ # Convert to NumPy
207
+ recon_np = recon.detach().cpu().numpy() # shape: (1, 1, D, H, W)
208
+ z_mu_np = z_mu.detach().cpu().numpy() # shape: (1, latent_channels, ...)
209
+ z_sigma_np = z_sigma.detach().cpu().numpy()
210
+
211
+ # Save each reconstruction (per image) as .npy
212
+ if args.save_recons:
213
+ recon_path = os.path.join(args.output_dir, f"reconstruction_{i}.npy")
214
+ np.save(recon_path, recon_np)
215
+ print(f"[INFO] Saved reconstruction to {recon_path}")
216
+
217
+ # Store latent parameters for optional combined saving
218
+ all_z_mu.append(z_mu_np)
219
+ all_z_sigma.append(z_sigma_np)
220
+
221
+ # Combine latent parameters from all images and save
222
+ stacked_mu = np.concatenate(all_z_mu, axis=0) # e.g., shape (N, latent_channels, ...)
223
+ stacked_sigma = np.concatenate(all_z_sigma, axis=0) # e.g., shape (N, latent_channels, ...)
224
+
225
+ mu_filename = args.embeddings_filename
226
+ if not mu_filename.lower().endswith(".npy"):
227
+ mu_filename += ".npy"
228
+
229
+ mu_path = os.path.join(args.output_dir, mu_filename)
230
+ sigma_path = os.path.join(args.output_dir, "all_z_sigma.npy")
231
+
232
+ np.save(mu_path, stacked_mu)
233
+ np.save(sigma_path, stacked_sigma)
234
+
235
+ print(f"[INFO] Saved z_mu of shape {stacked_mu.shape} to {mu_path}")
236
+ print(f"[INFO] Saved z_sigma of shape {stacked_sigma.shape} to {sigma_path}")
237
+
238
+
239
+ if __name__ == "__main__":
240
+ main()
inputs_example.csv ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ image_uid,age,sex,image_path,split
2
+ 0,81,2,/Users/jbrown2/.cache/huggingface/datasets/downloads/extracted/6429865a89f9ae54df1c3c2db5d0f1f25cf7dd43cb87704d76ed08cf8c194aba/OASIS-2/sub-OASIS20133/ses-03/anat/msub-OASIS20133_ses-03_T1w_brain_affine_mni.nii.gz,train
3
+ 1,78,2,/Users/jbrown2/.cache/huggingface/datasets/downloads/extracted/6429865a89f9ae54df1c3c2db5d0f1f25cf7dd43cb87704d76ed08cf8c194aba/OASIS-2/sub-OASIS20133/ses-01/anat/msub-OASIS20133_ses-01_T1w_brain_affine_mni.nii.gz,train
4
+ 2,87,1,/Users/jbrown2/.cache/huggingface/datasets/downloads/extracted/6429865a89f9ae54df1c3c2db5d0f1f25cf7dd43cb87704d76ed08cf8c194aba/OASIS-2/sub-OASIS20105/ses-02/anat/msub-OASIS20105_ses-02_T1w_brain_affine_mni.nii.gz,train
5
+ 3,86,1,/Users/jbrown2/.cache/huggingface/datasets/downloads/extracted/6429865a89f9ae54df1c3c2db5d0f1f25cf7dd43cb87704d76ed08cf8c194aba/OASIS-2/sub-OASIS20105/ses-01/anat/msub-OASIS20105_ses-01_T1w_brain_affine_mni.nii.gz,train
6
+ 4,84,1,/Users/jbrown2/.cache/huggingface/datasets/downloads/extracted/6429865a89f9ae54df1c3c2db5d0f1f25cf7dd43cb87704d76ed08cf8c194aba/OASIS-2/sub-OASIS20102/ses-02/anat/msub-OASIS20102_ses-02_T1w_brain_affine_mni.nii.gz,train
requirements.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PyTorch (CUDA or CPU version).
2
+ torch>=1.12
3
+
4
+ # Install MONAI Generative first
5
+ monai-generative
6
+
7
+ # Install the latest MONAI directly from GitHub (development version)
8
+ git+https://github.com/Project-MONAI/MONAI.git#egg=monai
9
+
10
+ # For perceptual losses in MONAI's generative module.
11
+ lpips
12
+
13
+ # Common Python libraries
14
+ pandas
15
+ numpy
16
+ nibabel
17
+ tqdm
18
+ tensorboard
19
+ matplotlib
20
+ datasets
21
+
train_brain2vec.py ADDED
@@ -0,0 +1,526 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ """
4
+ train_brain2vec.py
5
+
6
+ Trains a 3D VAE-based Brain2Vec model using MONAI. This script implements
7
+ autoencoder training with adversarial loss (via a patch discriminator),
8
+ a perceptual loss, and KL divergence regularization for robust latent
9
+ representations.
10
+
11
+ Example usage:
12
+ python train_brain2vec.py \
13
+ --dataset_csv inputs.csv \
14
+ --cache_dir ./ae_cache \
15
+ --output_dir ./ae_output \
16
+ --n_epochs 10
17
+ """
18
+
19
+ import os
20
+ os.environ["PYTORCH_WEIGHTS_ONLY"] = "False"
21
+ from typing import Optional, Union
22
+ import pandas as pd
23
+ import argparse
24
+ import numpy as np
25
+ import warnings
26
+ import torch
27
+ import torch.nn as nn
28
+ from torch import Tensor
29
+ from torch.optim.optimizer import Optimizer
30
+ from torch.nn import L1Loss
31
+ from torch.utils.data import DataLoader
32
+ from torch.amp import autocast
33
+ from torch.amp import GradScaler
34
+ from generative.networks.nets import (
35
+ AutoencoderKL,
36
+ PatchDiscriminator,
37
+ )
38
+ from generative.losses import PerceptualLoss, PatchAdversarialLoss
39
+ from monai.data import Dataset, PersistentDataset
40
+ from monai.transforms.transform import Transform
41
+ from monai import transforms
42
+ from monai.utils import set_determinism
43
+ from monai.data.meta_tensor import MetaTensor
44
+ import torch.serialization
45
+ from numpy.core.multiarray import _reconstruct
46
+ from numpy import ndarray, dtype
47
+ torch.serialization.add_safe_globals([_reconstruct])
48
+ torch.serialization.add_safe_globals([MetaTensor])
49
+ torch.serialization.add_safe_globals([ndarray])
50
+ torch.serialization.add_safe_globals([dtype])
51
+ from tqdm import tqdm
52
+ import matplotlib.pyplot as plt
53
+ from torch.utils.tensorboard import SummaryWriter
54
+
55
+ # voxel resolution
56
+ RESOLUTION = 2
57
+
58
+ # shape of the MNI152 (1mm^3) template
59
+ INPUT_SHAPE_1mm = (182, 218, 182)
60
+
61
+ # resampling the MNI152 to (1.5mm^3)
62
+ INPUT_SHAPE_1p5mm = (122, 146, 122)
63
+
64
+ # Adjusting the dimensions to be divisible by 8 (2^3 where 3 are the downsampling layers of the AE)
65
+ #INPUT_SHAPE_AE = (120, 144, 120)
66
+ INPUT_SHAPE_AE = (80, 96, 80)
67
+
68
+ # Latent shape of the autoencoder
69
+ LATENT_SHAPE_AE = (1, 10, 12, 10)
70
+
71
+
72
+ def load_if(checkpoints_path: Optional[str], network: nn.Module) -> nn.Module:
73
+ """
74
+ Load pretrained weights if available.
75
+
76
+ Args:
77
+ checkpoints_path (Optional[str]): path of the checkpoints
78
+ network (nn.Module): the neural network to initialize
79
+
80
+ Returns:
81
+ nn.Module: the initialized neural network
82
+ """
83
+ if checkpoints_path is not None:
84
+ assert os.path.exists(checkpoints_path), 'Invalid path'
85
+ network.load_state_dict(torch.load(checkpoints_path))
86
+ return network
87
+
88
+
89
+ def init_autoencoder(checkpoints_path: Optional[str] = None) -> nn.Module:
90
+ """
91
+ Load the KL autoencoder (pretrained if `checkpoints_path` points to previous params).
92
+
93
+ Args:
94
+ checkpoints_path (Optional[str], optional): path of the checkpoints. Defaults to None.
95
+
96
+ Returns:
97
+ nn.Module: the KL autoencoder
98
+ """
99
+ autoencoder = AutoencoderKL(spatial_dims=3,
100
+ in_channels=1,
101
+ out_channels=1,
102
+ latent_channels=1, #3,
103
+ num_channels=(64, 128, 256, 512),
104
+ num_res_blocks=2,
105
+ norm_num_groups=32,
106
+ norm_eps=1e-06,
107
+ attention_levels=(False, False, False, False),
108
+ with_decoder_nonlocal_attn=False,
109
+ with_encoder_nonlocal_attn=False)
110
+ return load_if(checkpoints_path, autoencoder)
111
+
112
+
113
+ def init_patch_discriminator(checkpoints_path: Optional[str] = None) -> nn.Module:
114
+ """
115
+ Load the patch discriminator (pretrained if `checkpoints_path` points to previous params).
116
+
117
+ Args:
118
+ checkpoints_path (Optional[str], optional): path of the checkpoints. Defaults to None.
119
+
120
+ Returns:
121
+ nn.Module: the patch discriminator
122
+ """
123
+ patch_discriminator = PatchDiscriminator(spatial_dims=3,
124
+ num_layers_d=3,
125
+ num_channels=32,
126
+ in_channels=1,
127
+ out_channels=1)
128
+ return load_if(checkpoints_path, patch_discriminator)
129
+
130
+
131
+ class KLDivergenceLoss:
132
+ """
133
+ A class for computing the Kullback-Leibler divergence loss.
134
+ """
135
+
136
+ def __call__(self, z_mu: Tensor, z_sigma: Tensor) -> Tensor:
137
+ """
138
+ Computes the KL divergence loss for the given parameters.
139
+
140
+ Args:
141
+ z_mu (Tensor): The mean of the distribution.
142
+ z_sigma (Tensor): The standard deviation of the distribution.
143
+
144
+ Returns:
145
+ Tensor: The computed KL divergence loss, averaged over the batch size.
146
+ """
147
+
148
+ kl_loss = 0.5 * torch.sum(z_mu.pow(2) + z_sigma.pow(2) - torch.log(z_sigma.pow(2)) - 1, dim=[1, 2, 3, 4])
149
+ return torch.sum(kl_loss) / kl_loss.shape[0]
150
+
151
+
152
+ class GradientAccumulation:
153
+ """
154
+ Implements gradient accumulation to facilitate training with larger
155
+ effective batch sizes than what can be physically accommodated in memory.
156
+ """
157
+
158
+ def __init__(self,
159
+ actual_batch_size: int,
160
+ expect_batch_size: int,
161
+ loader_len: int,
162
+ optimizer: Optimizer,
163
+ grad_scaler: Optional[GradScaler] = None) -> None:
164
+ """
165
+ Initializes the GradientAccumulation instance with the necessary parameters for
166
+ managing gradient accumulation.
167
+
168
+ Args:
169
+ actual_batch_size (int): The size of the mini-batches actually used in training.
170
+ expect_batch_size (int): The desired (effective) batch size to simulate through gradient accumulation.
171
+ loader_len (int): The length of the data loader, representing the total number of mini-batches.
172
+ optimizer (Optimizer): The optimizer used for performing optimization steps.
173
+ grad_scaler (Optional[GradScaler], optional): A GradScaler for mixed precision training. Defaults to None.
174
+
175
+ Raises:
176
+ AssertionError: If `expect_batch_size` is not divisible by `actual_batch_size`.
177
+ """
178
+
179
+ assert expect_batch_size % actual_batch_size == 0, \
180
+ 'expect_batch_size must be divisible by actual_batch_size'
181
+ self.actual_batch_size = actual_batch_size
182
+ self.expect_batch_size = expect_batch_size
183
+ self.loader_len = loader_len
184
+ self.optimizer = optimizer
185
+ self.grad_scaler = grad_scaler
186
+
187
+ # if the expected batch size is N=KM, and the actual batch size
188
+ # is M, then we need to accumulate gradient from N / M = K optimization steps.
189
+ self.steps_until_update = expect_batch_size / actual_batch_size
190
+
191
+ def step(self, loss: Tensor, step: int) -> None:
192
+ """
193
+ Performs a backward pass for the given loss and potentially executes an optimization
194
+ step if the conditions for gradient accumulation are met. The optimization step is taken
195
+ only after a specified number of steps (defined by the expected batch size) or at the end
196
+ of the dataset.
197
+
198
+ Args:
199
+ loss (Tensor): The loss value for the current forward pass.
200
+ step (int): The current step (mini-batch index) within the epoch.
201
+ """
202
+ loss = loss / self.expect_batch_size
203
+
204
+ if self.grad_scaler is not None:
205
+ self.grad_scaler.scale(loss).backward()
206
+ else:
207
+ loss.backward()
208
+ if (step + 1) % self.steps_until_update == 0 or (step + 1) == self.loader_len:
209
+ if self.grad_scaler is not None:
210
+ self.grad_scaler.step(self.optimizer)
211
+ self.grad_scaler.update()
212
+ else:
213
+ self.optimizer.step()
214
+ self.optimizer.zero_grad(set_to_none=True)
215
+
216
+
217
+ class AverageLoss:
218
+ """
219
+ Utility class to track losses
220
+ and metrics during training.
221
+ """
222
+
223
+ def __init__(self):
224
+ self.losses_accumulator = {}
225
+
226
+ def put(self, loss_key:str, loss_value:Union[int,float]) -> None:
227
+ """
228
+ Store value
229
+
230
+ Args:
231
+ loss_key (str): Metric name
232
+ loss_value (int | float): Metric value to store
233
+ """
234
+ if loss_key not in self.losses_accumulator:
235
+ self.losses_accumulator[loss_key] = []
236
+ self.losses_accumulator[loss_key].append(loss_value)
237
+
238
+ def pop_avg(self, loss_key:str) -> float:
239
+ """
240
+ Average the stored values of a given metric
241
+
242
+ Args:
243
+ loss_key (str): Metric name
244
+
245
+ Returns:
246
+ float: average of the stored values
247
+ """
248
+ if loss_key not in self.losses_accumulator:
249
+ return None
250
+ losses = self.losses_accumulator[loss_key]
251
+ self.losses_accumulator[loss_key] = []
252
+ return sum(losses) / len(losses)
253
+
254
+ def to_tensorboard(self, writer: SummaryWriter, step: int):
255
+ """
256
+ Logs the average value of all the metrics stored
257
+ into Tensorboard.
258
+
259
+ Args:
260
+ writer (SummaryWriter): Tensorboard writer
261
+ step (int): Tensorboard logging global step
262
+ """
263
+ for metric_key in self.losses_accumulator.keys():
264
+ writer.add_scalar(metric_key, self.pop_avg(metric_key), step)
265
+
266
+
267
+ def get_dataset_from_pd(df: pd.DataFrame, transforms_fn: Transform, cache_dir: Optional[str]) -> Union[Dataset,PersistentDataset]:
268
+ """
269
+ If `cache_dir` is defined, returns a `monai.data.PersistenDataset`.
270
+ Otherwise, returns a simple `monai.data.Dataset`.
271
+
272
+ Args:
273
+ df (pd.DataFrame): Dataframe describing each image in the longitudinal dataset.
274
+ transforms_fn (Transform): Set of transformations
275
+ cache_dir (Optional[str]): Cache directory (ensure enough storage is available)
276
+
277
+ Returns:
278
+ Dataset|PersistentDataset: The dataset
279
+ """
280
+ assert cache_dir is None or os.path.exists(cache_dir), 'Invalid cache directory path'
281
+ data = df.to_dict(orient='records')
282
+ return Dataset(data=data, transform=transforms_fn) if cache_dir is None \
283
+ else PersistentDataset(data=data, transform=transforms_fn, cache_dir=cache_dir)
284
+
285
+
286
+ def tb_display_reconstruction(writer, step, image, recon):
287
+ """
288
+ Display reconstruction in TensorBoard during AE training.
289
+ """
290
+ plt.style.use('dark_background')
291
+ _, ax = plt.subplots(ncols=3, nrows=2, figsize=(7, 5))
292
+ for _ax in ax.flatten(): _ax.set_axis_off()
293
+
294
+ if len(image.shape) == 4: image = image.squeeze(0)
295
+ if len(recon.shape) == 4: recon = recon.squeeze(0)
296
+
297
+ ax[0, 0].set_title('original image', color='cyan')
298
+ ax[0, 0].imshow(image[image.shape[0] // 2, :, :], cmap='gray')
299
+ ax[0, 1].imshow(image[:, image.shape[1] // 2, :], cmap='gray')
300
+ ax[0, 2].imshow(image[:, :, image.shape[2] // 2], cmap='gray')
301
+
302
+ ax[1, 0].set_title('reconstructed image', color='magenta')
303
+ ax[1, 0].imshow(recon[recon.shape[0] // 2, :, :], cmap='gray')
304
+ ax[1, 1].imshow(recon[:, recon.shape[1] // 2, :], cmap='gray')
305
+ ax[1, 2].imshow(recon[:, :, recon.shape[2] // 2], cmap='gray')
306
+
307
+ plt.tight_layout()
308
+ writer.add_figure('Reconstruction', plt.gcf(), global_step=step)
309
+
310
+
311
+ def set_environment(seed: int = 0) -> None:
312
+ """
313
+ Set deterministic behavior for reproducibility.
314
+
315
+ Args:
316
+ seed (int, optional): Seed value. Defaults to 0.
317
+ """
318
+ set_determinism(seed)
319
+
320
+
321
+ def train(
322
+ dataset_csv: str,
323
+ cache_dir: str,
324
+ output_dir: str,
325
+ aekl_ckpt: Optional[str] = None,
326
+ disc_ckpt: Optional[str] = None,
327
+ num_workers: int = 8,
328
+ n_epochs: int = 5,
329
+ max_batch_size: int = 2,
330
+ batch_size: int = 16,
331
+ lr: float = 1e-4,
332
+ aug_p: float = 0.8,
333
+ device: str = ('cuda' if torch.cuda.is_available() else
334
+ 'cpu'),
335
+ ) -> None:
336
+ """
337
+ Train the autoencoder and discriminator models.
338
+
339
+ Args:
340
+ dataset_csv (str): Path to the dataset CSV file.
341
+ cache_dir (str): Directory for caching data.
342
+ output_dir (str): Directory to save model checkpoints.
343
+ aekl_ckpt (Optional[str], optional): Path to the autoencoder checkpoint. Defaults to None.
344
+ disc_ckpt (Optional[str], optional): Path to the discriminator checkpoint. Defaults to None.
345
+ num_workers (int, optional): Number of data loader workers. Defaults to 8.
346
+ n_epochs (int, optional): Number of training epochs. Defaults to 5.
347
+ max_batch_size (int, optional): Actual batch size per iteration. Defaults to 2.
348
+ batch_size (int, optional): Expected (effective) batch size. Defaults to 16.
349
+ lr (float, optional): Learning rate. Defaults to 1e-4.
350
+ aug_p (float, optional): Augmentation probability. Defaults to 0.8.
351
+ device (str, optional): Device to run the training on. Defaults to 'cuda' if available.
352
+ """
353
+ set_environment(0)
354
+
355
+ transforms_fn = transforms.Compose([
356
+ transforms.CopyItemsD(keys={'image_path'}, names=['image']),
357
+ transforms.LoadImageD(image_only=True, keys=['image']),
358
+ transforms.EnsureChannelFirstD(keys=['image']),
359
+ transforms.SpacingD(pixdim=2, keys=['image']),
360
+ transforms.ResizeWithPadOrCropD(spatial_size=(80, 96, 80), mode='minimum', keys=['image']),
361
+ transforms.ScaleIntensityD(minv=0, maxv=1, keys=['image'])
362
+ ])
363
+
364
+ dataset_df = pd.read_csv(dataset_csv)
365
+ train_df = dataset_df[dataset_df.split == 'train']
366
+ trainset = get_dataset_from_pd(train_df, transforms_fn, cache_dir)
367
+
368
+ train_loader = DataLoader(
369
+ dataset=trainset,
370
+ num_workers=num_workers,
371
+ batch_size=max_batch_size,
372
+ shuffle=True,
373
+ persistent_workers=True,
374
+ pin_memory=True,
375
+ )
376
+
377
+ print('Device is %s' %(device))
378
+ autoencoder = init_autoencoder(aekl_ckpt).to(device)
379
+ discriminator = init_patch_discriminator(disc_ckpt).to(device)
380
+
381
+ # Loss Weights
382
+ adv_weight = 0.025
383
+ perceptual_weight = 0.001
384
+ kl_weight = 1e-7
385
+
386
+ # Loss Functions
387
+ l1_loss_fn = L1Loss()
388
+ kl_loss_fn = KLDivergenceLoss()
389
+ adv_loss_fn = PatchAdversarialLoss(criterion="least_squares")
390
+
391
+ with warnings.catch_warnings():
392
+ warnings.simplefilter("ignore")
393
+ perc_loss_fn = PerceptualLoss(
394
+ spatial_dims=3,
395
+ network_type="squeeze",
396
+ is_fake_3d=True,
397
+ fake_3d_ratio=0.2
398
+ ).to(device)
399
+
400
+ # Optimizers
401
+ optimizer_g = torch.optim.Adam(autoencoder.parameters(), lr=lr)
402
+ optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=lr)
403
+
404
+ # Gradient Accumulation
405
+ gradacc_g = GradientAccumulation(
406
+ actual_batch_size=max_batch_size,
407
+ expect_batch_size=batch_size,
408
+ loader_len=len(train_loader),
409
+ optimizer=optimizer_g,
410
+ grad_scaler=GradScaler()
411
+ )
412
+
413
+ gradacc_d = GradientAccumulation(
414
+ actual_batch_size=max_batch_size,
415
+ expect_batch_size=batch_size,
416
+ loader_len=len(train_loader),
417
+ optimizer=optimizer_d,
418
+ grad_scaler=GradScaler()
419
+ )
420
+
421
+ # Logging
422
+ avgloss = AverageLoss()
423
+ writer = SummaryWriter()
424
+ total_counter = 0
425
+
426
+ for epoch in range(n_epochs):
427
+ print(f"[DEBUG] Starting epoch {epoch}/{n_epochs-1}")
428
+ autoencoder.train()
429
+ progress_bar = tqdm(enumerate(train_loader), total=len(train_loader))
430
+ progress_bar.set_description(f'Epoch {epoch}')
431
+
432
+ for step, batch in progress_bar:
433
+ # Generator Training
434
+ with autocast(device, enabled=True):
435
+ images = batch["image"].to(device)
436
+ reconstruction, z_mu, z_sigma = autoencoder(images)
437
+
438
+ logits_fake = discriminator(reconstruction.contiguous().float())[-1]
439
+
440
+ rec_loss = l1_loss_fn(reconstruction.float(), images.float())
441
+ kl_loss = kl_weight * kl_loss_fn(z_mu, z_sigma)
442
+ per_loss = perceptual_weight * perc_loss_fn(reconstruction.float(), images.float())
443
+ gen_loss = adv_weight * adv_loss_fn(logits_fake, target_is_real=True, for_discriminator=False)
444
+
445
+ loss_g = rec_loss + kl_loss + per_loss + gen_loss
446
+
447
+ gradacc_g.step(loss_g, step)
448
+
449
+ # Discriminator Training
450
+ with autocast(device, enabled=True):
451
+ logits_fake = discriminator(reconstruction.contiguous().detach())[-1]
452
+ d_loss_fake = adv_loss_fn(logits_fake, target_is_real=False, for_discriminator=True)
453
+ logits_real = discriminator(images.contiguous().detach())[-1]
454
+ d_loss_real = adv_loss_fn(logits_real, target_is_real=True, for_discriminator=True)
455
+ discriminator_loss = (d_loss_fake + d_loss_real) * 0.5
456
+ loss_d = adv_weight * discriminator_loss
457
+
458
+ gradacc_d.step(loss_d, step)
459
+
460
+ # Logging
461
+ avgloss.put('Generator/reconstruction_loss', rec_loss.item())
462
+ avgloss.put('Generator/perceptual_loss', per_loss.item())
463
+ avgloss.put('Generator/adversarial_loss', gen_loss.item())
464
+ avgloss.put('Generator/kl_regularization', kl_loss.item())
465
+ avgloss.put('Discriminator/adversarial_loss', loss_d.item())
466
+
467
+ if total_counter % 10 == 0:
468
+ step_log = total_counter // 10
469
+ avgloss.to_tensorboard(writer, step_log)
470
+ tb_display_reconstruction(
471
+ writer,
472
+ step_log,
473
+ images[0].detach().cpu(),
474
+ reconstruction[0].detach().cpu()
475
+ )
476
+
477
+ total_counter += 1
478
+
479
+ # Save the model after each epoch.
480
+ os.makedirs(output_dir, exist_ok=True)
481
+ torch.save(discriminator.state_dict(), os.path.join(output_dir, f'discriminator-ep-{epoch}.pth'))
482
+ torch.save(autoencoder.state_dict(), os.path.join(output_dir, f'autoencoder-ep-{epoch}.pth'))
483
+
484
+ writer.close()
485
+ print("Training completed and models saved.")
486
+
487
+
488
+ def main():
489
+ """
490
+ Main function to parse command-line arguments and run train().
491
+ """
492
+ import argparse
493
+
494
+ parser = argparse.ArgumentParser(description="brain2vec Training Script")
495
+
496
+ parser.add_argument('--dataset_csv', type=str, required=True, help='Path to the dataset CSV file.')
497
+ parser.add_argument('--cache_dir', type=str, required=True, help='Directory for caching data.')
498
+ parser.add_argument('--output_dir', type=str, required=True, help='Directory to save model checkpoints.')
499
+ parser.add_argument('--aekl_ckpt', type=str, default=None, help='Path to the autoencoder checkpoint.')
500
+ parser.add_argument('--disc_ckpt', type=str, default=None, help='Path to the discriminator checkpoint.')
501
+ parser.add_argument('--num_workers', type=int, default=8, help='Number of data loader workers.')
502
+ parser.add_argument('--n_epochs', type=int, default=5, help='Number of training epochs.')
503
+ parser.add_argument('--max_batch_size', type=int, default=2, help='Actual batch size per iteration.')
504
+ parser.add_argument('--batch_size', type=int, default=16, help='Expected (effective) batch size.')
505
+ parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate.')
506
+ parser.add_argument('--aug_p', type=float, default=0.8, help='Augmentation probability.')
507
+
508
+ args = parser.parse_args()
509
+
510
+ train(
511
+ dataset_csv=args.dataset_csv,
512
+ cache_dir=args.cache_dir,
513
+ output_dir=args.output_dir,
514
+ aekl_ckpt=args.aekl_ckpt,
515
+ disc_ckpt=args.disc_ckpt,
516
+ num_workers=args.num_workers,
517
+ n_epochs=args.n_epochs,
518
+ max_batch_size=args.max_batch_size,
519
+ batch_size=args.batch_size,
520
+ lr=args.lr,
521
+ aug_p=args.aug_p,
522
+ )
523
+
524
+
525
+ if __name__ == '__main__':
526
+ main()