Code changes
Browse files- brain2vec_PCA.py +108 -55
- requirements.txt +2 -1
brain2vec_PCA.py
CHANGED
@@ -3,12 +3,16 @@
|
|
3 |
"""
|
4 |
pca_autoencoder.py
|
5 |
|
6 |
-
|
7 |
-
1
|
8 |
-
2
|
9 |
-
|
10 |
-
|
11 |
-
|
|
|
|
|
|
|
|
|
12 |
"""
|
13 |
|
14 |
import os
|
@@ -22,17 +26,20 @@ from torch.utils.data import DataLoader
|
|
22 |
from monai import transforms
|
23 |
from monai.data import Dataset, PersistentDataset
|
24 |
|
25 |
-
|
|
|
|
|
26 |
|
27 |
###################################################################
|
28 |
# Constants for your typical config
|
29 |
###################################################################
|
30 |
RESOLUTION = 2
|
31 |
INPUT_SHAPE_AE = (80, 96, 80)
|
32 |
-
|
|
|
33 |
|
34 |
###################################################################
|
35 |
-
# Helper
|
36 |
###################################################################
|
37 |
def get_dataset_from_pd(df: pd.DataFrame, transforms_fn, cache_dir: str):
|
38 |
"""
|
@@ -50,35 +57,57 @@ def get_dataset_from_pd(df: pd.DataFrame, transforms_fn, cache_dir: str):
|
|
50 |
return dataset
|
51 |
|
52 |
|
|
|
|
|
|
|
53 |
class PCAAutoencoder:
|
54 |
"""
|
55 |
-
A PCA 'autoencoder'
|
56 |
-
|
57 |
-
- fit(X): partial fit on batches
|
58 |
- transform(X): get embeddings
|
59 |
-
- inverse_transform(Z): reconstruct from embeddings
|
60 |
-
- forward(X): returns (X_recon, Z)
|
61 |
-
|
|
|
|
|
62 |
"""
|
63 |
-
def __init__(self, n_components=
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
self.n_components = n_components
|
65 |
self.batch_size = batch_size
|
66 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
|
68 |
def fit(self, X: np.ndarray):
|
69 |
"""
|
70 |
-
|
71 |
-
|
|
|
72 |
"""
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
|
|
|
|
|
|
|
|
|
|
77 |
|
78 |
def transform(self, X: np.ndarray) -> np.ndarray:
|
79 |
"""
|
80 |
-
|
81 |
-
Returns Z
|
82 |
"""
|
83 |
results = []
|
84 |
n_samples = X.shape[0]
|
@@ -91,7 +120,7 @@ class PCAAutoencoder:
|
|
91 |
def inverse_transform(self, Z: np.ndarray) -> np.ndarray:
|
92 |
"""
|
93 |
Reconstruct data from PCA latent space in batches.
|
94 |
-
Returns X_recon
|
95 |
"""
|
96 |
results = []
|
97 |
n_samples = Z.shape[0]
|
@@ -110,46 +139,65 @@ class PCAAutoencoder:
|
|
110 |
return X_recon, Z
|
111 |
|
112 |
|
|
|
|
|
|
|
113 |
def load_and_flatten_dataset(csv_path: str, cache_dir: str, transforms_fn) -> np.ndarray:
|
114 |
"""
|
115 |
-
|
116 |
-
|
117 |
-
|
|
|
118 |
"""
|
119 |
df = pd.read_csv(csv_path)
|
120 |
-
dataset = get_dataset_from_pd(df, transforms_fn, cache_dir)
|
121 |
|
122 |
-
#
|
123 |
-
|
|
|
|
|
124 |
|
125 |
-
|
126 |
-
# or multi-worker DataLoader for speed.
|
127 |
-
# We'll demonstrate a simple single-worker here for clarity.
|
128 |
loader = DataLoader(dataset, batch_size=1, num_workers=0)
|
129 |
|
|
|
|
|
130 |
for batch in loader:
|
131 |
-
# batch["image"] shape
|
132 |
-
img = batch["image"].squeeze(0) #
|
133 |
-
img_np = img.numpy()
|
134 |
-
flattened = img_np.flatten()
|
135 |
X_list.append(flattened)
|
136 |
|
137 |
-
|
|
|
|
|
|
|
138 |
return X
|
139 |
|
140 |
|
|
|
|
|
|
|
141 |
def main():
|
142 |
-
parser = argparse.ArgumentParser(description="PCA Autoencoder with MONAI transforms
|
143 |
-
parser.add_argument("--inputs_csv", type=str, required=True,
|
144 |
-
|
145 |
-
parser.add_argument("--
|
146 |
-
|
147 |
-
parser.add_argument("--
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
args = parser.parse_args()
|
149 |
|
150 |
os.makedirs(args.output_dir, exist_ok=True)
|
151 |
|
152 |
-
#
|
153 |
transforms_fn = transforms.Compose([
|
154 |
transforms.CopyItemsD(keys={'image_path'}, names=['image']),
|
155 |
transforms.LoadImageD(image_only=True, keys=['image']),
|
@@ -163,27 +211,32 @@ def main():
|
|
163 |
X = load_and_flatten_dataset(args.inputs_csv, args.cache_dir, transforms_fn)
|
164 |
print(f"Dataset shape after flattening: {X.shape}")
|
165 |
|
166 |
-
# Build PCAAutoencoder
|
167 |
-
model = PCAAutoencoder(
|
|
|
|
|
|
|
|
|
168 |
|
169 |
# Fit the PCA model
|
170 |
-
print("Fitting
|
171 |
model.fit(X)
|
172 |
print("Done fitting PCA. Transforming data to embeddings...")
|
173 |
|
174 |
# Get embeddings & reconstruction
|
175 |
X_recon, Z = model.forward(X)
|
176 |
-
print("Embeddings shape:", Z.shape)
|
177 |
-
print("Reconstruction shape:", X_recon.shape)
|
178 |
|
179 |
-
#
|
180 |
embeddings_path = os.path.join(args.output_dir, "pca_embeddings.npy")
|
181 |
recons_path = os.path.join(args.output_dir, "pca_reconstructions.npy")
|
182 |
np.save(embeddings_path, Z)
|
183 |
np.save(recons_path, X_recon)
|
184 |
-
print(f"Saved embeddings to {embeddings_path}
|
|
|
185 |
|
186 |
-
#
|
187 |
# from joblib import dump
|
188 |
# ipca_model_path = os.path.join(args.output_dir, "pca_model.joblib")
|
189 |
# dump(model.ipca, ipca_model_path)
|
|
|
3 |
"""
|
4 |
pca_autoencoder.py
|
5 |
|
6 |
+
Adjustments requested:
|
7 |
+
1. Only fit on scans with a 'train' label in the inputs.csv 'split' column.
|
8 |
+
2. An option to either run incremental PCA or standard PCA.
|
9 |
+
|
10 |
+
Example usage:
|
11 |
+
python pca_autoencoder.py \
|
12 |
+
--inputs_csv /path/to/inputs.csv \
|
13 |
+
--output_dir ./pca_outputs \
|
14 |
+
--pca_type standard \
|
15 |
+
--n_components 100
|
16 |
"""
|
17 |
|
18 |
import os
|
|
|
26 |
from monai import transforms
|
27 |
from monai.data import Dataset, PersistentDataset
|
28 |
|
29 |
+
# We'll import both PCA classes, and decide which to use based on CLI arg.
|
30 |
+
from sklearn.decomposition import PCA, IncrementalPCA
|
31 |
+
|
32 |
|
33 |
###################################################################
|
34 |
# Constants for your typical config
|
35 |
###################################################################
|
36 |
RESOLUTION = 2
|
37 |
INPUT_SHAPE_AE = (80, 96, 80)
|
38 |
+
DEFAULT_N_COMPONENTS = 1200
|
39 |
+
|
40 |
|
41 |
###################################################################
|
42 |
+
# Helper: get_dataset_from_pd (same as in brain2vec_linearAE.py)
|
43 |
###################################################################
|
44 |
def get_dataset_from_pd(df: pd.DataFrame, transforms_fn, cache_dir: str):
|
45 |
"""
|
|
|
57 |
return dataset
|
58 |
|
59 |
|
60 |
+
###################################################################
|
61 |
+
# PCAAutoencoder
|
62 |
+
###################################################################
|
63 |
class PCAAutoencoder:
|
64 |
"""
|
65 |
+
A PCA 'autoencoder' that can use either standard PCA or IncrementalPCA:
|
66 |
+
- fit(X): trains the model
|
|
|
67 |
- transform(X): get embeddings
|
68 |
+
- inverse_transform(Z): reconstruct data from embeddings
|
69 |
+
- forward(X): returns (X_recon, Z)
|
70 |
+
|
71 |
+
If using standard PCA, we do a single call to .fit(X).
|
72 |
+
If using incremental PCA, we do .partial_fit on data in batches.
|
73 |
"""
|
74 |
+
def __init__(self, n_components=DEFAULT_N_COMPONENTS, batch_size=128, pca_type='incremental'):
|
75 |
+
"""
|
76 |
+
Args:
|
77 |
+
n_components (int): number of principal components to keep
|
78 |
+
batch_size (int): chunk size for either partial_fit or chunked .transform
|
79 |
+
pca_type (str): 'incremental' or 'standard'
|
80 |
+
"""
|
81 |
self.n_components = n_components
|
82 |
self.batch_size = batch_size
|
83 |
+
self.pca_type = pca_type.lower()
|
84 |
+
|
85 |
+
if self.pca_type == 'standard':
|
86 |
+
self.ipca = PCA(n_components=self.n_components, svd_solver='randomized')
|
87 |
+
else:
|
88 |
+
# default to incremental
|
89 |
+
self.ipca = IncrementalPCA(n_components=self.n_components)
|
90 |
|
91 |
def fit(self, X: np.ndarray):
|
92 |
"""
|
93 |
+
Fit the PCA model. If incremental, calls partial_fit in batches.
|
94 |
+
If standard, calls .fit once on the entire data matrix.
|
95 |
+
X: shape (n_samples, n_features)
|
96 |
"""
|
97 |
+
if self.pca_type == 'standard':
|
98 |
+
# Potentially large memory usage, so be sure your system can handle it.
|
99 |
+
self.ipca.fit(X)
|
100 |
+
else:
|
101 |
+
# IncrementalPCA
|
102 |
+
n_samples = X.shape[0]
|
103 |
+
for start_idx in range(0, n_samples, self.batch_size):
|
104 |
+
end_idx = min(start_idx + self.batch_size, n_samples)
|
105 |
+
self.ipca.partial_fit(X[start_idx:end_idx])
|
106 |
|
107 |
def transform(self, X: np.ndarray) -> np.ndarray:
|
108 |
"""
|
109 |
+
Project data into the PCA latent space in batches for memory efficiency.
|
110 |
+
Returns Z with shape (n_samples, n_components)
|
111 |
"""
|
112 |
results = []
|
113 |
n_samples = X.shape[0]
|
|
|
120 |
def inverse_transform(self, Z: np.ndarray) -> np.ndarray:
|
121 |
"""
|
122 |
Reconstruct data from PCA latent space in batches.
|
123 |
+
Returns X_recon with shape (n_samples, n_features).
|
124 |
"""
|
125 |
results = []
|
126 |
n_samples = Z.shape[0]
|
|
|
139 |
return X_recon, Z
|
140 |
|
141 |
|
142 |
+
###################################################################
|
143 |
+
# Load and Flatten Data
|
144 |
+
###################################################################
|
145 |
def load_and_flatten_dataset(csv_path: str, cache_dir: str, transforms_fn) -> np.ndarray:
|
146 |
"""
|
147 |
+
1) Reads CSV.
|
148 |
+
2) Filters rows if 'split' in columns => only keep 'split' == 'train'.
|
149 |
+
3) Applies transforms to each image, flattening them into a 1D vector (614,400).
|
150 |
+
4) Returns a NumPy array X: shape (n_samples, 614400).
|
151 |
"""
|
152 |
df = pd.read_csv(csv_path)
|
|
|
153 |
|
154 |
+
# Filter only 'train' if the split column exists
|
155 |
+
if 'split' in df.columns:
|
156 |
+
df = df[df['split'] == 'train']
|
157 |
+
# If there is no 'split' column, we assume the entire CSV is for training.
|
158 |
|
159 |
+
dataset = get_dataset_from_pd(df, transforms_fn, cache_dir)
|
|
|
|
|
160 |
loader = DataLoader(dataset, batch_size=1, num_workers=0)
|
161 |
|
162 |
+
# We'll store each flattened volume in a list, then stack
|
163 |
+
X_list = []
|
164 |
for batch in loader:
|
165 |
+
# batch["image"] shape => (1, 1, 80, 96, 80)
|
166 |
+
img = batch["image"].squeeze(0) # => (1, 80, 96, 80)
|
167 |
+
img_np = img.numpy()
|
168 |
+
flattened = img_np.flatten() # => (614400,)
|
169 |
X_list.append(flattened)
|
170 |
|
171 |
+
if len(X_list) == 0:
|
172 |
+
raise ValueError("No training samples found (split='train'). Check your CSV or 'split' values.")
|
173 |
+
|
174 |
+
X = np.vstack(X_list)
|
175 |
return X
|
176 |
|
177 |
|
178 |
+
###################################################################
|
179 |
+
# Main
|
180 |
+
###################################################################
|
181 |
def main():
|
182 |
+
parser = argparse.ArgumentParser(description="PCA Autoencoder with MONAI transforms and 'split' filtering.")
|
183 |
+
parser.add_argument("--inputs_csv", type=str, required=True,
|
184 |
+
help="Path to CSV with at least 'image_path' column, optional 'split' column.")
|
185 |
+
parser.add_argument("--cache_dir", type=str, default="",
|
186 |
+
help="Cache directory for MONAI PersistentDataset (optional).")
|
187 |
+
parser.add_argument("--output_dir", type=str, default="./pca_outputs",
|
188 |
+
help="Where to save PCA model and embeddings.")
|
189 |
+
parser.add_argument("--batch_size_ipca", type=int, default=128,
|
190 |
+
help="Batch size for partial_fit or chunked transform.")
|
191 |
+
parser.add_argument("--n_components", type=int, default=1200,
|
192 |
+
help="Number of PCA components to keep.")
|
193 |
+
parser.add_argument("--pca_type", type=str, default="incremental",
|
194 |
+
choices=["incremental", "standard"],
|
195 |
+
help="Which PCA algorithm to use: 'incremental' or 'standard'.")
|
196 |
args = parser.parse_args()
|
197 |
|
198 |
os.makedirs(args.output_dir, exist_ok=True)
|
199 |
|
200 |
+
# define transforms as in brain2vec_linearAE.py
|
201 |
transforms_fn = transforms.Compose([
|
202 |
transforms.CopyItemsD(keys={'image_path'}, names=['image']),
|
203 |
transforms.LoadImageD(image_only=True, keys=['image']),
|
|
|
211 |
X = load_and_flatten_dataset(args.inputs_csv, args.cache_dir, transforms_fn)
|
212 |
print(f"Dataset shape after flattening: {X.shape}")
|
213 |
|
214 |
+
# Build the PCAAutoencoder with chosen type
|
215 |
+
model = PCAAutoencoder(
|
216 |
+
n_components=args.n_components,
|
217 |
+
batch_size=args.batch_size_ipca,
|
218 |
+
pca_type=args.pca_type
|
219 |
+
)
|
220 |
|
221 |
# Fit the PCA model
|
222 |
+
print(f"Fitting {args.pca_type.capitalize()}PCA in batches...")
|
223 |
model.fit(X)
|
224 |
print("Done fitting PCA. Transforming data to embeddings...")
|
225 |
|
226 |
# Get embeddings & reconstruction
|
227 |
X_recon, Z = model.forward(X)
|
228 |
+
print("Embeddings shape:", Z.shape) # (n_samples, n_components)
|
229 |
+
print("Reconstruction shape:", X_recon.shape) # (n_samples, 614400)
|
230 |
|
231 |
+
# Save
|
232 |
embeddings_path = os.path.join(args.output_dir, "pca_embeddings.npy")
|
233 |
recons_path = os.path.join(args.output_dir, "pca_reconstructions.npy")
|
234 |
np.save(embeddings_path, Z)
|
235 |
np.save(recons_path, X_recon)
|
236 |
+
print(f"Saved embeddings to {embeddings_path}")
|
237 |
+
print(f"Saved reconstructions to {recons_path}")
|
238 |
|
239 |
+
# Optionally save the actual PCA model with joblib
|
240 |
# from joblib import dump
|
241 |
# ipca_model_path = os.path.join(args.output_dir, "pca_model.joblib")
|
242 |
# dump(model.ipca, ipca_model_path)
|
requirements.txt
CHANGED
@@ -12,4 +12,5 @@ pandas
|
|
12 |
numpy
|
13 |
nibabel
|
14 |
matplotlib
|
15 |
-
datasets
|
|
|
|
12 |
numpy
|
13 |
nibabel
|
14 |
matplotlib
|
15 |
+
datasets
|
16 |
+
scikit-learn
|