Initial commit
Browse files- README.md +103 -3
- brain2vec_PCA.py +194 -0
- create_csv.py +39 -0
- inputs_example.csv +6 -0
- model.py +121 -0
- requirements.txt +15 -0
README.md
CHANGED
@@ -1,3 +1,103 @@
|
|
1 |
-
---
|
2 |
-
license: mit
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: mit
|
3 |
+
language:
|
4 |
+
- en
|
5 |
+
task_categories:
|
6 |
+
- image-classification
|
7 |
+
tags:
|
8 |
+
- medical
|
9 |
+
- brain-data
|
10 |
+
- mri
|
11 |
+
pretty_name: 3D Brain Structure MRI PCA
|
12 |
+
---
|
13 |
+
|
14 |
+
## 🧠 Model Summary
|
15 |
+
# brain2vec
|
16 |
+
An linear PCA model for brain structure T1 MRIs. The models takes in a 3d MRI NIfTI file and compresses to 1200 latent dimensions before reconstructing the image.
|
17 |
+
|
18 |
+
|
19 |
+
# Training data
|
20 |
+
[Radiata brain-structure](https://huggingface.co/datasets/radiata-ai/brain-structure): 3066 scans from 2085 individuals in the 'train' split. Mean age = 45.1 +- 24.5, including 2847 scans from cognitively normal subjects and 219 scans from individuals with an Alzheimer's disease clinical diagnosis.
|
21 |
+
|
22 |
+
# Example usage
|
23 |
+
```
|
24 |
+
# get brain2vec model repository
|
25 |
+
git clone https://huggingface.co/radiata-ai/brain2vec
|
26 |
+
cd brain2vec
|
27 |
+
|
28 |
+
# set up virtual environemt
|
29 |
+
python3 -m venv venv_brain2vec
|
30 |
+
source venv_brain2vec/bin/activate
|
31 |
+
|
32 |
+
# install Python libraries
|
33 |
+
pip install -r requirements.txt
|
34 |
+
|
35 |
+
# create the csv file inputs.csv listing the scan paths and other info
|
36 |
+
# this script loads the radiata-ai/brain-structure dataset
|
37 |
+
python create_csv.py
|
38 |
+
|
39 |
+
mkdir ae_cache
|
40 |
+
mkdir ae_output
|
41 |
+
|
42 |
+
# train the model
|
43 |
+
nohup python brain2vec.py train \
|
44 |
+
--dataset_csv /home/ubuntu/brain2vec/inputs.csv \
|
45 |
+
--cache_dir ./ae_cache \
|
46 |
+
--output_dir ./ae_output \
|
47 |
+
--n_epochs 10 \
|
48 |
+
> train_log.txt 2>&1 &
|
49 |
+
|
50 |
+
# run model inference to create *_embeddings.npz files
|
51 |
+
python brain2vec.py infererence \
|
52 |
+
--dataset_csv home/ubuntu/brain2vec/inputs.csv \
|
53 |
+
--aekl_ckpt /home/ubuntu/brain2vec/autoencoder_final.pth \
|
54 |
+
--output_dir /home/ubuntu/brain2vec
|
55 |
+
```
|
56 |
+
|
57 |
+
# Methods
|
58 |
+
transform:
|
59 |
+
(80, 96, 80)
|
60 |
+
pixdim=2
|
61 |
+
10 epochs
|
62 |
+
max_batch_size: int = 2,
|
63 |
+
batch_size: int = 16,
|
64 |
+
lr: float = 1e-4,
|
65 |
+
|
66 |
+
# References
|
67 |
+
Puglisi
|
68 |
+
Pinaya
|
69 |
+
|
70 |
+
# Citation
|
71 |
+
```
|
72 |
+
@dataset{Radiata-Brain-Structure,
|
73 |
+
author = {Jesse Brown and Clayton Young},
|
74 |
+
title = {Brain-Structure: Processed Structural MRI Brain Scans Across the Lifespan},
|
75 |
+
year = {2025},
|
76 |
+
url = {https://huggingface.co/datasets/radiata-ai/brain-structure},
|
77 |
+
note = {Version 1.0},
|
78 |
+
publisher = {Hugging Face}
|
79 |
+
}
|
80 |
+
```
|
81 |
+
|
82 |
+
# License
|
83 |
+
MIT License
|
84 |
+
|
85 |
+
Copyright (c) 2025
|
86 |
+
|
87 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
88 |
+
of this software and associated documentation files (the "Software"), to deal
|
89 |
+
in the Software without restriction, including without limitation the rights
|
90 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
91 |
+
copies of the Software, and to permit persons to whom the Software is
|
92 |
+
furnished to do so, subject to the following conditions:
|
93 |
+
|
94 |
+
The above copyright notice and this permission notice shall be included in all
|
95 |
+
copies or substantial portions of the Software.
|
96 |
+
|
97 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
98 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
99 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
100 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
101 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
102 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
103 |
+
SOFTWARE.
|
brain2vec_PCA.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
"""
|
4 |
+
pca_autoencoder.py
|
5 |
+
|
6 |
+
This script demonstrates how to:
|
7 |
+
1) Load a dataset of MRI volumes using MONAI transforms (as in brain2vec_linearAE.py).
|
8 |
+
2) Flatten each 3D volume into a 1D vector (614,400 features if 80x96x80).
|
9 |
+
3) Perform IncrementalPCA to reduce dimensionality to 1200 components.
|
10 |
+
4) Provide a 'forward()' method that returns (reconstruction, embedding),
|
11 |
+
mimicking the interface of a linear autoencoder.
|
12 |
+
"""
|
13 |
+
|
14 |
+
import os
|
15 |
+
import argparse
|
16 |
+
import numpy as np
|
17 |
+
import pandas as pd
|
18 |
+
|
19 |
+
import torch
|
20 |
+
from torch.utils.data import DataLoader
|
21 |
+
|
22 |
+
from monai import transforms
|
23 |
+
from monai.data import Dataset, PersistentDataset
|
24 |
+
|
25 |
+
from sklearn.decomposition import IncrementalPCA
|
26 |
+
|
27 |
+
###################################################################
|
28 |
+
# Constants for your typical config
|
29 |
+
###################################################################
|
30 |
+
RESOLUTION = 2
|
31 |
+
INPUT_SHAPE_AE = (80, 96, 80)
|
32 |
+
N_COMPONENTS = 1200
|
33 |
+
|
34 |
+
###################################################################
|
35 |
+
# Helper classes/functions
|
36 |
+
###################################################################
|
37 |
+
def get_dataset_from_pd(df: pd.DataFrame, transforms_fn, cache_dir: str):
|
38 |
+
"""
|
39 |
+
Returns a monai.data.Dataset or monai.data.PersistentDataset
|
40 |
+
if `cache_dir` is defined, to speed up loading.
|
41 |
+
"""
|
42 |
+
if cache_dir and cache_dir.strip():
|
43 |
+
os.makedirs(cache_dir, exist_ok=True)
|
44 |
+
dataset = PersistentDataset(data=df.to_dict(orient='records'),
|
45 |
+
transform=transforms_fn,
|
46 |
+
cache_dir=cache_dir)
|
47 |
+
else:
|
48 |
+
dataset = Dataset(data=df.to_dict(orient='records'),
|
49 |
+
transform=transforms_fn)
|
50 |
+
return dataset
|
51 |
+
|
52 |
+
|
53 |
+
class PCAAutoencoder:
|
54 |
+
"""
|
55 |
+
A PCA 'autoencoder' using IncrementalPCA for memory efficiency,
|
56 |
+
providing:
|
57 |
+
- fit(X): partial fit on batches
|
58 |
+
- transform(X): get embeddings
|
59 |
+
- inverse_transform(Z): reconstruct from embeddings
|
60 |
+
- forward(X): returns (X_recon, Z) for a direct API
|
61 |
+
similar to a shallow linear AE.
|
62 |
+
"""
|
63 |
+
def __init__(self, n_components=N_COMPONENTS, batch_size=128):
|
64 |
+
self.n_components = n_components
|
65 |
+
self.batch_size = batch_size
|
66 |
+
self.ipca = IncrementalPCA(n_components=self.n_components)
|
67 |
+
|
68 |
+
def fit(self, X: np.ndarray):
|
69 |
+
"""
|
70 |
+
Incrementally fit the PCA model on batches of data.
|
71 |
+
X: shape (n_samples, n_features).
|
72 |
+
"""
|
73 |
+
n_samples = X.shape[0]
|
74 |
+
for start_idx in range(0, n_samples, self.batch_size):
|
75 |
+
end_idx = min(start_idx + self.batch_size, n_samples)
|
76 |
+
self.ipca.partial_fit(X[start_idx:end_idx])
|
77 |
+
|
78 |
+
def transform(self, X: np.ndarray) -> np.ndarray:
|
79 |
+
"""
|
80 |
+
Projects data into the PCA latent space in batches.
|
81 |
+
Returns Z: shape (n_samples, n_components).
|
82 |
+
"""
|
83 |
+
results = []
|
84 |
+
n_samples = X.shape[0]
|
85 |
+
for start_idx in range(0, n_samples, self.batch_size):
|
86 |
+
end_idx = min(start_idx + self.batch_size, n_samples)
|
87 |
+
Z_chunk = self.ipca.transform(X[start_idx:end_idx])
|
88 |
+
results.append(Z_chunk)
|
89 |
+
return np.vstack(results)
|
90 |
+
|
91 |
+
def inverse_transform(self, Z: np.ndarray) -> np.ndarray:
|
92 |
+
"""
|
93 |
+
Reconstruct data from PCA latent space in batches.
|
94 |
+
Returns X_recon: shape (n_samples, n_features).
|
95 |
+
"""
|
96 |
+
results = []
|
97 |
+
n_samples = Z.shape[0]
|
98 |
+
for start_idx in range(0, n_samples, self.batch_size):
|
99 |
+
end_idx = min(start_idx + self.batch_size, n_samples)
|
100 |
+
X_chunk = self.ipca.inverse_transform(Z[start_idx:end_idx])
|
101 |
+
results.append(X_chunk)
|
102 |
+
return np.vstack(results)
|
103 |
+
|
104 |
+
def forward(self, X: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
105 |
+
"""
|
106 |
+
Mimics a linear AE's forward() returning (X_recon, Z).
|
107 |
+
"""
|
108 |
+
Z = self.transform(X)
|
109 |
+
X_recon = self.inverse_transform(Z)
|
110 |
+
return X_recon, Z
|
111 |
+
|
112 |
+
|
113 |
+
def load_and_flatten_dataset(csv_path: str, cache_dir: str, transforms_fn) -> np.ndarray:
|
114 |
+
"""
|
115 |
+
Loads the dataset from csv_path, applies the monai transforms,
|
116 |
+
and flattens each 3D MRI into a 1D vector of shape (80*96*80).
|
117 |
+
Returns a numpy array X with shape (n_samples, 614400).
|
118 |
+
"""
|
119 |
+
df = pd.read_csv(csv_path)
|
120 |
+
dataset = get_dataset_from_pd(df, transforms_fn, cache_dir)
|
121 |
+
|
122 |
+
# We'll put the flattened data into this list, then stack.
|
123 |
+
X_list = []
|
124 |
+
|
125 |
+
# If memory allows, you can simply do a single-threaded loop
|
126 |
+
# or multi-worker DataLoader for speed.
|
127 |
+
# We'll demonstrate a simple single-worker here for clarity.
|
128 |
+
loader = DataLoader(dataset, batch_size=1, num_workers=0)
|
129 |
+
|
130 |
+
for batch in loader:
|
131 |
+
# batch["image"] shape: (1, 1, 80, 96, 80)
|
132 |
+
img = batch["image"].squeeze(0) # shape: (1, 80, 96, 80)
|
133 |
+
img_np = img.numpy() # convert to np array, shape: (1, D, H, W)
|
134 |
+
flattened = img_np.flatten() # shape: (614400,)
|
135 |
+
X_list.append(flattened)
|
136 |
+
|
137 |
+
X = np.vstack(X_list) # shape: (n_samples, 614400)
|
138 |
+
return X
|
139 |
+
|
140 |
+
|
141 |
+
def main():
|
142 |
+
parser = argparse.ArgumentParser(description="PCA Autoencoder with MONAI transforms example.")
|
143 |
+
parser.add_argument("--inputs_csv", type=str, required=True, help="CSV with 'image_path' column.")
|
144 |
+
parser.add_argument("--cache_dir", type=str, default="", help="Cache directory for MONAI PersistentDataset.")
|
145 |
+
parser.add_argument("--output_dir", type=str, default="./pca_outputs", help="Where to save PCA model and embeddings.")
|
146 |
+
parser.add_argument("--batch_size_ipca", type=int, default=128, help="Batch size for IncrementalPCA partial_fit().")
|
147 |
+
parser.add_argument("--n_components", type=int, default=1200, help="Number of PCA components.")
|
148 |
+
args = parser.parse_args()
|
149 |
+
|
150 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
151 |
+
|
152 |
+
# Same transforms as in brain2vec_linearAE.py
|
153 |
+
transforms_fn = transforms.Compose([
|
154 |
+
transforms.CopyItemsD(keys={'image_path'}, names=['image']),
|
155 |
+
transforms.LoadImageD(image_only=True, keys=['image']),
|
156 |
+
transforms.EnsureChannelFirstD(keys=['image']),
|
157 |
+
transforms.SpacingD(pixdim=RESOLUTION, keys=['image']),
|
158 |
+
transforms.ResizeWithPadOrCropD(spatial_size=INPUT_SHAPE_AE, mode='minimum', keys=['image']),
|
159 |
+
transforms.ScaleIntensityD(minv=0, maxv=1, keys=['image']),
|
160 |
+
])
|
161 |
+
|
162 |
+
print("Loading and flattening dataset from:", args.inputs_csv)
|
163 |
+
X = load_and_flatten_dataset(args.inputs_csv, args.cache_dir, transforms_fn)
|
164 |
+
print(f"Dataset shape after flattening: {X.shape}")
|
165 |
+
|
166 |
+
# Build PCAAutoencoder
|
167 |
+
model = PCAAutoencoder(n_components=args.n_components, batch_size=args.batch_size_ipca)
|
168 |
+
|
169 |
+
# Fit the PCA model
|
170 |
+
print("Fitting IncrementalPCA in batches...")
|
171 |
+
model.fit(X)
|
172 |
+
print("Done fitting PCA. Transforming data to embeddings...")
|
173 |
+
|
174 |
+
# Get embeddings & reconstruction
|
175 |
+
X_recon, Z = model.forward(X)
|
176 |
+
print("Embeddings shape:", Z.shape)
|
177 |
+
print("Reconstruction shape:", X_recon.shape)
|
178 |
+
|
179 |
+
# Optional: Save
|
180 |
+
embeddings_path = os.path.join(args.output_dir, "pca_embeddings.npy")
|
181 |
+
recons_path = os.path.join(args.output_dir, "pca_reconstructions.npy")
|
182 |
+
np.save(embeddings_path, Z)
|
183 |
+
np.save(recons_path, X_recon)
|
184 |
+
print(f"Saved embeddings to {embeddings_path} and reconstructions to {recons_path}")
|
185 |
+
|
186 |
+
# If you want to store the actual PCA components for future usage:
|
187 |
+
# from joblib import dump
|
188 |
+
# ipca_model_path = os.path.join(args.output_dir, "pca_model.joblib")
|
189 |
+
# dump(model.ipca, ipca_model_path)
|
190 |
+
# print(f"Saved PCA model to {ipca_model_path}")
|
191 |
+
|
192 |
+
|
193 |
+
if __name__ == "__main__":
|
194 |
+
main()
|
create_csv.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
import os
|
3 |
+
import pandas as pd
|
4 |
+
from datasets import load_dataset
|
5 |
+
|
6 |
+
def row_to_dict(row, split_name):
|
7 |
+
return {
|
8 |
+
"image_uid": row["id"],
|
9 |
+
"age": int(row["metadata"]["age"]),
|
10 |
+
"sex": 1 if row["metadata"]["sex"].lower() == "male" else 2,
|
11 |
+
"image_path": os.path.abspath(row["nii_filepath"]),
|
12 |
+
"split": split_name
|
13 |
+
}
|
14 |
+
|
15 |
+
def main():
|
16 |
+
# Load the datasets
|
17 |
+
ds_train = load_dataset("radiata-ai/brain-structure", split="train", trust_remote_code=True)
|
18 |
+
ds_val = load_dataset("radiata-ai/brain-structure", split="validation", trust_remote_code=True)
|
19 |
+
ds_test = load_dataset("radiata-ai/brain-structure", split="test", trust_remote_code=True)
|
20 |
+
|
21 |
+
rows = []
|
22 |
+
|
23 |
+
# Process each split
|
24 |
+
for data_row in ds_train:
|
25 |
+
rows.append(row_to_dict(data_row, "train"))
|
26 |
+
for data_row in ds_val:
|
27 |
+
rows.append(row_to_dict(data_row, "validation"))
|
28 |
+
for data_row in ds_test:
|
29 |
+
rows.append(row_to_dict(data_row, "test"))
|
30 |
+
|
31 |
+
# Create a DataFrame and write it to CSV
|
32 |
+
df = pd.DataFrame(rows)
|
33 |
+
output_csv = "inputs.csv"
|
34 |
+
df.to_csv(output_csv, index=False)
|
35 |
+
print(f"CSV file created: {output_csv}")
|
36 |
+
|
37 |
+
if __name__ == "__main__":
|
38 |
+
main()
|
39 |
+
|
inputs_example.csv
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
image_uid,age,sex,image_path,split
|
2 |
+
0,81,2,/Users/jbrown2/.cache/huggingface/datasets/downloads/extracted/6429865a89f9ae54df1c3c2db5d0f1f25cf7dd43cb87704d76ed08cf8c194aba/OASIS-2/sub-OASIS20133/ses-03/anat/msub-OASIS20133_ses-03_T1w_brain_affine_mni.nii.gz,train
|
3 |
+
1,78,2,/Users/jbrown2/.cache/huggingface/datasets/downloads/extracted/6429865a89f9ae54df1c3c2db5d0f1f25cf7dd43cb87704d76ed08cf8c194aba/OASIS-2/sub-OASIS20133/ses-01/anat/msub-OASIS20133_ses-01_T1w_brain_affine_mni.nii.gz,train
|
4 |
+
2,87,1,/Users/jbrown2/.cache/huggingface/datasets/downloads/extracted/6429865a89f9ae54df1c3c2db5d0f1f25cf7dd43cb87704d76ed08cf8c194aba/OASIS-2/sub-OASIS20105/ses-02/anat/msub-OASIS20105_ses-02_T1w_brain_affine_mni.nii.gz,train
|
5 |
+
3,86,1,/Users/jbrown2/.cache/huggingface/datasets/downloads/extracted/6429865a89f9ae54df1c3c2db5d0f1f25cf7dd43cb87704d76ed08cf8c194aba/OASIS-2/sub-OASIS20105/ses-01/anat/msub-OASIS20105_ses-01_T1w_brain_affine_mni.nii.gz,train
|
6 |
+
4,84,1,/Users/jbrown2/.cache/huggingface/datasets/downloads/extracted/6429865a89f9ae54df1c3c2db5d0f1f25cf7dd43cb87704d76ed08cf8c194aba/OASIS-2/sub-OASIS20102/ses-02/anat/msub-OASIS20102_ses-02_T1w_brain_affine_mni.nii.gz,train
|
model.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# model.py
|
2 |
+
import os
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from monai.transforms import (
|
8 |
+
Compose,
|
9 |
+
CopyItemsD,
|
10 |
+
LoadImageD,
|
11 |
+
EnsureChannelFirstD,
|
12 |
+
SpacingD,
|
13 |
+
ResizeWithPadOrCropD,
|
14 |
+
ScaleIntensityD,
|
15 |
+
)
|
16 |
+
|
17 |
+
# Constants for your typical config
|
18 |
+
RESOLUTION = 2
|
19 |
+
INPUT_SHAPE_AE = (80, 96, 80)
|
20 |
+
|
21 |
+
# Define the exact transform pipeline for input MRI
|
22 |
+
transforms_fn = Compose([
|
23 |
+
CopyItemsD(keys={'image_path'}, names=['image']),
|
24 |
+
LoadImageD(image_only=True, keys=['image']),
|
25 |
+
EnsureChannelFirstD(keys=['image']),
|
26 |
+
SpacingD(pixdim=RESOLUTION, keys=['image']),
|
27 |
+
ResizeWithPadOrCropD(spatial_size=INPUT_SHAPE_AE, mode='minimum', keys=['image']),
|
28 |
+
ScaleIntensityD(minv=0, maxv=1, keys=['image']),
|
29 |
+
])
|
30 |
+
|
31 |
+
def preprocess_mri(image_path: str, device: str = "cpu") -> torch.Tensor:
|
32 |
+
"""
|
33 |
+
Preprocess an MRI using MONAI transforms to produce
|
34 |
+
a 5D tensor (batch=1, channels=1, D, H, W) for inference.
|
35 |
+
"""
|
36 |
+
data_dict = {"image_path": image_path}
|
37 |
+
output_dict = transforms_fn(data_dict)
|
38 |
+
image_tensor = output_dict["image"] # shape: (1, D, H, W)
|
39 |
+
image_tensor = image_tensor.unsqueeze(0) # => (batch=1, channel=1, D, H, W)
|
40 |
+
return image_tensor.to(device)
|
41 |
+
|
42 |
+
|
43 |
+
class ShallowLinearAutoencoder(nn.Module):
|
44 |
+
"""
|
45 |
+
A purely linear autoencoder with one hidden layer.
|
46 |
+
- Flatten input into a vector
|
47 |
+
- Linear encoder (no activation)
|
48 |
+
- Linear decoder (no activation)
|
49 |
+
- Reshape output to original volume shape
|
50 |
+
"""
|
51 |
+
def __init__(self, input_shape=(80, 96, 80), hidden_size=1200):
|
52 |
+
super().__init__()
|
53 |
+
self.input_shape = input_shape
|
54 |
+
self.input_dim = input_shape[0] * input_shape[1] * input_shape[2]
|
55 |
+
self.hidden_size = hidden_size
|
56 |
+
|
57 |
+
# Encoder (no activation for PCA-like behavior)
|
58 |
+
self.encoder = nn.Sequential(
|
59 |
+
nn.Flatten(),
|
60 |
+
nn.Linear(self.input_dim, self.hidden_size),
|
61 |
+
)
|
62 |
+
|
63 |
+
# Decoder (no activation)
|
64 |
+
self.decoder = nn.Sequential(
|
65 |
+
nn.Linear(self.hidden_size, self.input_dim),
|
66 |
+
)
|
67 |
+
|
68 |
+
def encode(self, x: torch.Tensor):
|
69 |
+
return self.encoder(x)
|
70 |
+
|
71 |
+
def decode(self, z: torch.Tensor):
|
72 |
+
out = self.decoder(z)
|
73 |
+
# Reshape to (N, 1, D, H, W)
|
74 |
+
return out.view(-1, 1, *self.input_shape)
|
75 |
+
|
76 |
+
def forward(self, x: torch.Tensor):
|
77 |
+
"""
|
78 |
+
Return (reconstruction, embedding, None) to keep a similar API
|
79 |
+
to the old VAE-based code, though there's no σ for sampling.
|
80 |
+
"""
|
81 |
+
z = self.encode(x)
|
82 |
+
reconstruction = self.decode(z)
|
83 |
+
return reconstruction, z, None
|
84 |
+
|
85 |
+
|
86 |
+
class Brain2vec(nn.Module):
|
87 |
+
"""
|
88 |
+
A wrapper around the ShallowLinearAutoencoder, providing a from_pretrained(...)
|
89 |
+
method for model loading, mirroring the old usage with AutoencoderKL.
|
90 |
+
"""
|
91 |
+
def __init__(self, device: str = "cpu"):
|
92 |
+
super().__init__()
|
93 |
+
# Instantiate the shallow linear model
|
94 |
+
self.model = ShallowLinearAutoencoder(input_shape=INPUT_SHAPE_AE, hidden_size=1200)
|
95 |
+
self.to(device)
|
96 |
+
|
97 |
+
def forward(self, x: torch.Tensor):
|
98 |
+
"""
|
99 |
+
Forward pass that returns (reconstruction, embedding, None).
|
100 |
+
"""
|
101 |
+
return self.model(x)
|
102 |
+
|
103 |
+
@staticmethod
|
104 |
+
def from_pretrained(
|
105 |
+
checkpoint_path: Optional[str] = None,
|
106 |
+
device: str = "cpu"
|
107 |
+
) -> nn.Module:
|
108 |
+
"""
|
109 |
+
Load a pretrained ShallowLinearAutoencoder if a checkpoint path is provided.
|
110 |
+
Args:
|
111 |
+
checkpoint_path (Optional[str]): path to a .pth checkpoint
|
112 |
+
device (str): "cpu", "cuda", etc.
|
113 |
+
"""
|
114 |
+
model = Brain2vec(device=device)
|
115 |
+
if checkpoint_path is not None:
|
116 |
+
if not os.path.exists(checkpoint_path):
|
117 |
+
raise FileNotFoundError(f"Checkpoint {checkpoint_path} not found.")
|
118 |
+
state_dict = torch.load(checkpoint_path, map_location=device)
|
119 |
+
model.load_state_dict(state_dict)
|
120 |
+
model.eval()
|
121 |
+
return model
|
requirements.txt
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# requirements.txt
|
2 |
+
|
3 |
+
# PyTorch (CUDA or CPU version). For GPU install, see PyTorch docs for the correct wheel.
|
4 |
+
torch>=1.12
|
5 |
+
|
6 |
+
# MONAI v1.2+ has the 'generative' subpackage with AutoencoderKL, PatchDiscriminator, etc.
|
7 |
+
monai-weekly
|
8 |
+
monai-generative
|
9 |
+
|
10 |
+
# Common Python libraries
|
11 |
+
pandas
|
12 |
+
numpy
|
13 |
+
nibabel
|
14 |
+
matplotlib
|
15 |
+
datasets
|