# model.py import os from typing import Optional import torch import torch.nn as nn from monai.transforms import ( Compose, CopyItemsD, LoadImageD, EnsureChannelFirstD, SpacingD, ResizeWithPadOrCropD, ScaleIntensityD, ) # Constants for your typical config RESOLUTION = 2 INPUT_SHAPE_AE = (80, 96, 80) # Define the exact transform pipeline for input MRI transforms_fn = Compose([ CopyItemsD(keys={'image_path'}, names=['image']), LoadImageD(image_only=True, keys=['image']), EnsureChannelFirstD(keys=['image']), SpacingD(pixdim=RESOLUTION, keys=['image']), ResizeWithPadOrCropD(spatial_size=INPUT_SHAPE_AE, mode='minimum', keys=['image']), ScaleIntensityD(minv=0, maxv=1, keys=['image']), ]) def preprocess_mri(image_path: str, device: str = "cpu") -> torch.Tensor: """ Preprocess an MRI using MONAI transforms to produce a 5D tensor (batch=1, channels=1, D, H, W) for inference. """ data_dict = {"image_path": image_path} output_dict = transforms_fn(data_dict) image_tensor = output_dict["image"] # shape: (1, D, H, W) image_tensor = image_tensor.unsqueeze(0) # => (batch=1, channel=1, D, H, W) return image_tensor.to(device) class ShallowLinearAutoencoder(nn.Module): """ A purely linear autoencoder with one hidden layer. - Flatten input into a vector - Linear encoder (no activation) - Linear decoder (no activation) - Reshape output to original volume shape """ def __init__(self, input_shape=(80, 96, 80), hidden_size=1200): super().__init__() self.input_shape = input_shape self.input_dim = input_shape[0] * input_shape[1] * input_shape[2] self.hidden_size = hidden_size # Encoder (no activation for PCA-like behavior) self.encoder = nn.Sequential( nn.Flatten(), nn.Linear(self.input_dim, self.hidden_size), ) # Decoder (no activation) self.decoder = nn.Sequential( nn.Linear(self.hidden_size, self.input_dim), ) def encode(self, x: torch.Tensor): return self.encoder(x) def decode(self, z: torch.Tensor): out = self.decoder(z) # Reshape to (N, 1, D, H, W) return out.view(-1, 1, *self.input_shape) def forward(self, x: torch.Tensor): """ Return (reconstruction, embedding, None) to keep a similar API to the old VAE-based code, though there's no σ for sampling. """ z = self.encode(x) reconstruction = self.decode(z) return reconstruction, z, None class Brain2vec(nn.Module): """ A wrapper around the ShallowLinearAutoencoder, providing a from_pretrained(...) method for model loading, mirroring the old usage with AutoencoderKL. """ def __init__(self, device: str = "cpu"): super().__init__() # Instantiate the shallow linear model self.model = ShallowLinearAutoencoder(input_shape=INPUT_SHAPE_AE, hidden_size=1200) self.to(device) def forward(self, x: torch.Tensor): """ Forward pass that returns (reconstruction, embedding, None). """ return self.model(x) @staticmethod def from_pretrained( checkpoint_path: Optional[str] = None, device: str = "cpu" ) -> nn.Module: """ Load a pretrained ShallowLinearAutoencoder if a checkpoint path is provided. Args: checkpoint_path (Optional[str]): path to a .pth checkpoint device (str): "cpu", "cuda", etc. """ model = Brain2vec(device=device) if checkpoint_path is not None: if not os.path.exists(checkpoint_path): raise FileNotFoundError(f"Checkpoint {checkpoint_path} not found.") state_dict = torch.load(checkpoint_path, map_location=device) model.load_state_dict(state_dict) model.eval() return model