|
|
|
import os |
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
|
|
from monai.transforms import ( |
|
Compose, |
|
CopyItemsD, |
|
LoadImageD, |
|
EnsureChannelFirstD, |
|
SpacingD, |
|
ResizeWithPadOrCropD, |
|
ScaleIntensityD, |
|
) |
|
|
|
|
|
from joblib import load |
|
|
|
|
|
|
|
|
|
RESOLUTION = 2 |
|
INPUT_SHAPE_AE = (80, 96, 80) |
|
FLATTENED_DIM = INPUT_SHAPE_AE[0] * INPUT_SHAPE_AE[1] * INPUT_SHAPE_AE[2] |
|
|
|
|
|
|
|
|
|
|
|
transforms_fn = Compose([ |
|
CopyItemsD(keys={'image_path'}, names=['image']), |
|
LoadImageD(image_only=True, keys=['image']), |
|
EnsureChannelFirstD(keys=['image']), |
|
SpacingD(pixdim=RESOLUTION, keys=['image']), |
|
ResizeWithPadOrCropD(spatial_size=INPUT_SHAPE_AE, mode='minimum', keys=['image']), |
|
ScaleIntensityD(minv=0, maxv=1, keys=['image']), |
|
]) |
|
|
|
|
|
def preprocess_mri(image_path: str) -> torch.Tensor: |
|
""" |
|
Preprocess an MRI using MONAI transforms to produce |
|
a 5D Torch tensor: (batch=1, channel=1, D, H, W). |
|
""" |
|
data_dict = {"image_path": image_path} |
|
output_dict = transforms_fn(data_dict) |
|
|
|
image_tensor = output_dict["image"].unsqueeze(0) |
|
return image_tensor.float() |
|
|
|
|
|
|
|
|
|
|
|
class PCABrain2vec(nn.Module): |
|
""" |
|
A PCA-based 'autoencoder' that mimics the old interface: |
|
- from_pretrained(...) to load a PCA model from disk |
|
- forward(...) returns (reconstruction, embedding, None) |
|
|
|
Under the hood, it: |
|
- takes in a torch tensor shape (N, 1, D, H, W) |
|
- flattens it (N, 614400) |
|
- uses PCA's transform(...) to get embeddings => shape (N, n_components) |
|
- uses inverse_transform(...) to get reconstructions => shape (N, 614400) |
|
- reshapes back to (N, 1, D, H, W) |
|
""" |
|
|
|
def __init__(self, pca_model=None): |
|
super().__init__() |
|
|
|
self.pca_model = pca_model |
|
|
|
def forward(self, x: torch.Tensor): |
|
""" |
|
Returns (reconstruction, embedding, None). |
|
|
|
1) Convert x => numpy array => flatten => (N, 614400) |
|
2) embedding = pca_model.transform(flat_x) |
|
3) reconstruction_np = pca_model.inverse_transform(embedding) |
|
4) reshape => (N, 1, 80, 96, 80) |
|
5) convert to torch => return (recon, embed, None) |
|
""" |
|
|
|
n_samples = x.shape[0] |
|
|
|
x_cpu = x.detach().cpu().numpy() |
|
x_flat = x_cpu.reshape(n_samples, -1) |
|
|
|
|
|
embedding_np = self.pca_model.transform(x_flat) |
|
|
|
|
|
recon_np = self.pca_model.inverse_transform(embedding_np) |
|
|
|
recon_np = recon_np.reshape(n_samples, 1, *INPUT_SHAPE_AE) |
|
|
|
|
|
reconstruction_torch = torch.from_numpy(recon_np).float() |
|
embedding_torch = torch.from_numpy(embedding_np).float() |
|
return reconstruction_torch, embedding_torch, None |
|
|
|
@staticmethod |
|
def from_pretrained(pca_path: str): |
|
""" |
|
Load a pre-trained PCA model (pickled or joblib). |
|
Returns an instance of PCABrain2vec with that model. |
|
""" |
|
if not os.path.exists(pca_path): |
|
raise FileNotFoundError(f"Could not find PCA model at {pca_path}") |
|
|
|
|
|
pca_model = load(pca_path) |
|
return PCABrain2vec(pca_model=pca_model) |