English
medical
brain-data
mri
brain2vec / model.py
jesseab's picture
Updates
3ae8863
raw
history blame
3.11 kB
# model.py
import os
from typing import Optional
import torch
import torch.nn as nn
from monai.transforms import (
Compose,
CopyItemsD,
LoadImageD,
EnsureChannelFirstD,
SpacingD,
ResizeWithPadOrCropD,
ScaleIntensityD,
)
from generative.networks.nets import AutoencoderKL
# Constants for your typical config
RESOLUTION = 2
INPUT_SHAPE_AE = (80, 96, 80)
# Define the exact transform pipeline for input MRI
transforms_fn = Compose([
CopyItemsD(keys={'image_path'}, names=['image']),
LoadImageD(image_only=True, keys=['image']),
EnsureChannelFirstD(keys=['image']),
SpacingD(pixdim=RESOLUTION, keys=['image']),
ResizeWithPadOrCropD(spatial_size=INPUT_SHAPE_AE, mode='minimum', keys=['image']),
ScaleIntensityD(minv=0, maxv=1, keys=['image']),
])
def preprocess_mri(image_path: str, device: str = "cpu") -> torch.Tensor:
"""
Preprocess an MRI using MONAI transforms to produce
a 5D tensor (batch=1, channels=1, D, H, W) for inference.
"""
data_dict = {"image_path": image_path}
output_dict = transforms_fn(data_dict)
image_tensor = output_dict["image"] # shape: (1, D, H, W)
image_tensor = image_tensor.unsqueeze(0) # => (batch=1, channel=1, D, H, W)
return image_tensor.to(device)
class Brain2vec(AutoencoderKL):
"""
Subclass of MONAI's AutoencoderKL that includes:
- a from_pretrained(...) for loading a .pth checkpoint
- uses the existing forward(...) that returns (reconstruction, z_mu, z_sigma)
Usage:
>>> model = Brain2vec.from_pretrained("my_checkpoint.pth", device="cuda")
>>> image_tensor = preprocess_mri("/path/to/mri.nii.gz", device="cuda")
>>> reconstruction, z_mu, z_sigma = model.forward(image_tensor)
"""
@staticmethod
def from_pretrained(
checkpoint_path: Optional[str] = None,
device: str = "cpu"
) -> nn.Module:
"""
Load a pretrained Brain2vec (AutoencoderKL) if a checkpoint_path is provided.
Otherwise, return an uninitialized model.
Args:
checkpoint_path (Optional[str]): path to a .pth checkpoint
device (str): "cpu", "cuda", "mps", etc.
Returns:
nn.Module: the loaded Brain2vec model on the chosen device
"""
model = Brain2vec(
spatial_dims=3,
in_channels=1,
out_channels=1,
latent_channels=1,
num_channels=(64, 128, 128, 128),
num_res_blocks=2,
norm_num_groups=32,
norm_eps=1e-06,
attention_levels=(False, False, False, False),
with_decoder_nonlocal_attn=False,
with_encoder_nonlocal_attn=False,
)
if checkpoint_path is not None:
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(f"Checkpoint {checkpoint_path} not found.")
state_dict = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(state_dict)
model.to(device)
model.eval() # ready for inference
return model