English
medical
brain-data
mri
File size: 7,889 Bytes
ac3730a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46a0841
ac3730a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c84e52
 
 
 
 
 
 
 
 
 
 
 
ac3730a
 
 
 
bef8312
 
 
 
ac3730a
 
bef8312
ac3730a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bef8312
ac3730a
 
 
 
 
 
 
 
 
 
8c84e52
ac3730a
 
 
 
 
 
 
 
 
 
 
 
8c84e52
 
 
 
 
ac3730a
8c84e52
ac3730a
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
#!/usr/bin/env python3

"""
inference_brain2vec.py

Loads a pretrained Brain2vec VAE (AutoencoderKL) model and performs inference
on one or more MRI images, generating reconstructions and latent parameters 
(z_mu, z_sigma).

Example usage:

    # 1) Multiple file paths
    python inference_brain2vec.py \
        --checkpoint_path /path/to/autoencoder_checkpoint.pth \
        --input_images /path/to/img1.nii.gz /path/to/img2.nii.gz \
        --output_dir ./vae_inference_outputs \
        --device cuda

    # 2) Use a CSV containing image paths
    python inference_brain2vec.py \
        --checkpoint_path /path/to/autoencoder_checkpoint.pth \
        --csv_input /path/to/images.csv \
        --output_dir ./vae_inference_outputs
"""

import os
import argparse
import numpy as np
import torch
import torch.nn as nn
from typing import Optional
from monai.transforms import (
    Compose,
    CopyItemsD,
    LoadImageD,
    EnsureChannelFirstD,
    SpacingD,
    ResizeWithPadOrCropD,
    ScaleIntensityD,
)
from generative.networks.nets import AutoencoderKL
import pandas as pd


RESOLUTION = 2
INPUT_SHAPE_AE = (80, 96, 80)

transforms_fn = Compose([
    CopyItemsD(keys={'image_path'}, names=['image']),
    LoadImageD(image_only=True, keys=['image']),
    EnsureChannelFirstD(keys=['image']),
    SpacingD(pixdim=RESOLUTION, keys=['image']),
    ResizeWithPadOrCropD(spatial_size=INPUT_SHAPE_AE, mode='minimum', keys=['image']),
    ScaleIntensityD(minv=0, maxv=1, keys=['image']),
])


def preprocess_mri(image_path: str, device: str = "cpu") -> torch.Tensor:
    """
    Preprocess an MRI using MONAI transforms to produce
    a 5D tensor (batch=1, channel=1, D, H, W) for inference.

    Args:
        image_path (str): Path to the MRI (e.g. .nii.gz).
        device (str): Device to place the tensor on.

    Returns:
        torch.Tensor: Shape (1, 1, D, H, W).
    """
    data_dict = {"image_path": image_path}
    output_dict = transforms_fn(data_dict)
    image_tensor = output_dict["image"]  # shape: (1, D, H, W)
    image_tensor = image_tensor.unsqueeze(0)  # => (1, 1, D, H, W)
    return image_tensor.to(device)


class Brain2vec(AutoencoderKL):
    """
    Subclass of MONAI's AutoencoderKL that includes:
      - a from_pretrained(...) for loading a .pth checkpoint
      - uses the existing forward(...) that returns (reconstruction, z_mu, z_sigma)

    Usage:
      >>> model = Brain2vec.from_pretrained("my_checkpoint.pth", device="cuda")
      >>> image_tensor = preprocess_mri("/path/to/mri.nii.gz", device="cuda")
      >>> reconstruction, z_mu, z_sigma = model.forward(image_tensor)
    """

    @staticmethod
    def from_pretrained(
        checkpoint_path: Optional[str] = None,
        device: str = "cpu"
    ) -> nn.Module:
        """
        Load a pretrained Brain2vec (AutoencoderKL) if a checkpoint_path is provided.
        Otherwise, return an uninitialized model.

        Args:
            checkpoint_path (Optional[str]): Path to a .pth checkpoint file.
            device (str): "cpu", "cuda", "mps", etc.

        Returns:
            nn.Module: The loaded Brain2vec model on the chosen device.
        """
        model = Brain2vec(
            spatial_dims=3,
            in_channels=1,
            out_channels=1,
            latent_channels=1,
            num_channels=(64, 128, 128, 128),
            num_res_blocks=2,
            norm_num_groups=32,
            norm_eps=1e-06,
            attention_levels=(False, False, False, False),
            with_decoder_nonlocal_attn=False,
            with_encoder_nonlocal_attn=False,
        )

        if checkpoint_path is not None:
            if not os.path.exists(checkpoint_path):
                raise FileNotFoundError(f"Checkpoint {checkpoint_path} not found.")
            state_dict = torch.load(checkpoint_path, map_location=device)
            model.load_state_dict(state_dict)

        model.to(device)
        model.eval()
        return model


def main() -> None:
    """
    Main function to parse command-line arguments and run inference
    with a pretrained Brain2vec model.
    """
    parser = argparse.ArgumentParser(
        description="Inference script for a Brain2vec (VAE) model."
    )
    parser.add_argument(
        "--checkpoint_path", type=str, required=True,
        help="Path to the .pth checkpoint of the pretrained Brain2vec model."
    )
    parser.add_argument(
        "--output_dir", type=str, default="./vae_inference_outputs",
        help="Directory to save reconstructions and latent parameters."
    )
    # Two ways to supply images: multiple file paths or a CSV
    parser.add_argument(
        "--input_images", type=str, nargs="*",
        help="One or more MRI file paths (e.g. .nii.gz)."
    )
    parser.add_argument(
        "--csv_input", type=str,
        help="Path to a CSV file with an 'image_path' column."
    )
    parser.add_argument(
        "--embeddings_filename",
        type=str,
        required=True,
        help="Filename (in output_dir) to save the stacked z_mu embeddings (e.g. 'all_z_mu.npy')."
    )
    parser.add_argument(
        "--save_recons",
        action="store_true",
        help="If set, saves each reconstruction as .npy. Default is not to save."
    )

    args = parser.parse_args()

    os.makedirs(args.output_dir, exist_ok=True)

    # After parsing args, add:
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Then pass that device to the model:
    model = Brain2vec.from_pretrained(
        checkpoint_path=args.checkpoint_path,
        device=device
    )

    # Gather image paths
    if args.csv_input:
        df = pd.read_csv(args.csv_input)
        if "image_path" not in df.columns:
            raise ValueError("CSV must contain a column named 'image_path'.")
        image_paths = df["image_path"].tolist()
    else:
        if not args.input_images:
            raise ValueError("Must provide either --csv_input or --input_images.")
        image_paths = args.input_images

    # Lists for stacking latent parameters later
    all_z_mu = []
    all_z_sigma = []

    # Inference on each image
    for i, img_path in enumerate(image_paths):
        if not os.path.exists(img_path):
            raise FileNotFoundError(f"Image not found: {img_path}")

        print(f"[INFO] Processing image {i}: {img_path}")
        img_tensor = preprocess_mri(img_path, device=device)

        with torch.no_grad():
            recon, z_mu, z_sigma = model.forward(img_tensor)

        # Convert to NumPy
        recon_np = recon.detach().cpu().numpy()  # shape: (1, 1, D, H, W)
        z_mu_np = z_mu.detach().cpu().numpy()    # shape: (1, latent_channels, ...)
        z_sigma_np = z_sigma.detach().cpu().numpy()

        # Save each reconstruction (per image) as .npy
    if args.save_recons:
        recon_path = os.path.join(args.output_dir, f"reconstruction_{i}.npy")
        np.save(recon_path, recon_np)
        print(f"[INFO] Saved reconstruction to {recon_path}")

        # Store latent parameters for optional combined saving
        all_z_mu.append(z_mu_np)
        all_z_sigma.append(z_sigma_np)

    # Combine latent parameters from all images and save
    stacked_mu = np.concatenate(all_z_mu, axis=0)       # e.g., shape (N, latent_channels, ...)
    stacked_sigma = np.concatenate(all_z_sigma, axis=0) # e.g., shape (N, latent_channels, ...)

    mu_filename = args.embeddings_filename
    if not mu_filename.lower().endswith(".npy"):
        mu_filename += ".npy"

    mu_path = os.path.join(args.output_dir, mu_filename)
    sigma_path = os.path.join(args.output_dir, "all_z_sigma.npy")

    np.save(mu_path, stacked_mu)
    np.save(sigma_path, stacked_sigma)

    print(f"[INFO] Saved z_mu of shape {stacked_mu.shape} to {mu_path}")
    print(f"[INFO] Saved z_sigma of shape {stacked_sigma.shape} to {sigma_path}")


if __name__ == "__main__":
    main()