Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,038 Bytes
e148d83 cf79142 e148d83 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import gradio as gr
import torch
import numpy as np
import pandas as pd
import os
import torchio
import torch.nn as nn
from huggingface_hub import hf_hub_download
from monai.transforms import Compose, LoadImaged, Spacingd, CropForegroundd, SpatialPadd, CenterSpatialCropd
from monai.data import Dataset
from nnunet_mednext import create_mednext_encoder_v1
# Device selection
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Model definition
class MedNeXtEncReg(nn.Module):
def __init__(self):
super(MedNeXtEncReg, self).__init__()
self.mednextv1 = create_mednext_encoder_v1(
num_input_channels=1, num_classes=1, model_id='B', kernel_size=3, deep_supervision=True
)
self.global_avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1))
self.regression_fc = nn.Sequential(
nn.Linear(512, 64),
nn.ReLU(),
nn.Dropout(0.0),
nn.Linear(64, 1)
)
def forward(self, x):
x = self.mednextv1(x)
x = self.global_avg_pool(x)
x = torch.flatten(x, start_dim=1)
age_estimate = self.regression_fc(x)
return age_estimate.squeeze()
# Download the model from Hugging Face Hub
def initialize_model():
model_paths = [
hf_hub_download(repo_id="FrancescoLR/BrainAgeNeXt", filename=f"BrainAge_{i}.pth") for i in range(1, 6)
]
models = []
for model_path in model_paths:
model = MedNeXtEncReg().to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
models.append(model)
return models
# Define preprocessing transforms
def prepare_transforms():
return Compose([
LoadImaged(keys=["image"], ensure_channel_first=True),
Spacingd(keys=["image"], pixdim=(1.0, 1.0, 1.0)),
CropForegroundd(keys=["image"], allow_smaller=True, source_key="image"),
SpatialPadd(keys=["image"], spatial_size=(160, 192, 160)),
CenterSpatialCropd(keys=["image"], roi_size=(160, 192, 160)),
torchio.transforms.ZNormalization(masking_method=lambda x: x > 0, keys=["image"])
])
# Process uploaded MRI scan
def preprocess_mri(mri_path):
transforms = prepare_transforms()
data_dict = {"image": mri_path}
dataset = Dataset([data_dict], transform=transforms)
dataloader = DataLoader(dataset, batch_size=1, num_workers=0)
return next(iter(dataloader))["image"].to(device)
# Predict brain age
def predict_brain_age(mri_path, actual_age, sex):
if not os.path.exists(mri_path):
return "Error: MRI file not found"
# Load the model
models = initialize_model()
# Preprocess MRI
image = preprocess_mri(mri_path)
# Run predictions
predictions = []
with torch.no_grad():
for model in models:
pred = model(image)
predictions.append(pred.cpu().numpy())
# Compute median brain age prediction
predicted_brain_age = np.median(np.stack(predictions))
# Apply correction based on actual age
predicted_brain_age_corrected = (
predicted_brain_age + (actual_age * 0.062) - 2.96 if actual_age > 18 else predicted_brain_age
)
brain_age_difference = predicted_brain_age_corrected - actual_age
# Output results
return f"Predicted Brain Age: {predicted_brain_age_corrected:.2f} years", \
f"Brain Age Difference (BAD): {brain_age_difference:.2f} years"
# Gradio UI
iface = gr.Interface(
fn=predict_brain_age,
inputs=[
gr.File(label="Upload MRI (NIfTI .nii.gz)"),
gr.Number(label="Enter Age"),
gr.Radio(["Male", "Female"], label="Select Sex")
],
outputs=[
gr.Textbox(label="Predicted Brain Age"),
gr.Textbox(label="Brain Age Difference (BAD)")
],
title="Brain Age Prediction with MedNeXt",
description="Upload an MRI scan (.nii.gz), enter your age and sex, and get a brain age prediction.",
theme="default"
)
# Launch the Gradio app
if __name__ == "__main__":
iface.launch()
|