brain2vec_PCA / model.py
jesseab's picture
Initial commit
bb04d63
raw
history blame
3.96 kB
# model.py
import os
from typing import Optional
import torch
import torch.nn as nn
from monai.transforms import (
Compose,
CopyItemsD,
LoadImageD,
EnsureChannelFirstD,
SpacingD,
ResizeWithPadOrCropD,
ScaleIntensityD,
)
# Constants for your typical config
RESOLUTION = 2
INPUT_SHAPE_AE = (80, 96, 80)
# Define the exact transform pipeline for input MRI
transforms_fn = Compose([
CopyItemsD(keys={'image_path'}, names=['image']),
LoadImageD(image_only=True, keys=['image']),
EnsureChannelFirstD(keys=['image']),
SpacingD(pixdim=RESOLUTION, keys=['image']),
ResizeWithPadOrCropD(spatial_size=INPUT_SHAPE_AE, mode='minimum', keys=['image']),
ScaleIntensityD(minv=0, maxv=1, keys=['image']),
])
def preprocess_mri(image_path: str, device: str = "cpu") -> torch.Tensor:
"""
Preprocess an MRI using MONAI transforms to produce
a 5D tensor (batch=1, channels=1, D, H, W) for inference.
"""
data_dict = {"image_path": image_path}
output_dict = transforms_fn(data_dict)
image_tensor = output_dict["image"] # shape: (1, D, H, W)
image_tensor = image_tensor.unsqueeze(0) # => (batch=1, channel=1, D, H, W)
return image_tensor.to(device)
class ShallowLinearAutoencoder(nn.Module):
"""
A purely linear autoencoder with one hidden layer.
- Flatten input into a vector
- Linear encoder (no activation)
- Linear decoder (no activation)
- Reshape output to original volume shape
"""
def __init__(self, input_shape=(80, 96, 80), hidden_size=1200):
super().__init__()
self.input_shape = input_shape
self.input_dim = input_shape[0] * input_shape[1] * input_shape[2]
self.hidden_size = hidden_size
# Encoder (no activation for PCA-like behavior)
self.encoder = nn.Sequential(
nn.Flatten(),
nn.Linear(self.input_dim, self.hidden_size),
)
# Decoder (no activation)
self.decoder = nn.Sequential(
nn.Linear(self.hidden_size, self.input_dim),
)
def encode(self, x: torch.Tensor):
return self.encoder(x)
def decode(self, z: torch.Tensor):
out = self.decoder(z)
# Reshape to (N, 1, D, H, W)
return out.view(-1, 1, *self.input_shape)
def forward(self, x: torch.Tensor):
"""
Return (reconstruction, embedding, None) to keep a similar API
to the old VAE-based code, though there's no σ for sampling.
"""
z = self.encode(x)
reconstruction = self.decode(z)
return reconstruction, z, None
class Brain2vec(nn.Module):
"""
A wrapper around the ShallowLinearAutoencoder, providing a from_pretrained(...)
method for model loading, mirroring the old usage with AutoencoderKL.
"""
def __init__(self, device: str = "cpu"):
super().__init__()
# Instantiate the shallow linear model
self.model = ShallowLinearAutoencoder(input_shape=INPUT_SHAPE_AE, hidden_size=1200)
self.to(device)
def forward(self, x: torch.Tensor):
"""
Forward pass that returns (reconstruction, embedding, None).
"""
return self.model(x)
@staticmethod
def from_pretrained(
checkpoint_path: Optional[str] = None,
device: str = "cpu"
) -> nn.Module:
"""
Load a pretrained ShallowLinearAutoencoder if a checkpoint path is provided.
Args:
checkpoint_path (Optional[str]): path to a .pth checkpoint
device (str): "cpu", "cuda", etc.
"""
model = Brain2vec(device=device)
if checkpoint_path is not None:
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(f"Checkpoint {checkpoint_path} not found.")
state_dict = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(state_dict)
model.eval()
return model