|
|
|
import os |
|
from typing import Optional |
|
|
|
import torch |
|
import torch.nn as nn |
|
from monai.transforms import ( |
|
Compose, |
|
CopyItemsD, |
|
LoadImageD, |
|
EnsureChannelFirstD, |
|
SpacingD, |
|
ResizeWithPadOrCropD, |
|
ScaleIntensityD, |
|
) |
|
|
|
|
|
RESOLUTION = 2 |
|
INPUT_SHAPE_AE = (80, 96, 80) |
|
|
|
|
|
transforms_fn = Compose([ |
|
CopyItemsD(keys={'image_path'}, names=['image']), |
|
LoadImageD(image_only=True, keys=['image']), |
|
EnsureChannelFirstD(keys=['image']), |
|
SpacingD(pixdim=RESOLUTION, keys=['image']), |
|
ResizeWithPadOrCropD(spatial_size=INPUT_SHAPE_AE, mode='minimum', keys=['image']), |
|
ScaleIntensityD(minv=0, maxv=1, keys=['image']), |
|
]) |
|
|
|
def preprocess_mri(image_path: str, device: str = "cpu") -> torch.Tensor: |
|
""" |
|
Preprocess an MRI using MONAI transforms to produce |
|
a 5D tensor (batch=1, channels=1, D, H, W) for inference. |
|
""" |
|
data_dict = {"image_path": image_path} |
|
output_dict = transforms_fn(data_dict) |
|
image_tensor = output_dict["image"] |
|
image_tensor = image_tensor.unsqueeze(0) |
|
return image_tensor.to(device) |
|
|
|
|
|
class ShallowLinearAutoencoder(nn.Module): |
|
""" |
|
A purely linear autoencoder with one hidden layer. |
|
- Flatten input into a vector |
|
- Linear encoder (no activation) |
|
- Linear decoder (no activation) |
|
- Reshape output to original volume shape |
|
""" |
|
def __init__(self, input_shape=(80, 96, 80), hidden_size=1200): |
|
super().__init__() |
|
self.input_shape = input_shape |
|
self.input_dim = input_shape[0] * input_shape[1] * input_shape[2] |
|
self.hidden_size = hidden_size |
|
|
|
|
|
self.encoder = nn.Sequential( |
|
nn.Flatten(), |
|
nn.Linear(self.input_dim, self.hidden_size), |
|
) |
|
|
|
|
|
self.decoder = nn.Sequential( |
|
nn.Linear(self.hidden_size, self.input_dim), |
|
) |
|
|
|
def encode(self, x: torch.Tensor): |
|
return self.encoder(x) |
|
|
|
def decode(self, z: torch.Tensor): |
|
out = self.decoder(z) |
|
|
|
return out.view(-1, 1, *self.input_shape) |
|
|
|
def forward(self, x: torch.Tensor): |
|
""" |
|
Return (reconstruction, embedding, None) to keep a similar API |
|
to the old VAE-based code, though there's no σ for sampling. |
|
""" |
|
z = self.encode(x) |
|
reconstruction = self.decode(z) |
|
return reconstruction, z, None |
|
|
|
|
|
class Brain2vec(nn.Module): |
|
""" |
|
A wrapper around the ShallowLinearAutoencoder, providing a from_pretrained(...) |
|
method for model loading, mirroring the old usage with AutoencoderKL. |
|
""" |
|
def __init__(self, device: str = "cpu"): |
|
super().__init__() |
|
|
|
self.model = ShallowLinearAutoencoder(input_shape=INPUT_SHAPE_AE, hidden_size=1200) |
|
self.to(device) |
|
|
|
def forward(self, x: torch.Tensor): |
|
""" |
|
Forward pass that returns (reconstruction, embedding, None). |
|
""" |
|
return self.model(x) |
|
|
|
@staticmethod |
|
def from_pretrained( |
|
checkpoint_path: Optional[str] = None, |
|
device: str = "cpu" |
|
) -> nn.Module: |
|
""" |
|
Load a pretrained ShallowLinearAutoencoder if a checkpoint path is provided. |
|
Args: |
|
checkpoint_path (Optional[str]): path to a .pth checkpoint |
|
device (str): "cpu", "cuda", etc. |
|
""" |
|
model = Brain2vec(device=device) |
|
if checkpoint_path is not None: |
|
if not os.path.exists(checkpoint_path): |
|
raise FileNotFoundError(f"Checkpoint {checkpoint_path} not found.") |
|
state_dict = torch.load(checkpoint_path, map_location=device) |
|
model.load_state_dict(state_dict) |
|
model.eval() |
|
return model |