|
|
|
|
|
""" |
|
pca_autoencoder.py |
|
|
|
Adjustments requested: |
|
1. Only fit on scans with a 'train' label in the inputs.csv 'split' column. |
|
2. An option to either run incremental PCA or standard PCA. |
|
|
|
Example usage: |
|
python pca_autoencoder.py \ |
|
--inputs_csv /path/to/inputs.csv \ |
|
--output_dir ./pca_outputs \ |
|
--pca_type standard \ |
|
--n_components 100 |
|
""" |
|
|
|
import os |
|
import argparse |
|
import numpy as np |
|
import pandas as pd |
|
|
|
import torch |
|
from torch.utils.data import DataLoader |
|
|
|
from monai import transforms |
|
from monai.data import Dataset, PersistentDataset |
|
|
|
|
|
from sklearn.decomposition import PCA, IncrementalPCA |
|
|
|
|
|
|
|
|
|
|
|
RESOLUTION = 2 |
|
INPUT_SHAPE_AE = (80, 96, 80) |
|
DEFAULT_N_COMPONENTS = 1200 |
|
|
|
|
|
|
|
|
|
|
|
def get_dataset_from_pd(df: pd.DataFrame, transforms_fn, cache_dir: str): |
|
""" |
|
Returns a monai.data.Dataset or monai.data.PersistentDataset |
|
if `cache_dir` is defined, to speed up loading. |
|
""" |
|
if cache_dir and cache_dir.strip(): |
|
os.makedirs(cache_dir, exist_ok=True) |
|
dataset = PersistentDataset(data=df.to_dict(orient='records'), |
|
transform=transforms_fn, |
|
cache_dir=cache_dir) |
|
else: |
|
dataset = Dataset(data=df.to_dict(orient='records'), |
|
transform=transforms_fn) |
|
return dataset |
|
|
|
|
|
|
|
|
|
|
|
class PCAAutoencoder: |
|
""" |
|
A PCA 'autoencoder' that can use either standard PCA or IncrementalPCA: |
|
- fit(X): trains the model |
|
- transform(X): get embeddings |
|
- inverse_transform(Z): reconstruct data from embeddings |
|
- forward(X): returns (X_recon, Z) |
|
|
|
If using standard PCA, we do a single call to .fit(X). |
|
If using incremental PCA, we do .partial_fit on data in batches. |
|
""" |
|
def __init__(self, n_components=DEFAULT_N_COMPONENTS, batch_size=128, pca_type='incremental'): |
|
""" |
|
Args: |
|
n_components (int): number of principal components to keep |
|
batch_size (int): chunk size for either partial_fit or chunked .transform |
|
pca_type (str): 'incremental' or 'standard' |
|
""" |
|
self.n_components = n_components |
|
self.batch_size = batch_size |
|
self.pca_type = pca_type.lower() |
|
|
|
if self.pca_type == 'standard': |
|
self.ipca = PCA(n_components=self.n_components, svd_solver='randomized') |
|
else: |
|
|
|
self.ipca = IncrementalPCA(n_components=self.n_components) |
|
|
|
def fit(self, X: np.ndarray): |
|
""" |
|
Fit the PCA model. If incremental, calls partial_fit in batches. |
|
If standard, calls .fit once on the entire data matrix. |
|
X: shape (n_samples, n_features) |
|
""" |
|
if self.pca_type == 'standard': |
|
|
|
self.ipca.fit(X) |
|
else: |
|
|
|
n_samples = X.shape[0] |
|
for start_idx in range(0, n_samples, self.batch_size): |
|
end_idx = min(start_idx + self.batch_size, n_samples) |
|
self.ipca.partial_fit(X[start_idx:end_idx]) |
|
|
|
def transform(self, X: np.ndarray) -> np.ndarray: |
|
""" |
|
Project data into the PCA latent space in batches for memory efficiency. |
|
Returns Z with shape (n_samples, n_components) |
|
""" |
|
results = [] |
|
n_samples = X.shape[0] |
|
for start_idx in range(0, n_samples, self.batch_size): |
|
end_idx = min(start_idx + self.batch_size, n_samples) |
|
Z_chunk = self.ipca.transform(X[start_idx:end_idx]) |
|
results.append(Z_chunk) |
|
return np.vstack(results) |
|
|
|
def inverse_transform(self, Z: np.ndarray) -> np.ndarray: |
|
""" |
|
Reconstruct data from PCA latent space in batches. |
|
Returns X_recon with shape (n_samples, n_features). |
|
""" |
|
results = [] |
|
n_samples = Z.shape[0] |
|
for start_idx in range(0, n_samples, self.batch_size): |
|
end_idx = min(start_idx + self.batch_size, n_samples) |
|
X_chunk = self.ipca.inverse_transform(Z[start_idx:end_idx]) |
|
results.append(X_chunk) |
|
return np.vstack(results) |
|
|
|
def forward(self, X: np.ndarray) -> tuple[np.ndarray, np.ndarray]: |
|
""" |
|
Mimics a linear AE's forward() returning (X_recon, Z). |
|
""" |
|
Z = self.transform(X) |
|
X_recon = self.inverse_transform(Z) |
|
return X_recon, Z |
|
|
|
|
|
|
|
|
|
|
|
def load_and_flatten_dataset(csv_path: str, cache_dir: str, transforms_fn) -> np.ndarray: |
|
""" |
|
1) Reads CSV. |
|
2) Filters rows if 'split' in columns => only keep 'split' == 'train'. |
|
3) Applies transforms to each image, flattening them into a 1D vector (614,400). |
|
4) Returns a NumPy array X: shape (n_samples, 614400). |
|
""" |
|
df = pd.read_csv(csv_path) |
|
|
|
|
|
if 'split' in df.columns: |
|
df = df[df['split'] == 'train'] |
|
|
|
|
|
dataset = get_dataset_from_pd(df, transforms_fn, cache_dir) |
|
loader = DataLoader(dataset, batch_size=1, num_workers=0) |
|
|
|
|
|
X_list = [] |
|
for batch in loader: |
|
|
|
img = batch["image"].squeeze(0) |
|
img_np = img.numpy() |
|
flattened = img_np.flatten() |
|
X_list.append(flattened) |
|
|
|
if len(X_list) == 0: |
|
raise ValueError("No training samples found (split='train'). Check your CSV or 'split' values.") |
|
|
|
X = np.vstack(X_list) |
|
return X |
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser(description="PCA Autoencoder with MONAI transforms and 'split' filtering.") |
|
parser.add_argument("--inputs_csv", type=str, required=True, |
|
help="Path to CSV with at least 'image_path' column, optional 'split' column.") |
|
parser.add_argument("--cache_dir", type=str, default="", |
|
help="Cache directory for MONAI PersistentDataset (optional).") |
|
parser.add_argument("--output_dir", type=str, default="./pca_outputs", |
|
help="Where to save PCA model and embeddings.") |
|
parser.add_argument("--batch_size_ipca", type=int, default=128, |
|
help="Batch size for partial_fit or chunked transform.") |
|
parser.add_argument("--n_components", type=int, default=1200, |
|
help="Number of PCA components to keep.") |
|
parser.add_argument("--pca_type", type=str, default="incremental", |
|
choices=["incremental", "standard"], |
|
help="Which PCA algorithm to use: 'incremental' or 'standard'.") |
|
args = parser.parse_args() |
|
|
|
os.makedirs(args.output_dir, exist_ok=True) |
|
|
|
|
|
transforms_fn = transforms.Compose([ |
|
transforms.CopyItemsD(keys={'image_path'}, names=['image']), |
|
transforms.LoadImageD(image_only=True, keys=['image']), |
|
transforms.EnsureChannelFirstD(keys=['image']), |
|
transforms.SpacingD(pixdim=RESOLUTION, keys=['image']), |
|
transforms.ResizeWithPadOrCropD(spatial_size=INPUT_SHAPE_AE, mode='minimum', keys=['image']), |
|
transforms.ScaleIntensityD(minv=0, maxv=1, keys=['image']), |
|
]) |
|
|
|
print("Loading and flattening dataset from:", args.inputs_csv) |
|
X = load_and_flatten_dataset(args.inputs_csv, args.cache_dir, transforms_fn) |
|
print(f"Dataset shape after flattening: {X.shape}") |
|
|
|
|
|
model = PCAAutoencoder( |
|
n_components=args.n_components, |
|
batch_size=args.batch_size_ipca, |
|
pca_type=args.pca_type |
|
) |
|
|
|
|
|
print(f"Fitting {args.pca_type.capitalize()}PCA in batches...") |
|
model.fit(X) |
|
print("Done fitting PCA. Transforming data to embeddings...") |
|
|
|
|
|
X_recon, Z = model.forward(X) |
|
print("Embeddings shape:", Z.shape) |
|
print("Reconstruction shape:", X_recon.shape) |
|
|
|
|
|
embeddings_path = os.path.join(args.output_dir, "pca_embeddings.npy") |
|
recons_path = os.path.join(args.output_dir, "pca_reconstructions.npy") |
|
np.save(embeddings_path, Z) |
|
np.save(recons_path, X_recon) |
|
print(f"Saved embeddings to {embeddings_path}") |
|
print(f"Saved reconstructions to {recons_path}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
main() |