Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import torch | |
import numpy as np | |
import os | |
import nibabel as nib | |
import torchio | |
import torch.nn as nn | |
import subprocess | |
import spaces # Import spaces for GPU decoration | |
from scipy.ndimage.measurements import center_of_mass | |
from huggingface_hub import hf_hub_download | |
from monai.transforms import Compose, LoadImaged, Spacingd, CropForegroundd, SpatialPadd, CenterSpatialCropd | |
from monai.data import Dataset, DataLoader | |
from nnunet_mednext import create_mednext_encoder_v1 | |
# Model and data directory setup | |
MODEL_DIR = "/root/.cache/huggingface/hub" | |
DATASET_DIR = os.path.join(MODEL_DIR, "BrainAgeNeXt") | |
REPO_ID = "FrancescoLR/BrainAgeNeXt" | |
# Ensure model directory exists | |
os.makedirs(MODEL_DIR, exist_ok=True) | |
# ๐น Function to Download Model Weights from Hugging Face | |
def download_model(): | |
if not os.path.exists(DATASET_DIR): | |
os.makedirs(DATASET_DIR, exist_ok=True) | |
print("Downloading BrainAgeNeXt model weights...") | |
for i in range(1, 6): | |
hf_hub_download(repo_id=REPO_ID, filename=f"BrainAge_{i}.pth", cache_dir=MODEL_DIR) | |
print("BrainAgeNeXt model downloaded successfully.") | |
# ๐น Function to Load Model | |
def initialize_model(): | |
model_paths = [hf_hub_download(repo_id=REPO_ID, filename=f"BrainAge_{i}.pth", cache_dir=MODEL_DIR) for i in range(1, 6)] | |
models = [] | |
for model_path in model_paths: | |
model = MedNeXtEncReg().to(device) | |
model.load_state_dict(torch.load(model_path, map_location=device)) | |
model.eval() | |
models.append(model) | |
return models | |
# ๐น Define Model | |
class MedNeXtEncReg(nn.Module): | |
def __init__(self): | |
super(MedNeXtEncReg, self).__init__() | |
self.mednextv1 = create_mednext_encoder_v1( | |
num_input_channels=1, num_classes=1, model_id='B', kernel_size=3, deep_supervision=True | |
) | |
self.global_avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1)) | |
self.regression_fc = nn.Sequential( | |
nn.Linear(512, 64), | |
nn.ReLU(), | |
nn.Dropout(0.0), | |
nn.Linear(64, 1) | |
) | |
def forward(self, x): | |
x = self.mednextv1(x) | |
x = self.global_avg_pool(x) | |
x = torch.flatten(x, start_dim=1) | |
age_estimate = self.regression_fc(x) | |
return age_estimate.squeeze() | |
# ๐น Preprocessing Pipeline | |
def prepare_transforms(): | |
return Compose([ | |
LoadImaged(keys=["image"], ensure_channel_first=True), | |
Spacingd(keys=["image"], pixdim=(1.0, 1.0, 1.0)), | |
CropForegroundd(keys=["image"], allow_smaller=True, source_key="image"), | |
SpatialPadd(keys=["image"], spatial_size=(160, 192, 160)), | |
CenterSpatialCropd(keys=["image"], roi_size=(160, 192, 160)), | |
torchio.transforms.ZNormalization(masking_method=lambda x: x > 0, keys=["image"]) | |
]) | |
# ๐น Process MRI File | |
def preprocess_mri(nifti_path): | |
transforms = prepare_transforms() | |
data_dict = {"image": nifti_path} | |
dataset = Dataset([data_dict], transform=transforms) | |
dataloader = DataLoader(dataset, batch_size=1, num_workers=0) | |
return next(iter(dataloader))["image"].to(device) | |
# ๐น Run Brain Age Prediction (Decorated for GPU Execution) | |
def predict_brain_age(nifti_file, actual_age, sex): | |
if not os.path.exists(nifti_file.name): | |
return "Error: MRI file not found" | |
# Load Model | |
models = initialize_model() | |
# Preprocess MRI | |
image = preprocess_mri(nifti_file.name) | |
# Run Predictions | |
predictions = [] | |
with torch.no_grad(): | |
for model in models: | |
pred = model(image) | |
predictions.append(pred.cpu().numpy()) | |
# Compute Median Brain Age Prediction | |
predicted_brain_age = np.median(np.stack(predictions)) | |
# Apply Correction Based on Actual Age | |
predicted_brain_age_corrected = ( | |
predicted_brain_age + (actual_age * 0.062) - 2.96 if actual_age > 18 else predicted_brain_age | |
) | |
brain_age_difference = predicted_brain_age - actual_age | |
# Determine color: Red if positive, Green if negative | |
color = "red" if brain_age_difference > 0 else "green" | |
bad_output_html = f"<span style='color:{color}; font-weight:bold;'>Brain Age Difference (BAD): {brain_age_difference:.2f} years</span>" | |
# Return formatted outputs | |
return f"Predicted Brain Age: {predicted_brain_age_corrected:.2f} years", bad_output_html | |
# ๐น Gradio Interface Setup | |
with gr.Blocks() as demo: | |
gr.Markdown(""" | |
# ๐ง **BrainAgeNeXt**: Advancing Brain Age Modeling | |
Upload a preprocessed T1w MRI scan (.nii.gz), enter the age and sex of the subject, and get the brain age prediction. | |
The following preprocessing steps are required. | |
1. Skull-stripping using [SynthStrip](https://surfer.nmr.mgh.harvard.edu/docs/synthstrip/) with the `--no-csf` flag for optimal results. | |
2. N4 bias field correction using [ANTs](https://github.com/ANTsX/ANTs/wiki/N4BiasFieldCorrection). | |
3. Affine registration to the MNI 1mm isotropic template space. | |
**BrainAgeNeXt** has been trained and validated using over 11,000 T1w MRI acquired at 1.5, 3, and 7T. A 1mm isotropic resolution is preferred for the input image but not required. Our [manuscript](https://doi.org/10.1162/imag_a_00487) presents a detailed explanation of **BrainAgeNeXt** and its potential applications. | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
mri_input = gr.File(label="Upload a T1w MRI (NIfTI .nii.gz)") | |
age_input = gr.Number(label="Enter Age", value=50) | |
sex_input = gr.Radio(["Male", "Female"], label="Select Sex") | |
submit_button = gr.Button("Predict") | |
with gr.Column(scale=2): | |
brain_age_output = gr.Textbox(label="Predicted Brain Age", interactive=False) | |
bad_output = gr.HTML(label="Brain Age Difference") # Use gr.HTML for colored text | |
submit_button.click( | |
fn=predict_brain_age, | |
inputs=[mri_input, age_input, sex_input], | |
outputs=[brain_age_output, bad_output] | |
) | |
gr.Markdown(""" | |
**Disclaimer:** This is a research tool and is not intended for clinical use. | |
**If you find this tool useful, please consider citing:** | |
1. La Rosa, F., Dos Santos Silva, J., Dereskewicz, E., Invernizzi, A., Cahan, N., Galasso, J., ... & Beck, E. S. (2025). | |
BrainAgeNeXt: Advancing Brain Age Modeling for Individuals with Multiple Sclerosis. Imaging Neuroscience. | |
DOI: [10.1162/imag_a_00487](https://doi.org/10.1162/imag_a_00487) | |
2. Roy, S., Koehler, G., Ulrich, C., Baumgartner, M., Petersen, J., Isensee, F., Jaeger, P.F. & Maier-Hein, K. (2023). | |
MedNeXt: Transformer-driven Scaling of ConvNets for Medical Image Segmentation. | |
International Conference on Medical Image Computing and Computer-Assisted Intervention (MICCAI). | |
""") | |
# ๐น Debugging GPU Environment | |
if torch.cuda.is_available(): | |
print(f"GPU available: {torch.cuda.get_device_name(0)}") | |
device = torch.device("cuda") | |
else: | |
print("No GPU detected. Falling back to CPU.") | |
os.system("nvidia-smi") | |
device = torch.device("cpu") | |
# ๐น Download Model Weights | |
download_model() | |
# ๐น Run Gradio App | |
if __name__ == "__main__": | |
demo.launch(share=True) |