BrainAgeNeXt / app.py
FrancescoLR's picture
Update app.py
877f6af verified
raw
history blame
5.96 kB
import gradio as gr
import torch
import numpy as np
import os
import nibabel as nib
import torchio
import torch.nn as nn
import subprocess
import spaces # Import spaces for GPU decoration
from scipy.ndimage.measurements import center_of_mass
from huggingface_hub import hf_hub_download
from monai.transforms import Compose, LoadImaged, Spacingd, CropForegroundd, SpatialPadd, CenterSpatialCropd
from monai.data import Dataset, DataLoader
from nnunet_mednext import create_mednext_encoder_v1
# Model and data directory setup
MODEL_DIR = "/root/.cache/huggingface/hub"
DATASET_DIR = os.path.join(MODEL_DIR, "BrainAgeNeXt")
REPO_ID = "FrancescoLR/BrainAgeNeXt"
# Ensure model directory exists
os.makedirs(MODEL_DIR, exist_ok=True)
# πŸ”Ή Function to Download Model Weights from Hugging Face
def download_model():
if not os.path.exists(DATASET_DIR):
os.makedirs(DATASET_DIR, exist_ok=True)
print("Downloading BrainAgeNeXt model weights...")
for i in range(1, 6):
hf_hub_download(repo_id=REPO_ID, filename=f"BrainAge_{i}.pth", cache_dir=MODEL_DIR)
print("βœ… BrainAgeNeXt model downloaded successfully.")
# πŸ”Ή Function to Load Model
def initialize_model():
model_paths = [hf_hub_download(repo_id=REPO_ID, filename=f"BrainAge_{i}.pth", cache_dir=MODEL_DIR) for i in range(1, 6)]
models = []
for model_path in model_paths:
model = MedNeXtEncReg().to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
models.append(model)
return models
# πŸ”Ή Define Model
class MedNeXtEncReg(nn.Module):
def __init__(self):
super(MedNeXtEncReg, self).__init__()
self.mednextv1 = create_mednext_encoder_v1(
num_input_channels=1, num_classes=1, model_id='B', kernel_size=3, deep_supervision=True
)
self.global_avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1))
self.regression_fc = nn.Sequential(
nn.Linear(512, 64),
nn.ReLU(),
nn.Dropout(0.0),
nn.Linear(64, 1)
)
def forward(self, x):
x = self.mednextv1(x)
x = self.global_avg_pool(x)
x = torch.flatten(x, start_dim=1)
age_estimate = self.regression_fc(x)
return age_estimate.squeeze()
# πŸ”Ή Preprocessing Pipeline
def prepare_transforms():
return Compose([
LoadImaged(keys=["image"], ensure_channel_first=True),
Spacingd(keys=["image"], pixdim=(1.0, 1.0, 1.0)),
CropForegroundd(keys=["image"], allow_smaller=True, source_key="image"),
SpatialPadd(keys=["image"], spatial_size=(160, 192, 160)),
CenterSpatialCropd(keys=["image"], roi_size=(160, 192, 160)),
torchio.transforms.ZNormalization(masking_method=lambda x: x > 0, keys=["image"])
])
# πŸ”Ή Process MRI File
def preprocess_mri(nifti_path):
transforms = prepare_transforms()
data_dict = {"image": nifti_path}
dataset = Dataset([data_dict], transform=transforms)
dataloader = DataLoader(dataset, batch_size=1, num_workers=0)
return next(iter(dataloader))["image"].to(device)
# πŸ”Ή Run Brain Age Prediction (Decorated for GPU Execution)
@spaces.GPU(duration=90)
def predict_brain_age(nifti_file, actual_age, sex):
if not os.path.exists(nifti_file.name):
return "Error: MRI file not found"
# Load Model
models = initialize_model()
# Preprocess MRI
image = preprocess_mri(nifti_file.name)
# Run Predictions
predictions = []
with torch.no_grad():
for model in models:
pred = model(image)
predictions.append(pred.cpu().numpy())
# Compute Median Brain Age Prediction
predicted_brain_age = np.median(np.stack(predictions))
# Apply Correction Based on Actual Age
predicted_brain_age_corrected = (
predicted_brain_age + (actual_age * 0.062) - 2.96 if actual_age > 18 else predicted_brain_age
)
brain_age_difference = predicted_brain_age_corrected - actual_age
# Output Results
return f"Predicted Brain Age: {predicted_brain_age_corrected:.2f} years", \
f"Brain Age Difference (BAD): {brain_age_difference:.2f} years"
# πŸ”Ή Gradio Interface Setup
with gr.Blocks() as demo:
gr.Markdown("""
# 🧠 Brain Age Prediction with BrainAgeNeXt
Upload a preprocessed T1w MRI scan (.nii.gz), enter the age and sex, and get a brain age prediction.
The following preprocessing are required.
1. Skull-stripping using [SynthStrip](https://surfer.nmr.mgh.harvard.edu/docs/synthstrip/) with the `--no-csf` flag for optimal results.
2. N4 bias field correction using [ANTs] (https://github.com/ANTsX/ANTs/wiki/N4BiasFieldCorrection).
3. Affine registration to the MNI 1mm isotropic template space.
""")
with gr.Row():
with gr.Column(scale=1):
mri_input = gr.File(label="Upload MRI (NIfTI .nii.gz)")
age_input = gr.Number(label="Enter Age", value=30)
sex_input = gr.Radio(["Male", "Female"], label="Select Sex")
submit_button = gr.Button("Predict")
with gr.Column(scale=2):
brain_age_output = gr.Textbox(label="Predicted Brain Age")
bad_output = gr.Textbox(label="Brain Age Difference (BAD)")
submit_button.click(
fn=predict_brain_age,
inputs=[mri_input, age_input, sex_input],
outputs=[brain_age_output, bad_output]
)
gr.Markdown("""
**Disclaimer:** This is a research tool and is not intended for clinical use.
""")
# πŸ”Ή Debugging GPU Environment
if torch.cuda.is_available():
print(f"GPU available: {torch.cuda.get_device_name(0)}")
device = torch.device("cuda")
else:
print("No GPU detected. Falling back to CPU.")
os.system("nvidia-smi")
device = torch.device("cpu")
# πŸ”Ή Download Model Weights
download_model()
# πŸ”Ή Run Gradio App
if __name__ == "__main__":
demo.launch(share=True)