File size: 3,964 Bytes
bb04d63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# model.py
import os
from typing import Optional

import torch
import torch.nn as nn
from monai.transforms import (
    Compose,
    CopyItemsD,
    LoadImageD,
    EnsureChannelFirstD,
    SpacingD,
    ResizeWithPadOrCropD,
    ScaleIntensityD,
)

# Constants for your typical config
RESOLUTION = 2
INPUT_SHAPE_AE = (80, 96, 80)

# Define the exact transform pipeline for input MRI
transforms_fn = Compose([
    CopyItemsD(keys={'image_path'}, names=['image']),
    LoadImageD(image_only=True, keys=['image']),
    EnsureChannelFirstD(keys=['image']),
    SpacingD(pixdim=RESOLUTION, keys=['image']),
    ResizeWithPadOrCropD(spatial_size=INPUT_SHAPE_AE, mode='minimum', keys=['image']),
    ScaleIntensityD(minv=0, maxv=1, keys=['image']),
])

def preprocess_mri(image_path: str, device: str = "cpu") -> torch.Tensor:
    """
    Preprocess an MRI using MONAI transforms to produce
    a 5D tensor (batch=1, channels=1, D, H, W) for inference.
    """
    data_dict = {"image_path": image_path}
    output_dict = transforms_fn(data_dict)
    image_tensor = output_dict["image"]  # shape: (1, D, H, W)
    image_tensor = image_tensor.unsqueeze(0)  # => (batch=1, channel=1, D, H, W)
    return image_tensor.to(device)


class ShallowLinearAutoencoder(nn.Module):
    """
    A purely linear autoencoder with one hidden layer.
    - Flatten input into a vector
    - Linear encoder (no activation)
    - Linear decoder (no activation)
    - Reshape output to original volume shape
    """
    def __init__(self, input_shape=(80, 96, 80), hidden_size=1200):
        super().__init__()
        self.input_shape = input_shape
        self.input_dim = input_shape[0] * input_shape[1] * input_shape[2]
        self.hidden_size = hidden_size

        # Encoder (no activation for PCA-like behavior)
        self.encoder = nn.Sequential(
            nn.Flatten(),
            nn.Linear(self.input_dim, self.hidden_size),
        )

        # Decoder (no activation)
        self.decoder = nn.Sequential(
            nn.Linear(self.hidden_size, self.input_dim),
        )

    def encode(self, x: torch.Tensor):
        return self.encoder(x)

    def decode(self, z: torch.Tensor):
        out = self.decoder(z)
        # Reshape to (N, 1, D, H, W)
        return out.view(-1, 1, *self.input_shape)

    def forward(self, x: torch.Tensor):
        """
        Return (reconstruction, embedding, None) to keep a similar API
        to the old VAE-based code, though there's no σ for sampling.
        """
        z = self.encode(x)
        reconstruction = self.decode(z)
        return reconstruction, z, None


class Brain2vec(nn.Module):
    """
    A wrapper around the ShallowLinearAutoencoder, providing a from_pretrained(...)
    method for model loading, mirroring the old usage with AutoencoderKL.
    """
    def __init__(self, device: str = "cpu"):
        super().__init__()
        # Instantiate the shallow linear model
        self.model = ShallowLinearAutoencoder(input_shape=INPUT_SHAPE_AE, hidden_size=1200)
        self.to(device)

    def forward(self, x: torch.Tensor):
        """
        Forward pass that returns (reconstruction, embedding, None).
        """
        return self.model(x)

    @staticmethod
    def from_pretrained(
        checkpoint_path: Optional[str] = None,
        device: str = "cpu"
    ) -> nn.Module:
        """
        Load a pretrained ShallowLinearAutoencoder if a checkpoint path is provided.
        Args:
            checkpoint_path (Optional[str]): path to a .pth checkpoint
            device (str): "cpu", "cuda", etc.
        """
        model = Brain2vec(device=device)
        if checkpoint_path is not None:
            if not os.path.exists(checkpoint_path):
                raise FileNotFoundError(f"Checkpoint {checkpoint_path} not found.")
            state_dict = torch.load(checkpoint_path, map_location=device)
            model.load_state_dict(state_dict)
        model.eval()
        return model