File size: 7,898 Bytes
e148d83
 
 
73bc6a0
 
e148d83
 
73bc6a0
e133e90
73bc6a0
e148d83
 
73bc6a0
cf79142
e148d83
73bc6a0
 
 
 
 
 
 
989a90a
73bc6a0
 
 
 
 
 
 
5456f4c
73bc6a0
 
 
 
 
 
 
 
 
 
 
 
e148d83
73bc6a0
e148d83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73bc6a0
e148d83
 
 
 
 
 
 
 
 
 
73bc6a0
 
e148d83
73bc6a0
e148d83
 
 
 
73bc6a0
9d02ecc
aec25e1
 
57cee88
aec25e1
8c39a9e
73bc6a0
e148d83
 
73bc6a0
e148d83
 
 
73bc6a0
e148d83
73bc6a0
e148d83
 
 
 
 
 
73bc6a0
e148d83
 
73bc6a0
e148d83
 
 
 
893aaeb
9d8a150
 
7c9c0b1
e148d83
9d8a150
6d14476
 
aec25e1
e148d83
73bc6a0
 
8f99d89
 
 
 
 
 
 
 
 
 
 
 
 
143dce3
4fd8f2c
 
69a0449
 
 
4fd8f2c
d892c6e
 
aef731a
 
d892c6e
aef731a
 
420b558
b83e952
6515de6
57cee88
420b558
912d400
73bc6a0
 
7d444b3
 
 
 
 
5456f4c
 
 
73bc6a0
 
 
 
 
 
 
 
 
 
6515de6
73bc6a0
 
6515de6
9e2e95c
6515de6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import gradio as gr
import torch
import numpy as np
import os
import nibabel as nib
import torchio
import torch.nn as nn
import subprocess
import spaces  # Import spaces for GPU decoration
from scipy.ndimage.measurements import center_of_mass
from huggingface_hub import hf_hub_download
from monai.transforms import Compose, LoadImaged, Spacingd, CropForegroundd, SpatialPadd, CenterSpatialCropd
from monai.data import Dataset, DataLoader
from nnunet_mednext import create_mednext_encoder_v1

# Model and data directory setup
MODEL_DIR = "/root/.cache/huggingface/hub"
DATASET_DIR = os.path.join(MODEL_DIR, "BrainAgeNeXt")
REPO_ID = "FrancescoLR/BrainAgeNeXt"

# Ensure model directory exists
os.makedirs(MODEL_DIR, exist_ok=True)

# ๐Ÿ”น Function to Download Model Weights from Hugging Face
def download_model():
    if not os.path.exists(DATASET_DIR):
        os.makedirs(DATASET_DIR, exist_ok=True)
        print("Downloading BrainAgeNeXt model weights...")
        for i in range(1, 6):
            hf_hub_download(repo_id=REPO_ID, filename=f"BrainAge_{i}.pth", cache_dir=MODEL_DIR)
        print("BrainAgeNeXt model downloaded successfully.")

# ๐Ÿ”น Function to Load Model
def initialize_model():
    model_paths = [hf_hub_download(repo_id=REPO_ID, filename=f"BrainAge_{i}.pth", cache_dir=MODEL_DIR) for i in range(1, 6)]
    
    models = []
    for model_path in model_paths:
        model = MedNeXtEncReg().to(device)
        model.load_state_dict(torch.load(model_path, map_location=device))
        model.eval()
        models.append(model)
    return models

# ๐Ÿ”น Define Model
class MedNeXtEncReg(nn.Module):
    def __init__(self):
        super(MedNeXtEncReg, self).__init__()
        self.mednextv1 = create_mednext_encoder_v1(
            num_input_channels=1, num_classes=1, model_id='B', kernel_size=3, deep_supervision=True
        )
        self.global_avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1))
        self.regression_fc = nn.Sequential(
            nn.Linear(512, 64),
            nn.ReLU(),
            nn.Dropout(0.0),
            nn.Linear(64, 1)
        )

    def forward(self, x):
        x = self.mednextv1(x)
        x = self.global_avg_pool(x)
        x = torch.flatten(x, start_dim=1)
        age_estimate = self.regression_fc(x)
        return age_estimate.squeeze()

# ๐Ÿ”น Preprocessing Pipeline
def prepare_transforms():
    return Compose([
        LoadImaged(keys=["image"], ensure_channel_first=True),
        Spacingd(keys=["image"], pixdim=(1.0, 1.0, 1.0)),
        CropForegroundd(keys=["image"], allow_smaller=True, source_key="image"),
        SpatialPadd(keys=["image"], spatial_size=(160, 192, 160)),
        CenterSpatialCropd(keys=["image"], roi_size=(160, 192, 160)),
        torchio.transforms.ZNormalization(masking_method=lambda x: x > 0, keys=["image"])
    ])

# ๐Ÿ”น Process MRI File
def preprocess_mri(nifti_path):
    transforms = prepare_transforms()
    data_dict = {"image": nifti_path}
    dataset = Dataset([data_dict], transform=transforms)
    dataloader = DataLoader(dataset, batch_size=1, num_workers=0)
    return next(iter(dataloader))["image"].to(device)

# ๐Ÿ”น Run Brain Age Prediction (Decorated for GPU Execution)
@spaces.GPU(duration=90)
def predict_brain_age(nifti_file, actual_age):
    return f"Brain Age estimate: 42"

"""
def predict_brain_age(nifti_file, actual_age): #sex
    if not os.path.exists(nifti_file.name):
        return "Error: MRI file not found"

    # Load Model
    models = initialize_model()

    # Preprocess MRI
    image = preprocess_mri(nifti_file.name)

    # Run Predictions
    predictions = []
    with torch.no_grad():
        for model in models:
            pred = model(image)
            predictions.append(pred.cpu().numpy())

    # Compute Median Brain Age Prediction
    predicted_brain_age = np.median(np.stack(predictions))

    # Apply Correction Based on Actual Age
    predicted_brain_age_corrected = (
        predicted_brain_age + (actual_age * 0.062) - 2.96 if actual_age > 18 else predicted_brain_age
    )

    brain_age_difference = predicted_brain_age - actual_age
    # Determine color: Red if positive, Green if negative
    color = "red" if brain_age_difference > 0 else "green"
    bad_output_html = f"<span style='color:{color}; font-weight:bold;'>Brain Age Difference: {brain_age_difference:.2f} years</span>"

    # Return formatted outputs
    #return f"Predicted Brain Age: {predicted_brain_age_corrected:.2f} years", bad_output_html
    return f"Predicted Brain Age: {predicted_brain_age_corrected:.2f} years", f"Brain Age Difference: {brain_age_difference:.2f} years"
"""

# ๐Ÿ”น Gradio Interface Setup
with gr.Blocks() as demo:
    gr.Markdown("""
    # ๐Ÿง  **BrainAgeNeXt**: Advancing Brain Age Modeling
    Upload a preprocessed T1w MRI scan (.nii.gz), enter the age and sex of the subject, and get the brain age prediction.

    The following preprocessing steps are required.
    1. Skull-stripping using [SynthStrip](https://surfer.nmr.mgh.harvard.edu/docs/synthstrip/) with the `--no-csf` flag for optimal results.
    2. N4 bias field correction using [ANTs](https://github.com/ANTsX/ANTs/wiki/N4BiasFieldCorrection).
    3. Affine registration to the MNI 1mm isotropic template space. 

    **BrainAgeNeXt** has been trained and validated using over 11,000 T1w MRI acquired at 1.5, 3, and 7T. A 1mm isotropic resolution is preferred for the input image but not required. Our [manuscript](https://doi.org/10.1162/imag_a_00487) presents a detailed explanation of **BrainAgeNeXt** and its potential applications.
  
    Note: this app allows to process only a single MRI at the time. Please visit our [GitHub repository](https://github.com/FrancescoLR/BrainAgeNeXt/tree/main) to install the code on your machine and run BrainAgeNeXt on large datasets         )
    """)
          
    with gr.Row():
        with gr.Column(scale=1):
            mri_input = gr.File(label="Upload a T1w MRI (NIfTI .nii.gz)")
            age_input = gr.Number(label="Enter Age", value=50)
            sex_input = gr.Radio(["Male", "Female"], label="Select Sex")
            submit_button = gr.Button("Predict")

        with gr.Column(scale=2):
            #brain_age_output = gr.Textbox(label="Predicted Brain Age", interactive=False)
            #bad_output = gr.HTML(label="Brain Age Difference")  # Use gr.HTML for colored text  
            brain_age_output = gr.Textbox(label="Predicted Brain Age", interactive=False)
            bad_output = gr.Textbox(label="Brain Age Difference", interactive=False)
    
    submit_button.click(
        fn=predict_brain_age,
        #inputs=[mri_input, age_input], #sex_input
        outputs=[brain_age_output, bad_output]
    )
    
    gr.Markdown("""
    **Disclaimer:** This is a research tool and is not intended for clinical use.

    **If you find this tool useful, please consider citing:**
    1. La Rosa, F., Dos Santos Silva, J., Dereskewicz, E., Invernizzi, A., Cahan, N., Galasso, J., ... & Beck, E. S. (2025).
    BrainAgeNeXt: Advancing Brain Age Modeling for Individuals with Multiple Sclerosis. Imaging Neuroscience. 
   DOI: [10.1162/imag_a_00487](https://doi.org/10.1162/imag_a_00487)
    2. Roy, S., Koehler, G., Ulrich, C., Baumgartner, M., Petersen, J., Isensee, F., Jaeger, P.F. & Maier-Hein, K. (2023).
       MedNeXt: Transformer-driven Scaling of ConvNets for Medical Image Segmentation. 
       International Conference on Medical Image Computing and Computer-Assisted Intervention (MICCAI).
    """)
# ๐Ÿ”น Debugging GPU Environment
if torch.cuda.is_available():
    print(f"GPU available: {torch.cuda.get_device_name(0)}")
    device = torch.device("cuda")
else:
    print("No GPU detected. Falling back to CPU.")
    os.system("nvidia-smi")
    device = torch.device("cpu")
    
#  Download Model Weights
download_model()

#  Run Gradio App
if __name__ == "__main__":
    demo.launch(share=True)