English
medical
brain-data
mri
jesseab commited on
Commit
3ae8863
·
1 Parent(s): ae80538
Files changed (2) hide show
  1. brain2vec.py +2 -2
  2. model.py +94 -0
brain2vec.py CHANGED
@@ -72,7 +72,7 @@ import matplotlib.pyplot as plt
72
  from torch.utils.tensorboard import SummaryWriter
73
 
74
  # choosen resolution
75
- RESOLUTION = 1.5
76
 
77
  # shape of the MNI152 (1mm^3) template
78
  INPUT_SHAPE_1mm = (182, 218, 182)
@@ -616,7 +616,7 @@ def main():
616
  lr=args.lr,
617
  aug_p=args.aug_p,
618
  )
619
- elif args.command == 'inference':
620
  inference(
621
  dataset_csv=args.dataset_csv,
622
  aekl_ckpt=args.aekl_ckpt,
 
72
  from torch.utils.tensorboard import SummaryWriter
73
 
74
  # choosen resolution
75
+ RESOLUTION = 2
76
 
77
  # shape of the MNI152 (1mm^3) template
78
  INPUT_SHAPE_1mm = (182, 218, 182)
 
616
  lr=args.lr,
617
  aug_p=args.aug_p,
618
  )
619
+ elif args.command == 'infer':
620
  inference(
621
  dataset_csv=args.dataset_csv,
622
  aekl_ckpt=args.aekl_ckpt,
model.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model.py
2
+ import os
3
+ from typing import Optional
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from monai.transforms import (
8
+ Compose,
9
+ CopyItemsD,
10
+ LoadImageD,
11
+ EnsureChannelFirstD,
12
+ SpacingD,
13
+ ResizeWithPadOrCropD,
14
+ ScaleIntensityD,
15
+ )
16
+ from generative.networks.nets import AutoencoderKL
17
+
18
+ # Constants for your typical config
19
+ RESOLUTION = 2
20
+ INPUT_SHAPE_AE = (80, 96, 80)
21
+
22
+ # Define the exact transform pipeline for input MRI
23
+ transforms_fn = Compose([
24
+ CopyItemsD(keys={'image_path'}, names=['image']),
25
+ LoadImageD(image_only=True, keys=['image']),
26
+ EnsureChannelFirstD(keys=['image']),
27
+ SpacingD(pixdim=RESOLUTION, keys=['image']),
28
+ ResizeWithPadOrCropD(spatial_size=INPUT_SHAPE_AE, mode='minimum', keys=['image']),
29
+ ScaleIntensityD(minv=0, maxv=1, keys=['image']),
30
+ ])
31
+
32
+ def preprocess_mri(image_path: str, device: str = "cpu") -> torch.Tensor:
33
+ """
34
+ Preprocess an MRI using MONAI transforms to produce
35
+ a 5D tensor (batch=1, channels=1, D, H, W) for inference.
36
+ """
37
+ data_dict = {"image_path": image_path}
38
+ output_dict = transforms_fn(data_dict)
39
+ image_tensor = output_dict["image"] # shape: (1, D, H, W)
40
+ image_tensor = image_tensor.unsqueeze(0) # => (batch=1, channel=1, D, H, W)
41
+ return image_tensor.to(device)
42
+
43
+
44
+ class Brain2vec(AutoencoderKL):
45
+ """
46
+ Subclass of MONAI's AutoencoderKL that includes:
47
+ - a from_pretrained(...) for loading a .pth checkpoint
48
+ - uses the existing forward(...) that returns (reconstruction, z_mu, z_sigma)
49
+
50
+ Usage:
51
+ >>> model = Brain2vec.from_pretrained("my_checkpoint.pth", device="cuda")
52
+ >>> image_tensor = preprocess_mri("/path/to/mri.nii.gz", device="cuda")
53
+ >>> reconstruction, z_mu, z_sigma = model.forward(image_tensor)
54
+ """
55
+
56
+ @staticmethod
57
+ def from_pretrained(
58
+ checkpoint_path: Optional[str] = None,
59
+ device: str = "cpu"
60
+ ) -> nn.Module:
61
+ """
62
+ Load a pretrained Brain2vec (AutoencoderKL) if a checkpoint_path is provided.
63
+ Otherwise, return an uninitialized model.
64
+
65
+ Args:
66
+ checkpoint_path (Optional[str]): path to a .pth checkpoint
67
+ device (str): "cpu", "cuda", "mps", etc.
68
+
69
+ Returns:
70
+ nn.Module: the loaded Brain2vec model on the chosen device
71
+ """
72
+ model = Brain2vec(
73
+ spatial_dims=3,
74
+ in_channels=1,
75
+ out_channels=1,
76
+ latent_channels=1,
77
+ num_channels=(64, 128, 128, 128),
78
+ num_res_blocks=2,
79
+ norm_num_groups=32,
80
+ norm_eps=1e-06,
81
+ attention_levels=(False, False, False, False),
82
+ with_decoder_nonlocal_attn=False,
83
+ with_encoder_nonlocal_attn=False,
84
+ )
85
+
86
+ if checkpoint_path is not None:
87
+ if not os.path.exists(checkpoint_path):
88
+ raise FileNotFoundError(f"Checkpoint {checkpoint_path} not found.")
89
+ state_dict = torch.load(checkpoint_path, map_location=device)
90
+ model.load_state_dict(state_dict)
91
+
92
+ model.to(device)
93
+ model.eval() # ready for inference
94
+ return model