|
|
|
|
|
""" |
|
pca_autoencoder.py |
|
|
|
This script demonstrates how to: |
|
1) Load a dataset of MRI volumes using MONAI transforms (as in brain2vec_linearAE.py). |
|
2) Flatten each 3D volume into a 1D vector (614,400 features if 80x96x80). |
|
3) Perform IncrementalPCA to reduce dimensionality to 1200 components. |
|
4) Provide a 'forward()' method that returns (reconstruction, embedding), |
|
mimicking the interface of a linear autoencoder. |
|
""" |
|
|
|
import os |
|
import argparse |
|
import numpy as np |
|
import pandas as pd |
|
|
|
import torch |
|
from torch.utils.data import DataLoader |
|
|
|
from monai import transforms |
|
from monai.data import Dataset, PersistentDataset |
|
|
|
from sklearn.decomposition import IncrementalPCA |
|
|
|
|
|
|
|
|
|
RESOLUTION = 2 |
|
INPUT_SHAPE_AE = (80, 96, 80) |
|
N_COMPONENTS = 1200 |
|
|
|
|
|
|
|
|
|
def get_dataset_from_pd(df: pd.DataFrame, transforms_fn, cache_dir: str): |
|
""" |
|
Returns a monai.data.Dataset or monai.data.PersistentDataset |
|
if `cache_dir` is defined, to speed up loading. |
|
""" |
|
if cache_dir and cache_dir.strip(): |
|
os.makedirs(cache_dir, exist_ok=True) |
|
dataset = PersistentDataset(data=df.to_dict(orient='records'), |
|
transform=transforms_fn, |
|
cache_dir=cache_dir) |
|
else: |
|
dataset = Dataset(data=df.to_dict(orient='records'), |
|
transform=transforms_fn) |
|
return dataset |
|
|
|
|
|
class PCAAutoencoder: |
|
""" |
|
A PCA 'autoencoder' using IncrementalPCA for memory efficiency, |
|
providing: |
|
- fit(X): partial fit on batches |
|
- transform(X): get embeddings |
|
- inverse_transform(Z): reconstruct from embeddings |
|
- forward(X): returns (X_recon, Z) for a direct API |
|
similar to a shallow linear AE. |
|
""" |
|
def __init__(self, n_components=N_COMPONENTS, batch_size=128): |
|
self.n_components = n_components |
|
self.batch_size = batch_size |
|
self.ipca = IncrementalPCA(n_components=self.n_components) |
|
|
|
def fit(self, X: np.ndarray): |
|
""" |
|
Incrementally fit the PCA model on batches of data. |
|
X: shape (n_samples, n_features). |
|
""" |
|
n_samples = X.shape[0] |
|
for start_idx in range(0, n_samples, self.batch_size): |
|
end_idx = min(start_idx + self.batch_size, n_samples) |
|
self.ipca.partial_fit(X[start_idx:end_idx]) |
|
|
|
def transform(self, X: np.ndarray) -> np.ndarray: |
|
""" |
|
Projects data into the PCA latent space in batches. |
|
Returns Z: shape (n_samples, n_components). |
|
""" |
|
results = [] |
|
n_samples = X.shape[0] |
|
for start_idx in range(0, n_samples, self.batch_size): |
|
end_idx = min(start_idx + self.batch_size, n_samples) |
|
Z_chunk = self.ipca.transform(X[start_idx:end_idx]) |
|
results.append(Z_chunk) |
|
return np.vstack(results) |
|
|
|
def inverse_transform(self, Z: np.ndarray) -> np.ndarray: |
|
""" |
|
Reconstruct data from PCA latent space in batches. |
|
Returns X_recon: shape (n_samples, n_features). |
|
""" |
|
results = [] |
|
n_samples = Z.shape[0] |
|
for start_idx in range(0, n_samples, self.batch_size): |
|
end_idx = min(start_idx + self.batch_size, n_samples) |
|
X_chunk = self.ipca.inverse_transform(Z[start_idx:end_idx]) |
|
results.append(X_chunk) |
|
return np.vstack(results) |
|
|
|
def forward(self, X: np.ndarray) -> tuple[np.ndarray, np.ndarray]: |
|
""" |
|
Mimics a linear AE's forward() returning (X_recon, Z). |
|
""" |
|
Z = self.transform(X) |
|
X_recon = self.inverse_transform(Z) |
|
return X_recon, Z |
|
|
|
|
|
def load_and_flatten_dataset(csv_path: str, cache_dir: str, transforms_fn) -> np.ndarray: |
|
""" |
|
Loads the dataset from csv_path, applies the monai transforms, |
|
and flattens each 3D MRI into a 1D vector of shape (80*96*80). |
|
Returns a numpy array X with shape (n_samples, 614400). |
|
""" |
|
df = pd.read_csv(csv_path) |
|
dataset = get_dataset_from_pd(df, transforms_fn, cache_dir) |
|
|
|
|
|
X_list = [] |
|
|
|
|
|
|
|
|
|
loader = DataLoader(dataset, batch_size=1, num_workers=0) |
|
|
|
for batch in loader: |
|
|
|
img = batch["image"].squeeze(0) |
|
img_np = img.numpy() |
|
flattened = img_np.flatten() |
|
X_list.append(flattened) |
|
|
|
X = np.vstack(X_list) |
|
return X |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser(description="PCA Autoencoder with MONAI transforms example.") |
|
parser.add_argument("--inputs_csv", type=str, required=True, help="CSV with 'image_path' column.") |
|
parser.add_argument("--cache_dir", type=str, default="", help="Cache directory for MONAI PersistentDataset.") |
|
parser.add_argument("--output_dir", type=str, default="./pca_outputs", help="Where to save PCA model and embeddings.") |
|
parser.add_argument("--batch_size_ipca", type=int, default=128, help="Batch size for IncrementalPCA partial_fit().") |
|
parser.add_argument("--n_components", type=int, default=1200, help="Number of PCA components.") |
|
args = parser.parse_args() |
|
|
|
os.makedirs(args.output_dir, exist_ok=True) |
|
|
|
|
|
transforms_fn = transforms.Compose([ |
|
transforms.CopyItemsD(keys={'image_path'}, names=['image']), |
|
transforms.LoadImageD(image_only=True, keys=['image']), |
|
transforms.EnsureChannelFirstD(keys=['image']), |
|
transforms.SpacingD(pixdim=RESOLUTION, keys=['image']), |
|
transforms.ResizeWithPadOrCropD(spatial_size=INPUT_SHAPE_AE, mode='minimum', keys=['image']), |
|
transforms.ScaleIntensityD(minv=0, maxv=1, keys=['image']), |
|
]) |
|
|
|
print("Loading and flattening dataset from:", args.inputs_csv) |
|
X = load_and_flatten_dataset(args.inputs_csv, args.cache_dir, transforms_fn) |
|
print(f"Dataset shape after flattening: {X.shape}") |
|
|
|
|
|
model = PCAAutoencoder(n_components=args.n_components, batch_size=args.batch_size_ipca) |
|
|
|
|
|
print("Fitting IncrementalPCA in batches...") |
|
model.fit(X) |
|
print("Done fitting PCA. Transforming data to embeddings...") |
|
|
|
|
|
X_recon, Z = model.forward(X) |
|
print("Embeddings shape:", Z.shape) |
|
print("Reconstruction shape:", X_recon.shape) |
|
|
|
|
|
embeddings_path = os.path.join(args.output_dir, "pca_embeddings.npy") |
|
recons_path = os.path.join(args.output_dir, "pca_reconstructions.npy") |
|
np.save(embeddings_path, Z) |
|
np.save(recons_path, X_recon) |
|
print(f"Saved embeddings to {embeddings_path} and reconstructions to {recons_path}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
main() |