jesseab commited on
Commit
653468a
·
1 Parent(s): bb04d63

Code changes

Browse files
Files changed (2) hide show
  1. brain2vec_PCA.py +108 -55
  2. requirements.txt +2 -1
brain2vec_PCA.py CHANGED
@@ -3,12 +3,16 @@
3
  """
4
  pca_autoencoder.py
5
 
6
- This script demonstrates how to:
7
- 1) Load a dataset of MRI volumes using MONAI transforms (as in brain2vec_linearAE.py).
8
- 2) Flatten each 3D volume into a 1D vector (614,400 features if 80x96x80).
9
- 3) Perform IncrementalPCA to reduce dimensionality to 1200 components.
10
- 4) Provide a 'forward()' method that returns (reconstruction, embedding),
11
- mimicking the interface of a linear autoencoder.
 
 
 
 
12
  """
13
 
14
  import os
@@ -22,17 +26,20 @@ from torch.utils.data import DataLoader
22
  from monai import transforms
23
  from monai.data import Dataset, PersistentDataset
24
 
25
- from sklearn.decomposition import IncrementalPCA
 
 
26
 
27
  ###################################################################
28
  # Constants for your typical config
29
  ###################################################################
30
  RESOLUTION = 2
31
  INPUT_SHAPE_AE = (80, 96, 80)
32
- N_COMPONENTS = 1200
 
33
 
34
  ###################################################################
35
- # Helper classes/functions
36
  ###################################################################
37
  def get_dataset_from_pd(df: pd.DataFrame, transforms_fn, cache_dir: str):
38
  """
@@ -50,35 +57,57 @@ def get_dataset_from_pd(df: pd.DataFrame, transforms_fn, cache_dir: str):
50
  return dataset
51
 
52
 
 
 
 
53
  class PCAAutoencoder:
54
  """
55
- A PCA 'autoencoder' using IncrementalPCA for memory efficiency,
56
- providing:
57
- - fit(X): partial fit on batches
58
  - transform(X): get embeddings
59
- - inverse_transform(Z): reconstruct from embeddings
60
- - forward(X): returns (X_recon, Z) for a direct API
61
- similar to a shallow linear AE.
 
 
62
  """
63
- def __init__(self, n_components=N_COMPONENTS, batch_size=128):
 
 
 
 
 
 
64
  self.n_components = n_components
65
  self.batch_size = batch_size
66
- self.ipca = IncrementalPCA(n_components=self.n_components)
 
 
 
 
 
 
67
 
68
  def fit(self, X: np.ndarray):
69
  """
70
- Incrementally fit the PCA model on batches of data.
71
- X: shape (n_samples, n_features).
 
72
  """
73
- n_samples = X.shape[0]
74
- for start_idx in range(0, n_samples, self.batch_size):
75
- end_idx = min(start_idx + self.batch_size, n_samples)
76
- self.ipca.partial_fit(X[start_idx:end_idx])
 
 
 
 
 
77
 
78
  def transform(self, X: np.ndarray) -> np.ndarray:
79
  """
80
- Projects data into the PCA latent space in batches.
81
- Returns Z: shape (n_samples, n_components).
82
  """
83
  results = []
84
  n_samples = X.shape[0]
@@ -91,7 +120,7 @@ class PCAAutoencoder:
91
  def inverse_transform(self, Z: np.ndarray) -> np.ndarray:
92
  """
93
  Reconstruct data from PCA latent space in batches.
94
- Returns X_recon: shape (n_samples, n_features).
95
  """
96
  results = []
97
  n_samples = Z.shape[0]
@@ -110,46 +139,65 @@ class PCAAutoencoder:
110
  return X_recon, Z
111
 
112
 
 
 
 
113
  def load_and_flatten_dataset(csv_path: str, cache_dir: str, transforms_fn) -> np.ndarray:
114
  """
115
- Loads the dataset from csv_path, applies the monai transforms,
116
- and flattens each 3D MRI into a 1D vector of shape (80*96*80).
117
- Returns a numpy array X with shape (n_samples, 614400).
 
118
  """
119
  df = pd.read_csv(csv_path)
120
- dataset = get_dataset_from_pd(df, transforms_fn, cache_dir)
121
 
122
- # We'll put the flattened data into this list, then stack.
123
- X_list = []
 
 
124
 
125
- # If memory allows, you can simply do a single-threaded loop
126
- # or multi-worker DataLoader for speed.
127
- # We'll demonstrate a simple single-worker here for clarity.
128
  loader = DataLoader(dataset, batch_size=1, num_workers=0)
129
 
 
 
130
  for batch in loader:
131
- # batch["image"] shape: (1, 1, 80, 96, 80)
132
- img = batch["image"].squeeze(0) # shape: (1, 80, 96, 80)
133
- img_np = img.numpy() # convert to np array, shape: (1, D, H, W)
134
- flattened = img_np.flatten() # shape: (614400,)
135
  X_list.append(flattened)
136
 
137
- X = np.vstack(X_list) # shape: (n_samples, 614400)
 
 
 
138
  return X
139
 
140
 
 
 
 
141
  def main():
142
- parser = argparse.ArgumentParser(description="PCA Autoencoder with MONAI transforms example.")
143
- parser.add_argument("--inputs_csv", type=str, required=True, help="CSV with 'image_path' column.")
144
- parser.add_argument("--cache_dir", type=str, default="", help="Cache directory for MONAI PersistentDataset.")
145
- parser.add_argument("--output_dir", type=str, default="./pca_outputs", help="Where to save PCA model and embeddings.")
146
- parser.add_argument("--batch_size_ipca", type=int, default=128, help="Batch size for IncrementalPCA partial_fit().")
147
- parser.add_argument("--n_components", type=int, default=1200, help="Number of PCA components.")
 
 
 
 
 
 
 
 
148
  args = parser.parse_args()
149
 
150
  os.makedirs(args.output_dir, exist_ok=True)
151
 
152
- # Same transforms as in brain2vec_linearAE.py
153
  transforms_fn = transforms.Compose([
154
  transforms.CopyItemsD(keys={'image_path'}, names=['image']),
155
  transforms.LoadImageD(image_only=True, keys=['image']),
@@ -163,27 +211,32 @@ def main():
163
  X = load_and_flatten_dataset(args.inputs_csv, args.cache_dir, transforms_fn)
164
  print(f"Dataset shape after flattening: {X.shape}")
165
 
166
- # Build PCAAutoencoder
167
- model = PCAAutoencoder(n_components=args.n_components, batch_size=args.batch_size_ipca)
 
 
 
 
168
 
169
  # Fit the PCA model
170
- print("Fitting IncrementalPCA in batches...")
171
  model.fit(X)
172
  print("Done fitting PCA. Transforming data to embeddings...")
173
 
174
  # Get embeddings & reconstruction
175
  X_recon, Z = model.forward(X)
176
- print("Embeddings shape:", Z.shape)
177
- print("Reconstruction shape:", X_recon.shape)
178
 
179
- # Optional: Save
180
  embeddings_path = os.path.join(args.output_dir, "pca_embeddings.npy")
181
  recons_path = os.path.join(args.output_dir, "pca_reconstructions.npy")
182
  np.save(embeddings_path, Z)
183
  np.save(recons_path, X_recon)
184
- print(f"Saved embeddings to {embeddings_path} and reconstructions to {recons_path}")
 
185
 
186
- # If you want to store the actual PCA components for future usage:
187
  # from joblib import dump
188
  # ipca_model_path = os.path.join(args.output_dir, "pca_model.joblib")
189
  # dump(model.ipca, ipca_model_path)
 
3
  """
4
  pca_autoencoder.py
5
 
6
+ Adjustments requested:
7
+ 1. Only fit on scans with a 'train' label in the inputs.csv 'split' column.
8
+ 2. An option to either run incremental PCA or standard PCA.
9
+
10
+ Example usage:
11
+ python pca_autoencoder.py \
12
+ --inputs_csv /path/to/inputs.csv \
13
+ --output_dir ./pca_outputs \
14
+ --pca_type standard \
15
+ --n_components 100
16
  """
17
 
18
  import os
 
26
  from monai import transforms
27
  from monai.data import Dataset, PersistentDataset
28
 
29
+ # We'll import both PCA classes, and decide which to use based on CLI arg.
30
+ from sklearn.decomposition import PCA, IncrementalPCA
31
+
32
 
33
  ###################################################################
34
  # Constants for your typical config
35
  ###################################################################
36
  RESOLUTION = 2
37
  INPUT_SHAPE_AE = (80, 96, 80)
38
+ DEFAULT_N_COMPONENTS = 1200
39
+
40
 
41
  ###################################################################
42
+ # Helper: get_dataset_from_pd (same as in brain2vec_linearAE.py)
43
  ###################################################################
44
  def get_dataset_from_pd(df: pd.DataFrame, transforms_fn, cache_dir: str):
45
  """
 
57
  return dataset
58
 
59
 
60
+ ###################################################################
61
+ # PCAAutoencoder
62
+ ###################################################################
63
  class PCAAutoencoder:
64
  """
65
+ A PCA 'autoencoder' that can use either standard PCA or IncrementalPCA:
66
+ - fit(X): trains the model
 
67
  - transform(X): get embeddings
68
+ - inverse_transform(Z): reconstruct data from embeddings
69
+ - forward(X): returns (X_recon, Z)
70
+
71
+ If using standard PCA, we do a single call to .fit(X).
72
+ If using incremental PCA, we do .partial_fit on data in batches.
73
  """
74
+ def __init__(self, n_components=DEFAULT_N_COMPONENTS, batch_size=128, pca_type='incremental'):
75
+ """
76
+ Args:
77
+ n_components (int): number of principal components to keep
78
+ batch_size (int): chunk size for either partial_fit or chunked .transform
79
+ pca_type (str): 'incremental' or 'standard'
80
+ """
81
  self.n_components = n_components
82
  self.batch_size = batch_size
83
+ self.pca_type = pca_type.lower()
84
+
85
+ if self.pca_type == 'standard':
86
+ self.ipca = PCA(n_components=self.n_components, svd_solver='randomized')
87
+ else:
88
+ # default to incremental
89
+ self.ipca = IncrementalPCA(n_components=self.n_components)
90
 
91
  def fit(self, X: np.ndarray):
92
  """
93
+ Fit the PCA model. If incremental, calls partial_fit in batches.
94
+ If standard, calls .fit once on the entire data matrix.
95
+ X: shape (n_samples, n_features)
96
  """
97
+ if self.pca_type == 'standard':
98
+ # Potentially large memory usage, so be sure your system can handle it.
99
+ self.ipca.fit(X)
100
+ else:
101
+ # IncrementalPCA
102
+ n_samples = X.shape[0]
103
+ for start_idx in range(0, n_samples, self.batch_size):
104
+ end_idx = min(start_idx + self.batch_size, n_samples)
105
+ self.ipca.partial_fit(X[start_idx:end_idx])
106
 
107
  def transform(self, X: np.ndarray) -> np.ndarray:
108
  """
109
+ Project data into the PCA latent space in batches for memory efficiency.
110
+ Returns Z with shape (n_samples, n_components)
111
  """
112
  results = []
113
  n_samples = X.shape[0]
 
120
  def inverse_transform(self, Z: np.ndarray) -> np.ndarray:
121
  """
122
  Reconstruct data from PCA latent space in batches.
123
+ Returns X_recon with shape (n_samples, n_features).
124
  """
125
  results = []
126
  n_samples = Z.shape[0]
 
139
  return X_recon, Z
140
 
141
 
142
+ ###################################################################
143
+ # Load and Flatten Data
144
+ ###################################################################
145
  def load_and_flatten_dataset(csv_path: str, cache_dir: str, transforms_fn) -> np.ndarray:
146
  """
147
+ 1) Reads CSV.
148
+ 2) Filters rows if 'split' in columns => only keep 'split' == 'train'.
149
+ 3) Applies transforms to each image, flattening them into a 1D vector (614,400).
150
+ 4) Returns a NumPy array X: shape (n_samples, 614400).
151
  """
152
  df = pd.read_csv(csv_path)
 
153
 
154
+ # Filter only 'train' if the split column exists
155
+ if 'split' in df.columns:
156
+ df = df[df['split'] == 'train']
157
+ # If there is no 'split' column, we assume the entire CSV is for training.
158
 
159
+ dataset = get_dataset_from_pd(df, transforms_fn, cache_dir)
 
 
160
  loader = DataLoader(dataset, batch_size=1, num_workers=0)
161
 
162
+ # We'll store each flattened volume in a list, then stack
163
+ X_list = []
164
  for batch in loader:
165
+ # batch["image"] shape => (1, 1, 80, 96, 80)
166
+ img = batch["image"].squeeze(0) # => (1, 80, 96, 80)
167
+ img_np = img.numpy()
168
+ flattened = img_np.flatten() # => (614400,)
169
  X_list.append(flattened)
170
 
171
+ if len(X_list) == 0:
172
+ raise ValueError("No training samples found (split='train'). Check your CSV or 'split' values.")
173
+
174
+ X = np.vstack(X_list)
175
  return X
176
 
177
 
178
+ ###################################################################
179
+ # Main
180
+ ###################################################################
181
  def main():
182
+ parser = argparse.ArgumentParser(description="PCA Autoencoder with MONAI transforms and 'split' filtering.")
183
+ parser.add_argument("--inputs_csv", type=str, required=True,
184
+ help="Path to CSV with at least 'image_path' column, optional 'split' column.")
185
+ parser.add_argument("--cache_dir", type=str, default="",
186
+ help="Cache directory for MONAI PersistentDataset (optional).")
187
+ parser.add_argument("--output_dir", type=str, default="./pca_outputs",
188
+ help="Where to save PCA model and embeddings.")
189
+ parser.add_argument("--batch_size_ipca", type=int, default=128,
190
+ help="Batch size for partial_fit or chunked transform.")
191
+ parser.add_argument("--n_components", type=int, default=1200,
192
+ help="Number of PCA components to keep.")
193
+ parser.add_argument("--pca_type", type=str, default="incremental",
194
+ choices=["incremental", "standard"],
195
+ help="Which PCA algorithm to use: 'incremental' or 'standard'.")
196
  args = parser.parse_args()
197
 
198
  os.makedirs(args.output_dir, exist_ok=True)
199
 
200
+ # define transforms as in brain2vec_linearAE.py
201
  transforms_fn = transforms.Compose([
202
  transforms.CopyItemsD(keys={'image_path'}, names=['image']),
203
  transforms.LoadImageD(image_only=True, keys=['image']),
 
211
  X = load_and_flatten_dataset(args.inputs_csv, args.cache_dir, transforms_fn)
212
  print(f"Dataset shape after flattening: {X.shape}")
213
 
214
+ # Build the PCAAutoencoder with chosen type
215
+ model = PCAAutoencoder(
216
+ n_components=args.n_components,
217
+ batch_size=args.batch_size_ipca,
218
+ pca_type=args.pca_type
219
+ )
220
 
221
  # Fit the PCA model
222
+ print(f"Fitting {args.pca_type.capitalize()}PCA in batches...")
223
  model.fit(X)
224
  print("Done fitting PCA. Transforming data to embeddings...")
225
 
226
  # Get embeddings & reconstruction
227
  X_recon, Z = model.forward(X)
228
+ print("Embeddings shape:", Z.shape) # (n_samples, n_components)
229
+ print("Reconstruction shape:", X_recon.shape) # (n_samples, 614400)
230
 
231
+ # Save
232
  embeddings_path = os.path.join(args.output_dir, "pca_embeddings.npy")
233
  recons_path = os.path.join(args.output_dir, "pca_reconstructions.npy")
234
  np.save(embeddings_path, Z)
235
  np.save(recons_path, X_recon)
236
+ print(f"Saved embeddings to {embeddings_path}")
237
+ print(f"Saved reconstructions to {recons_path}")
238
 
239
+ # Optionally save the actual PCA model with joblib
240
  # from joblib import dump
241
  # ipca_model_path = os.path.join(args.output_dir, "pca_model.joblib")
242
  # dump(model.ipca, ipca_model_path)
requirements.txt CHANGED
@@ -12,4 +12,5 @@ pandas
12
  numpy
13
  nibabel
14
  matplotlib
15
- datasets
 
 
12
  numpy
13
  nibabel
14
  matplotlib
15
+ datasets
16
+ scikit-learn