import gradio as gr import torch import numpy as np import os import nibabel as nib import torchio import torch.nn as nn import subprocess import spaces # Import spaces for GPU decoration from scipy.ndimage.measurements import center_of_mass from huggingface_hub import hf_hub_download from monai.transforms import Compose, LoadImaged, Spacingd, CropForegroundd, SpatialPadd, CenterSpatialCropd from monai.data import Dataset, DataLoader from nnunet_mednext import create_mednext_encoder_v1 # Model and data directory setup MODEL_DIR = "/root/.cache/huggingface/hub" DATASET_DIR = os.path.join(MODEL_DIR, "BrainAgeNeXt") REPO_ID = "FrancescoLR/BrainAgeNeXt" # Ensure model directory exists os.makedirs(MODEL_DIR, exist_ok=True) # 🔹 Function to Download Model Weights from Hugging Face def download_model(): if not os.path.exists(DATASET_DIR): os.makedirs(DATASET_DIR, exist_ok=True) print("Downloading BrainAgeNeXt model weights...") for i in range(1, 6): hf_hub_download(repo_id=REPO_ID, filename=f"BrainAge_{i}.pth", cache_dir=MODEL_DIR) print("BrainAgeNeXt model downloaded successfully.") # 🔹 Function to Load Model def initialize_model(): model_paths = [hf_hub_download(repo_id=REPO_ID, filename=f"BrainAge_{i}.pth", cache_dir=MODEL_DIR) for i in range(1, 6)] models = [] for model_path in model_paths: model = MedNeXtEncReg().to(device) model.load_state_dict(torch.load(model_path, map_location=device)) model.eval() models.append(model) return models # 🔹 Define Model class MedNeXtEncReg(nn.Module): def __init__(self): super(MedNeXtEncReg, self).__init__() self.mednextv1 = create_mednext_encoder_v1( num_input_channels=1, num_classes=1, model_id='B', kernel_size=3, deep_supervision=True ) self.global_avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1)) self.regression_fc = nn.Sequential( nn.Linear(512, 64), nn.ReLU(), nn.Dropout(0.0), nn.Linear(64, 1) ) def forward(self, x): x = self.mednextv1(x) x = self.global_avg_pool(x) x = torch.flatten(x, start_dim=1) age_estimate = self.regression_fc(x) return age_estimate.squeeze() # 🔹 Preprocessing Pipeline def prepare_transforms(): return Compose([ LoadImaged(keys=["image"], ensure_channel_first=True), Spacingd(keys=["image"], pixdim=(1.0, 1.0, 1.0)), CropForegroundd(keys=["image"], allow_smaller=True, source_key="image"), SpatialPadd(keys=["image"], spatial_size=(160, 192, 160)), CenterSpatialCropd(keys=["image"], roi_size=(160, 192, 160)), torchio.transforms.ZNormalization(masking_method=lambda x: x > 0, keys=["image"]) ]) # 🔹 Process MRI File def preprocess_mri(nifti_path): transforms = prepare_transforms() data_dict = {"image": nifti_path} dataset = Dataset([data_dict], transform=transforms) dataloader = DataLoader(dataset, batch_size=1, num_workers=0) return next(iter(dataloader))["image"].to(device) # 🔹 Run Brain Age Prediction (Decorated for GPU Execution) @spaces.GPU(duration=90) def predict_brain_age(nifti_file, actual_age, sex): if not os.path.exists(nifti_file.name): return "Error: MRI file not found" # Load Model models = initialize_model() # Preprocess MRI image = preprocess_mri(nifti_file.name) # Run Predictions predictions = [] with torch.no_grad(): for model in models: pred = model(image) predictions.append(pred.cpu().numpy()) # Compute Median Brain Age Prediction predicted_brain_age = np.median(np.stack(predictions)) # Apply Correction Based on Actual Age predicted_brain_age_corrected = ( predicted_brain_age + (actual_age * 0.062) - 2.96 if actual_age > 18 else predicted_brain_age ) brain_age_difference = predicted_brain_age - actual_age # Determine color: Red if positive, Green if negative color = "red" if brain_age_difference > 0 else "green" bad_output_html = f"Brain Age Difference: {brain_age_difference:.2f} years" # Return formatted outputs return f"Predicted Brain Age: {predicted_brain_age_corrected:.2f} years", bad_output_html # 🔹 Gradio Interface Setup with gr.Blocks() as demo: gr.Markdown(""" # 🧠 **BrainAgeNeXt**: Advancing Brain Age Modeling Upload a preprocessed T1w MRI scan (.nii.gz), enter the age and sex of the subject, and get the brain age prediction. The following preprocessing steps are required. 1. Skull-stripping using [SynthStrip](https://surfer.nmr.mgh.harvard.edu/docs/synthstrip/) with the `--no-csf` flag for optimal results. 2. N4 bias field correction using [ANTs](https://github.com/ANTsX/ANTs/wiki/N4BiasFieldCorrection). 3. Affine registration to the MNI 1mm isotropic template space. **BrainAgeNeXt** has been trained and validated using over 11,000 T1w MRI acquired at 1.5, 3, and 7T. A 1mm isotropic resolution is preferred for the input image but not required. Our [manuscript](https://doi.org/10.1162/imag_a_00487) presents a detailed explanation of **BrainAgeNeXt** and its potential applications. """) with gr.Row(): with gr.Column(scale=1): mri_input = gr.File(label="Upload a T1w MRI (NIfTI .nii.gz)") age_input = gr.Number(label="Enter Age", value=50) sex_input = gr.Radio(["Male", "Female"], label="Select Sex") submit_button = gr.Button("Predict") with gr.Column(scale=2): brain_age_output = gr.Textbox(label="Predicted Brain Age", interactive=False) bad_output = gr.HTML(label="Brain Age Difference") # Use gr.HTML for colored text submit_button.click( fn=predict_brain_age, inputs=[mri_input, age_input, sex_input], outputs=[brain_age_output, bad_output] ) gr.Markdown(""" **Disclaimer:** This is a research tool and is not intended for clinical use. **If you find this tool useful, please consider citing:** 1. La Rosa, F., Dos Santos Silva, J., Dereskewicz, E., Invernizzi, A., Cahan, N., Galasso, J., ... & Beck, E. S. (2025). BrainAgeNeXt: Advancing Brain Age Modeling for Individuals with Multiple Sclerosis. Imaging Neuroscience. DOI: [10.1162/imag_a_00487](https://doi.org/10.1162/imag_a_00487) 2. Roy, S., Koehler, G., Ulrich, C., Baumgartner, M., Petersen, J., Isensee, F., Jaeger, P.F. & Maier-Hein, K. (2023). MedNeXt: Transformer-driven Scaling of ConvNets for Medical Image Segmentation. International Conference on Medical Image Computing and Computer-Assisted Intervention (MICCAI). """) # 🔹 Debugging GPU Environment if torch.cuda.is_available(): print(f"GPU available: {torch.cuda.get_device_name(0)}") device = torch.device("cuda") else: print("No GPU detected. Falling back to CPU.") os.system("nvidia-smi") device = torch.device("cpu") # 🔹 Download Model Weights download_model() # 🔹 Run Gradio App if __name__ == "__main__": demo.launch(share=True)