|
|
|
import os |
|
from typing import Optional |
|
|
|
import torch |
|
import torch.nn as nn |
|
from monai.transforms import ( |
|
Compose, |
|
CopyItemsD, |
|
LoadImageD, |
|
EnsureChannelFirstD, |
|
SpacingD, |
|
ResizeWithPadOrCropD, |
|
ScaleIntensityD, |
|
) |
|
from generative.networks.nets import AutoencoderKL |
|
|
|
|
|
RESOLUTION = 2 |
|
INPUT_SHAPE_AE = (80, 96, 80) |
|
|
|
|
|
transforms_fn = Compose([ |
|
CopyItemsD(keys={'image_path'}, names=['image']), |
|
LoadImageD(image_only=True, keys=['image']), |
|
EnsureChannelFirstD(keys=['image']), |
|
SpacingD(pixdim=RESOLUTION, keys=['image']), |
|
ResizeWithPadOrCropD(spatial_size=INPUT_SHAPE_AE, mode='minimum', keys=['image']), |
|
ScaleIntensityD(minv=0, maxv=1, keys=['image']), |
|
]) |
|
|
|
def preprocess_mri(image_path: str, device: str = "cpu") -> torch.Tensor: |
|
""" |
|
Preprocess an MRI using MONAI transforms to produce |
|
a 5D tensor (batch=1, channels=1, D, H, W) for inference. |
|
""" |
|
data_dict = {"image_path": image_path} |
|
output_dict = transforms_fn(data_dict) |
|
image_tensor = output_dict["image"] |
|
image_tensor = image_tensor.unsqueeze(0) |
|
return image_tensor.to(device) |
|
|
|
|
|
class Brain2vec(AutoencoderKL): |
|
""" |
|
Subclass of MONAI's AutoencoderKL that includes: |
|
- a from_pretrained(...) for loading a .pth checkpoint |
|
- uses the existing forward(...) that returns (reconstruction, z_mu, z_sigma) |
|
|
|
Usage: |
|
>>> model = Brain2vec.from_pretrained("my_checkpoint.pth", device="cuda") |
|
>>> image_tensor = preprocess_mri("/path/to/mri.nii.gz", device="cuda") |
|
>>> reconstruction, z_mu, z_sigma = model.forward(image_tensor) |
|
""" |
|
|
|
@staticmethod |
|
def from_pretrained( |
|
checkpoint_path: Optional[str] = None, |
|
device: str = "cpu" |
|
) -> nn.Module: |
|
""" |
|
Load a pretrained Brain2vec (AutoencoderKL) if a checkpoint_path is provided. |
|
Otherwise, return an uninitialized model. |
|
|
|
Args: |
|
checkpoint_path (Optional[str]): path to a .pth checkpoint |
|
device (str): "cpu", "cuda", "mps", etc. |
|
|
|
Returns: |
|
nn.Module: the loaded Brain2vec model on the chosen device |
|
""" |
|
model = Brain2vec( |
|
spatial_dims=3, |
|
in_channels=1, |
|
out_channels=1, |
|
latent_channels=1, |
|
num_channels=(64, 128, 128, 128), |
|
num_res_blocks=2, |
|
norm_num_groups=32, |
|
norm_eps=1e-06, |
|
attention_levels=(False, False, False, False), |
|
with_decoder_nonlocal_attn=False, |
|
with_encoder_nonlocal_attn=False, |
|
) |
|
|
|
if checkpoint_path is not None: |
|
if not os.path.exists(checkpoint_path): |
|
raise FileNotFoundError(f"Checkpoint {checkpoint_path} not found.") |
|
state_dict = torch.load(checkpoint_path, map_location=device) |
|
model.load_state_dict(state_dict) |
|
|
|
model.to(device) |
|
model.eval() |
|
return model |