BrainAgeNeXt / app.py
FrancescoLR's picture
Update app.py
84246e6 verified
raw
history blame
8 kB
import gradio as gr
import torch
import numpy as np
import os
import nibabel as nib
import torchio
import torch.nn as nn
import subprocess
import spaces # Import spaces for GPU decoration
from scipy.ndimage.measurements import center_of_mass
from huggingface_hub import hf_hub_download
from monai.transforms import Compose, LoadImaged, Spacingd, CropForegroundd, SpatialPadd, CenterSpatialCropd
from monai.data import Dataset, DataLoader
from nnunet_mednext import create_mednext_encoder_v1
# Model and data directory setup
MODEL_DIR = "/root/.cache/huggingface/hub"
DATASET_DIR = os.path.join(MODEL_DIR, "BrainAgeNeXt")
REPO_ID = "FrancescoLR/BrainAgeNeXt"
# Ensure model directory exists
os.makedirs(MODEL_DIR, exist_ok=True)
# ๐Ÿ”น Function to Download Model Weights from Hugging Face
def download_model():
if not os.path.exists(DATASET_DIR):
os.makedirs(DATASET_DIR, exist_ok=True)
print("Downloading BrainAgeNeXt model weights...")
for i in range(1, 6):
hf_hub_download(repo_id=REPO_ID, filename=f"BrainAge_{i}.pth", cache_dir=MODEL_DIR)
print("BrainAgeNeXt model downloaded successfully.")
# ๐Ÿ”น Function to Load Model
def initialize_model():
model_paths = [hf_hub_download(repo_id=REPO_ID, filename=f"BrainAge_{i}.pth", cache_dir=MODEL_DIR) for i in range(1, 6)]
models = []
for model_path in model_paths:
model = MedNeXtEncReg().to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
models.append(model)
return models
# ๐Ÿ”น Define Model
class MedNeXtEncReg(nn.Module):
def __init__(self):
super(MedNeXtEncReg, self).__init__()
self.mednextv1 = create_mednext_encoder_v1(
num_input_channels=1, num_classes=1, model_id='B', kernel_size=3, deep_supervision=True
)
self.global_avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1))
self.regression_fc = nn.Sequential(
nn.Linear(512, 64),
nn.ReLU(),
nn.Dropout(0.0),
nn.Linear(64, 1)
)
def forward(self, x):
x = self.mednextv1(x)
x = self.global_avg_pool(x)
x = torch.flatten(x, start_dim=1)
age_estimate = self.regression_fc(x)
return age_estimate.squeeze()
# ๐Ÿ”น Preprocessing Pipeline
def prepare_transforms():
return Compose([
LoadImaged(keys=["image"], ensure_channel_first=True),
Spacingd(keys=["image"], pixdim=(1.0, 1.0, 1.0)),
CropForegroundd(keys=["image"], allow_smaller=True, source_key="image"),
SpatialPadd(keys=["image"], spatial_size=(160, 192, 160)),
CenterSpatialCropd(keys=["image"], roi_size=(160, 192, 160)),
torchio.transforms.ZNormalization(masking_method=lambda x: x > 0, keys=["image"])
])
# ๐Ÿ”น Process MRI File
def preprocess_mri(nifti_path):
transforms = prepare_transforms()
data_dict = {"image": nifti_path}
dataset = Dataset([data_dict], transform=transforms)
dataloader = DataLoader(dataset, batch_size=1, num_workers=0)
return next(iter(dataloader))["image"].to(device)
# ๐Ÿ”น Run Brain Age Prediction (Decorated for GPU Execution)
@spaces.GPU(duration=90)
def predict_brain_age(path, age, sex): #, actual_age):
return f"Brain Age estimate: 42", age
"""
def predict_brain_age(nifti_file, actual_age): #sex
if not os.path.exists(nifti_file.name):
return "Error: MRI file not found"
# Load Model
models = initialize_model()
# Preprocess MRI
image = preprocess_mri(nifti_file.name)
# Run Predictions
predictions = []
with torch.no_grad():
for model in models:
pred = model(image)
predictions.append(pred.cpu().numpy())
# Compute Median Brain Age Prediction
predicted_brain_age = np.median(np.stack(predictions))
# Apply Correction Based on Actual Age
predicted_brain_age_corrected = (
predicted_brain_age + (actual_age * 0.062) - 2.96 if actual_age > 18 else predicted_brain_age
)
brain_age_difference = predicted_brain_age - actual_age
# Determine color: Red if positive, Green if negative
color = "red" if brain_age_difference > 0 else "green"
bad_output_html = f"<span style='color:{color}; font-weight:bold;'>Brain Age Difference: {brain_age_difference:.2f} years</span>"
# Return formatted outputs
#return f"Predicted Brain Age: {predicted_brain_age_corrected:.2f} years", bad_output_html
return f"Predicted Brain Age: {predicted_brain_age_corrected:.2f} years", f"Brain Age Difference: {brain_age_difference:.2f} years"
"""
# ๐Ÿ”น Gradio Interface Setup
with gr.Blocks() as demo:
gr.Markdown("""
# ๐Ÿง  **BrainAgeNeXt**: Advancing Brain Age Modeling
Upload a preprocessed T1w MRI scan (.nii.gz), enter the age and sex of the subject, and get the brain age prediction.
The following preprocessing steps are required.
1. Skull-stripping using [SynthStrip](https://surfer.nmr.mgh.harvard.edu/docs/synthstrip/) with the `--no-csf` flag for optimal results.
2. N4 bias field correction using [ANTs](https://github.com/ANTsX/ANTs/wiki/N4BiasFieldCorrection).
3. Affine registration to the MNI 1mm isotropic template space.
**BrainAgeNeXt** has been trained and validated using over 11,000 T1w MRI acquired at 1.5, 3, and 7T. A 1mm isotropic resolution is preferred for the input image but not required. Our [manuscript](https://doi.org/10.1162/imag_a_00487) presents a detailed explanation of **BrainAgeNeXt** and its potential applications.
Note: this app allows to process only a single MRI at the time. Please visit our [GitHub repository](https://github.com/FrancescoLR/BrainAgeNeXt/tree/main) to install the code on your machine and run BrainAgeNeXt on large datasets )
""")
with gr.Row():
with gr.Column(scale=1):
#mri_input = gr.File(label="Upload a T1w MRI (NIfTI .nii.gz)")
flair_input = gr.File(label="Upload a FLAIR Image (.nii.gz)")
age_input = gr.Number(label="Enter Age", value=50)
sex_input = gr.Radio(["Male", "Female"], label="Select Sex")
submit_button = gr.Button("Predict")
with gr.Column(scale=2):
#brain_age_output = gr.Textbox(label="Predicted Brain Age", interactive=False)
#bad_output = gr.HTML(label="Brain Age Difference") # Use gr.HTML for colored text
brain_age_output = gr.Textbox(label="Predicted Brain Age", interactive=False)
bad_output = gr.Textbox(label="Brain Age Difference", interactive=False)
submit_button.click(
fn=predict_brain_age,
inputs=[flair_input, age_input, sex_input],
outputs=[brain_age_output, bad_output]
)
gr.Markdown("""
**Disclaimer:** This is a research tool and is not intended for clinical use.
**If you find this tool useful, please consider citing:**
1. La Rosa, F., Dos Santos Silva, J., Dereskewicz, E., Invernizzi, A., Cahan, N., Galasso, J., ... & Beck, E. S. (2025).
BrainAgeNeXt: Advancing Brain Age Modeling for Individuals with Multiple Sclerosis. Imaging Neuroscience.
DOI: [10.1162/imag_a_00487](https://doi.org/10.1162/imag_a_00487)
2. Roy, S., Koehler, G., Ulrich, C., Baumgartner, M., Petersen, J., Isensee, F., Jaeger, P.F. & Maier-Hein, K. (2023).
MedNeXt: Transformer-driven Scaling of ConvNets for Medical Image Segmentation.
International Conference on Medical Image Computing and Computer-Assisted Intervention (MICCAI).
""")
# ๐Ÿ”น Debugging GPU Environment
if torch.cuda.is_available():
print(f"GPU available: {torch.cuda.get_device_name(0)}")
device = torch.device("cuda")
else:
print("No GPU detected. Falling back to CPU.")
os.system("nvidia-smi")
device = torch.device("cpu")
# Download Model Weights
download_model()
# Run Gradio App
if __name__ == "__main__":
demo.launch(share=True)