brain2vec_PCA / model.py
jesseab's picture
Updated model.py and added .joblib
34517cd
raw
history blame
4.18 kB
# model.py
import os
import numpy as np
import torch
import torch.nn as nn
from monai.transforms import (
Compose,
CopyItemsD,
LoadImageD,
EnsureChannelFirstD,
SpacingD,
ResizeWithPadOrCropD,
ScaleIntensityD,
)
# If you used joblib or pickle to save your PCA model:
from joblib import load # or "import pickle"
#################################################
# Constants
#################################################
RESOLUTION = 2
INPUT_SHAPE_AE = (80, 96, 80) # The typical shape from your pipelines
FLATTENED_DIM = INPUT_SHAPE_AE[0] * INPUT_SHAPE_AE[1] * INPUT_SHAPE_AE[2]
#################################################
# Define MONAI Transforms for Preprocessing
#################################################
transforms_fn = Compose([
CopyItemsD(keys={'image_path'}, names=['image']),
LoadImageD(image_only=True, keys=['image']),
EnsureChannelFirstD(keys=['image']),
SpacingD(pixdim=RESOLUTION, keys=['image']),
ResizeWithPadOrCropD(spatial_size=INPUT_SHAPE_AE, mode='minimum', keys=['image']),
ScaleIntensityD(minv=0, maxv=1, keys=['image']),
])
def preprocess_mri(image_path: str) -> torch.Tensor:
"""
Preprocess an MRI using MONAI transforms to produce
a 5D Torch tensor: (batch=1, channel=1, D, H, W).
"""
data_dict = {"image_path": image_path}
output_dict = transforms_fn(data_dict)
# shape => (1, D, H, W)
image_tensor = output_dict["image"].unsqueeze(0) # => (batch=1, channel=1, D, H, W)
return image_tensor.float() # typically float32
#################################################
# PCA "Autoencoder" Wrapper
#################################################
class PCABrain2vec(nn.Module):
"""
A PCA-based 'autoencoder' that mimics the old interface:
- from_pretrained(...) to load a PCA model from disk
- forward(...) returns (reconstruction, embedding, None)
Under the hood, it:
- takes in a torch tensor shape (N, 1, D, H, W)
- flattens it (N, 614400)
- uses PCA's transform(...) to get embeddings => shape (N, n_components)
- uses inverse_transform(...) to get reconstructions => shape (N, 614400)
- reshapes back to (N, 1, D, H, W)
"""
def __init__(self, pca_model=None):
super().__init__()
# We'll store the fitted PCA model (from scikit-learn)
self.pca_model = pca_model # e.g., an instance of IncrementalPCA or PCA
def forward(self, x: torch.Tensor):
"""
Returns (reconstruction, embedding, None).
1) Convert x => numpy array => flatten => (N, 614400)
2) embedding = pca_model.transform(flat_x)
3) reconstruction_np = pca_model.inverse_transform(embedding)
4) reshape => (N, 1, 80, 96, 80)
5) convert to torch => return (recon, embed, None)
"""
# Expect x shape => (N, 1, D, H, W) => flatten to (N, D*H*W)
n_samples = x.shape[0]
# Convert to CPU np
x_cpu = x.detach().cpu().numpy() # shape: (N, 1, D, H, W)
x_flat = x_cpu.reshape(n_samples, -1) # shape: (N, 614400)
# PCA transform => embeddings shape (N, n_components)
embedding_np = self.pca_model.transform(x_flat)
# PCA inverse_transform => recon shape (N, 614400)
recon_np = self.pca_model.inverse_transform(embedding_np)
# Reshape back => (N, 1, 80, 96, 80)
recon_np = recon_np.reshape(n_samples, 1, *INPUT_SHAPE_AE)
# Convert back to torch
reconstruction_torch = torch.from_numpy(recon_np).float()
embedding_torch = torch.from_numpy(embedding_np).float()
return reconstruction_torch, embedding_torch, None
@staticmethod
def from_pretrained(pca_path: str):
"""
Load a pre-trained PCA model (pickled or joblib).
Returns an instance of PCABrain2vec with that model.
"""
if not os.path.exists(pca_path):
raise FileNotFoundError(f"Could not find PCA model at {pca_path}")
# Example: pca_model = pickle.load(open(pca_path, 'rb'))
# or use joblib:
pca_model = load(pca_path)
return PCABrain2vec(pca_model=pca_model)