Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import torch | |
import numpy as np | |
import pandas as pd | |
import os | |
import torchio | |
import torch.nn as nn | |
from huggingface_hub import hf_hub_download | |
from monai.transforms import Compose, LoadImaged, Spacingd, CropForegroundd, SpatialPadd, CenterSpatialCropd | |
from monai.data import Dataset | |
from nnunet_mednext import create_mednext_encoder_v1 | |
# Debugging GPU environment | |
if torch.cuda.is_available(): | |
print(f"GPU is available: {torch.cuda.get_device_name(0)}") | |
else: | |
print("No GPU available. Falling back to CPU.") | |
os.system("nvidia-smi") | |
# Device selection | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Model definition | |
class MedNeXtEncReg(nn.Module): | |
def __init__(self): | |
super(MedNeXtEncReg, self).__init__() | |
self.mednextv1 = create_mednext_encoder_v1( | |
num_input_channels=1, num_classes=1, model_id='B', kernel_size=3, deep_supervision=True | |
) | |
self.global_avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1)) | |
self.regression_fc = nn.Sequential( | |
nn.Linear(512, 64), | |
nn.ReLU(), | |
nn.Dropout(0.0), | |
nn.Linear(64, 1) | |
) | |
def forward(self, x): | |
x = self.mednextv1(x) | |
x = self.global_avg_pool(x) | |
x = torch.flatten(x, start_dim=1) | |
age_estimate = self.regression_fc(x) | |
return age_estimate.squeeze() | |
# Download the model from Hugging Face Hub | |
def initialize_model(): | |
model_paths = [ | |
hf_hub_download(repo_id="FrancescoLR/BrainAgeNeXt", filename=f"BrainAge_{i}.pth") for i in range(1, 6) | |
] | |
models = [] | |
for model_path in model_paths: | |
model = MedNeXtEncReg().to(device) | |
model.load_state_dict(torch.load(model_path, map_location=device)) | |
model.eval() | |
models.append(model) | |
return models | |
# Define preprocessing transforms | |
def prepare_transforms(): | |
return Compose([ | |
LoadImaged(keys=["image"], ensure_channel_first=True), | |
Spacingd(keys=["image"], pixdim=(1.0, 1.0, 1.0)), | |
CropForegroundd(keys=["image"], allow_smaller=True, source_key="image"), | |
SpatialPadd(keys=["image"], spatial_size=(160, 192, 160)), | |
CenterSpatialCropd(keys=["image"], roi_size=(160, 192, 160)), | |
torchio.transforms.ZNormalization(masking_method=lambda x: x > 0, keys=["image"]) | |
]) | |
# Process uploaded MRI scan | |
def preprocess_mri(mri_path): | |
transforms = prepare_transforms() | |
data_dict = {"image": mri_path} | |
dataset = Dataset([data_dict], transform=transforms) | |
dataloader = DataLoader(dataset, batch_size=1, num_workers=0) | |
return next(iter(dataloader))["image"].to(device) | |
# Predict brain age | |
def predict_brain_age(mri_path, actual_age, sex): | |
if not os.path.exists(mri_path): | |
return "Error: MRI file not found" | |
# Load the model | |
models = initialize_model() | |
# Preprocess MRI | |
image = preprocess_mri(mri_path) | |
# Run predictions | |
predictions = [] | |
with torch.no_grad(): | |
for model in models: | |
pred = model(image) | |
predictions.append(pred.cpu().numpy()) | |
# Compute median brain age prediction | |
predicted_brain_age = np.median(np.stack(predictions)) | |
# Apply correction based on actual age | |
predicted_brain_age_corrected = ( | |
predicted_brain_age + (actual_age * 0.062) - 2.96 if actual_age > 18 else predicted_brain_age | |
) | |
brain_age_difference = predicted_brain_age_corrected - actual_age | |
# Output results | |
return f"Predicted Brain Age: {predicted_brain_age_corrected:.2f} years", \ | |
f"Brain Age Difference (BAD): {brain_age_difference:.2f} years" | |
# Gradio UI | |
demo = gr.Interface( | |
fn=predict_brain_age, | |
inputs=[ | |
gr.File(label="Upload MRI (NIfTI .nii.gz)"), | |
gr.Number(label="Enter Age"), | |
gr.Radio(["Male", "Female"], label="Select Sex") | |
], | |
outputs=[ | |
gr.Textbox(label="Predicted Brain Age"), | |
gr.Textbox(label="Brain Age Difference (BAD)") | |
], | |
title="Brain Age Prediction with MedNeXt", | |
description="Upload an MRI scan (.nii.gz), enter your age and sex, and get a brain age prediction.", | |
theme="default" | |
) | |
# Launch the Gradio app | |
if __name__ == "__main__": | |
demo.launch(share=True) | |