Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,255 Bytes
e148d83 cf79142 e148d83 989a90a e148d83 989a90a e148d83 989a90a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import gradio as gr
import torch
import numpy as np
import pandas as pd
import os
import torchio
import torch.nn as nn
from huggingface_hub import hf_hub_download
from monai.transforms import Compose, LoadImaged, Spacingd, CropForegroundd, SpatialPadd, CenterSpatialCropd
from monai.data import Dataset
from nnunet_mednext import create_mednext_encoder_v1
# Debugging GPU environment
if torch.cuda.is_available():
print(f"GPU is available: {torch.cuda.get_device_name(0)}")
else:
print("No GPU available. Falling back to CPU.")
os.system("nvidia-smi")
# Device selection
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Model definition
class MedNeXtEncReg(nn.Module):
def __init__(self):
super(MedNeXtEncReg, self).__init__()
self.mednextv1 = create_mednext_encoder_v1(
num_input_channels=1, num_classes=1, model_id='B', kernel_size=3, deep_supervision=True
)
self.global_avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1))
self.regression_fc = nn.Sequential(
nn.Linear(512, 64),
nn.ReLU(),
nn.Dropout(0.0),
nn.Linear(64, 1)
)
def forward(self, x):
x = self.mednextv1(x)
x = self.global_avg_pool(x)
x = torch.flatten(x, start_dim=1)
age_estimate = self.regression_fc(x)
return age_estimate.squeeze()
# Download the model from Hugging Face Hub
def initialize_model():
model_paths = [
hf_hub_download(repo_id="FrancescoLR/BrainAgeNeXt", filename=f"BrainAge_{i}.pth") for i in range(1, 6)
]
models = []
for model_path in model_paths:
model = MedNeXtEncReg().to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
models.append(model)
return models
# Define preprocessing transforms
def prepare_transforms():
return Compose([
LoadImaged(keys=["image"], ensure_channel_first=True),
Spacingd(keys=["image"], pixdim=(1.0, 1.0, 1.0)),
CropForegroundd(keys=["image"], allow_smaller=True, source_key="image"),
SpatialPadd(keys=["image"], spatial_size=(160, 192, 160)),
CenterSpatialCropd(keys=["image"], roi_size=(160, 192, 160)),
torchio.transforms.ZNormalization(masking_method=lambda x: x > 0, keys=["image"])
])
# Process uploaded MRI scan
def preprocess_mri(mri_path):
transforms = prepare_transforms()
data_dict = {"image": mri_path}
dataset = Dataset([data_dict], transform=transforms)
dataloader = DataLoader(dataset, batch_size=1, num_workers=0)
return next(iter(dataloader))["image"].to(device)
# Predict brain age
def predict_brain_age(mri_path, actual_age, sex):
if not os.path.exists(mri_path):
return "Error: MRI file not found"
# Load the model
models = initialize_model()
# Preprocess MRI
image = preprocess_mri(mri_path)
# Run predictions
predictions = []
with torch.no_grad():
for model in models:
pred = model(image)
predictions.append(pred.cpu().numpy())
# Compute median brain age prediction
predicted_brain_age = np.median(np.stack(predictions))
# Apply correction based on actual age
predicted_brain_age_corrected = (
predicted_brain_age + (actual_age * 0.062) - 2.96 if actual_age > 18 else predicted_brain_age
)
brain_age_difference = predicted_brain_age_corrected - actual_age
# Output results
return f"Predicted Brain Age: {predicted_brain_age_corrected:.2f} years", \
f"Brain Age Difference (BAD): {brain_age_difference:.2f} years"
# Gradio UI
demo = gr.Interface(
fn=predict_brain_age,
inputs=[
gr.File(label="Upload MRI (NIfTI .nii.gz)"),
gr.Number(label="Enter Age"),
gr.Radio(["Male", "Female"], label="Select Sex")
],
outputs=[
gr.Textbox(label="Predicted Brain Age"),
gr.Textbox(label="Brain Age Difference (BAD)")
],
title="Brain Age Prediction with MedNeXt",
description="Upload an MRI scan (.nii.gz), enter your age and sex, and get a brain age prediction.",
theme="default"
)
# Launch the Gradio app
if __name__ == "__main__":
demo.launch(share=True)
|