Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,229 Bytes
e148d83 73bc6a0 e148d83 73bc6a0 e133e90 73bc6a0 e148d83 73bc6a0 cf79142 e148d83 73bc6a0 989a90a 73bc6a0 5456f4c 73bc6a0 e148d83 73bc6a0 e148d83 73bc6a0 e148d83 73bc6a0 e148d83 73bc6a0 e148d83 73bc6a0 e148d83 73bc6a0 e148d83 73bc6a0 e148d83 73bc6a0 e148d83 73bc6a0 e148d83 73bc6a0 e148d83 893aaeb 9d8a150 7c9c0b1 e148d83 9d8a150 6c1f573 e148d83 73bc6a0 4d5997b 777c129 877f6af 893aaeb 877f6af 9d8a150 877f6af 777c129 2b53bfd 73bc6a0 777c129 30e4a53 73bc6a0 9d8a150 6c1f573 73bc6a0 7d444b3 5456f4c 73bc6a0 e148d83 73bc6a0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
import gradio as gr
import torch
import numpy as np
import os
import nibabel as nib
import torchio
import torch.nn as nn
import subprocess
import spaces # Import spaces for GPU decoration
from scipy.ndimage.measurements import center_of_mass
from huggingface_hub import hf_hub_download
from monai.transforms import Compose, LoadImaged, Spacingd, CropForegroundd, SpatialPadd, CenterSpatialCropd
from monai.data import Dataset, DataLoader
from nnunet_mednext import create_mednext_encoder_v1
# Model and data directory setup
MODEL_DIR = "/root/.cache/huggingface/hub"
DATASET_DIR = os.path.join(MODEL_DIR, "BrainAgeNeXt")
REPO_ID = "FrancescoLR/BrainAgeNeXt"
# Ensure model directory exists
os.makedirs(MODEL_DIR, exist_ok=True)
# ๐น Function to Download Model Weights from Hugging Face
def download_model():
if not os.path.exists(DATASET_DIR):
os.makedirs(DATASET_DIR, exist_ok=True)
print("Downloading BrainAgeNeXt model weights...")
for i in range(1, 6):
hf_hub_download(repo_id=REPO_ID, filename=f"BrainAge_{i}.pth", cache_dir=MODEL_DIR)
print("BrainAgeNeXt model downloaded successfully.")
# ๐น Function to Load Model
def initialize_model():
model_paths = [hf_hub_download(repo_id=REPO_ID, filename=f"BrainAge_{i}.pth", cache_dir=MODEL_DIR) for i in range(1, 6)]
models = []
for model_path in model_paths:
model = MedNeXtEncReg().to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
models.append(model)
return models
# ๐น Define Model
class MedNeXtEncReg(nn.Module):
def __init__(self):
super(MedNeXtEncReg, self).__init__()
self.mednextv1 = create_mednext_encoder_v1(
num_input_channels=1, num_classes=1, model_id='B', kernel_size=3, deep_supervision=True
)
self.global_avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1))
self.regression_fc = nn.Sequential(
nn.Linear(512, 64),
nn.ReLU(),
nn.Dropout(0.0),
nn.Linear(64, 1)
)
def forward(self, x):
x = self.mednextv1(x)
x = self.global_avg_pool(x)
x = torch.flatten(x, start_dim=1)
age_estimate = self.regression_fc(x)
return age_estimate.squeeze()
# ๐น Preprocessing Pipeline
def prepare_transforms():
return Compose([
LoadImaged(keys=["image"], ensure_channel_first=True),
Spacingd(keys=["image"], pixdim=(1.0, 1.0, 1.0)),
CropForegroundd(keys=["image"], allow_smaller=True, source_key="image"),
SpatialPadd(keys=["image"], spatial_size=(160, 192, 160)),
CenterSpatialCropd(keys=["image"], roi_size=(160, 192, 160)),
torchio.transforms.ZNormalization(masking_method=lambda x: x > 0, keys=["image"])
])
# ๐น Process MRI File
def preprocess_mri(nifti_path):
transforms = prepare_transforms()
data_dict = {"image": nifti_path}
dataset = Dataset([data_dict], transform=transforms)
dataloader = DataLoader(dataset, batch_size=1, num_workers=0)
return next(iter(dataloader))["image"].to(device)
# ๐น Run Brain Age Prediction (Decorated for GPU Execution)
@spaces.GPU(duration=90)
def predict_brain_age(nifti_file, actual_age, sex):
if not os.path.exists(nifti_file.name):
return "Error: MRI file not found"
# Load Model
models = initialize_model()
# Preprocess MRI
image = preprocess_mri(nifti_file.name)
# Run Predictions
predictions = []
with torch.no_grad():
for model in models:
pred = model(image)
predictions.append(pred.cpu().numpy())
# Compute Median Brain Age Prediction
predicted_brain_age = np.median(np.stack(predictions))
# Apply Correction Based on Actual Age
predicted_brain_age_corrected = (
predicted_brain_age + (actual_age * 0.062) - 2.96 if actual_age > 18 else predicted_brain_age
)
brain_age_difference = predicted_brain_age - actual_age
# Determine color: Red if positive, Green if negative
color = "red" if brain_age_difference > 0 else "green"
bad_output_html = f"<span style='color:{color}; font-weight:bold;'>Brain Age Difference: {brain_age_difference:.2f} years</span>"
# Return formatted outputs
return f"Predicted Brain Age: {predicted_brain_age_corrected:.2f} years", bad_output_html
# ๐น Gradio Interface Setup
with gr.Blocks() as demo:
gr.Markdown("""
# ๐ง **BrainAgeNeXt**: Advancing Brain Age Modeling
Upload a preprocessed T1w MRI scan (.nii.gz), enter the age and sex of the subject, and get the brain age prediction.
The following preprocessing steps are required.
1. Skull-stripping using [SynthStrip](https://surfer.nmr.mgh.harvard.edu/docs/synthstrip/) with the `--no-csf` flag for optimal results.
2. N4 bias field correction using [ANTs](https://github.com/ANTsX/ANTs/wiki/N4BiasFieldCorrection).
3. Affine registration to the MNI 1mm isotropic template space.
**BrainAgeNeXt** has been trained and validated using over 11,000 T1w MRI acquired at 1.5, 3, and 7T. A 1mm isotropic resolution is preferred for the input image but not required. Our [manuscript](https://doi.org/10.1162/imag_a_00487) presents a detailed explanation of **BrainAgeNeXt** and its potential applications.
""")
with gr.Row():
with gr.Column(scale=1):
mri_input = gr.File(label="Upload a T1w MRI (NIfTI .nii.gz)")
age_input = gr.Number(label="Enter Age", value=50)
sex_input = gr.Radio(["Male", "Female"], label="Select Sex")
submit_button = gr.Button("Predict")
with gr.Column(scale=2):
brain_age_output = gr.Textbox(label="Predicted Brain Age", interactive=False)
bad_output = gr.HTML(label="Brain Age Difference") # Use gr.HTML for colored text
submit_button.click(
fn=predict_brain_age,
inputs=[mri_input, age_input, sex_input],
outputs=[brain_age_output, bad_output]
)
gr.Markdown("""
**Disclaimer:** This is a research tool and is not intended for clinical use.
**If you find this tool useful, please consider citing:**
1. La Rosa, F., Dos Santos Silva, J., Dereskewicz, E., Invernizzi, A., Cahan, N., Galasso, J., ... & Beck, E. S. (2025).
BrainAgeNeXt: Advancing Brain Age Modeling for Individuals with Multiple Sclerosis. Imaging Neuroscience.
DOI: [10.1162/imag_a_00487](https://doi.org/10.1162/imag_a_00487)
2. Roy, S., Koehler, G., Ulrich, C., Baumgartner, M., Petersen, J., Isensee, F., Jaeger, P.F. & Maier-Hein, K. (2023).
MedNeXt: Transformer-driven Scaling of ConvNets for Medical Image Segmentation.
International Conference on Medical Image Computing and Computer-Assisted Intervention (MICCAI).
""")
# ๐น Debugging GPU Environment
if torch.cuda.is_available():
print(f"GPU available: {torch.cuda.get_device_name(0)}")
device = torch.device("cuda")
else:
print("No GPU detected. Falling back to CPU.")
os.system("nvidia-smi")
device = torch.device("cpu")
# ๐น Download Model Weights
download_model()
# ๐น Run Gradio App
if __name__ == "__main__":
demo.launch(share=True) |