File size: 4,319 Bytes
5e7b2d7
 
 
 
 
 
 
 
 
 
 
 
 
e148d83
 
 
 
 
 
 
 
 
 
81ef1c7
e148d83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import pkg_resources
import sys
import subprocess

# Print installed packages
installed_packages = [pkg.key for pkg in pkg_resources.working_set]
print("Installed packages:", installed_packages)

# Alternatively, print as a list
subprocess.run([sys.executable, "-m", "pip", "list"])



import gradio as gr
import torch
import numpy as np
import pandas as pd
import os
import torchio
import torch.nn as nn
from huggingface_hub import hf_hub_download
from monai.transforms import Compose, LoadImaged, Spacingd, CropForegroundd, SpatialPadd, CenterSpatialCropd
from monai.data import Dataset
from mednextv1 import create_mednext_encoder_v1

# Device selection
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Model definition
class MedNeXtEncReg(nn.Module):
    def __init__(self):
        super(MedNeXtEncReg, self).__init__()
        self.mednextv1 = create_mednext_encoder_v1(
            num_input_channels=1, num_classes=1, model_id='B', kernel_size=3, deep_supervision=True
        )
        self.global_avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1))
        self.regression_fc = nn.Sequential(
            nn.Linear(512, 64),
            nn.ReLU(),
            nn.Dropout(0.0),
            nn.Linear(64, 1)
        )

    def forward(self, x):
        x = self.mednextv1(x)
        x = self.global_avg_pool(x)
        x = torch.flatten(x, start_dim=1)
        age_estimate = self.regression_fc(x)
        return age_estimate.squeeze()

# Download the model from Hugging Face Hub
def initialize_model():
    model_paths = [
        hf_hub_download(repo_id="FrancescoLR/BrainAgeNeXt", filename=f"BrainAge_{i}.pth") for i in range(1, 6)
    ]
    
    models = []
    for model_path in model_paths:
        model = MedNeXtEncReg().to(device)
        model.load_state_dict(torch.load(model_path, map_location=device))
        model.eval()
        models.append(model)
    return models

# Define preprocessing transforms
def prepare_transforms():
    return Compose([
        LoadImaged(keys=["image"], ensure_channel_first=True),
        Spacingd(keys=["image"], pixdim=(1.0, 1.0, 1.0)),
        CropForegroundd(keys=["image"], allow_smaller=True, source_key="image"),
        SpatialPadd(keys=["image"], spatial_size=(160, 192, 160)),
        CenterSpatialCropd(keys=["image"], roi_size=(160, 192, 160)),
        torchio.transforms.ZNormalization(masking_method=lambda x: x > 0, keys=["image"])
    ])

# Process uploaded MRI scan
def preprocess_mri(mri_path):
    transforms = prepare_transforms()
    data_dict = {"image": mri_path}
    dataset = Dataset([data_dict], transform=transforms)
    dataloader = DataLoader(dataset, batch_size=1, num_workers=0)
    return next(iter(dataloader))["image"].to(device)

# Predict brain age
def predict_brain_age(mri_path, actual_age, sex):
    if not os.path.exists(mri_path):
        return "Error: MRI file not found"

    # Load the model
    models = initialize_model()

    # Preprocess MRI
    image = preprocess_mri(mri_path)

    # Run predictions
    predictions = []
    with torch.no_grad():
        for model in models:
            pred = model(image)
            predictions.append(pred.cpu().numpy())

    # Compute median brain age prediction
    predicted_brain_age = np.median(np.stack(predictions))

    # Apply correction based on actual age
    predicted_brain_age_corrected = (
        predicted_brain_age + (actual_age * 0.062) - 2.96 if actual_age > 18 else predicted_brain_age
    )

    brain_age_difference = predicted_brain_age_corrected - actual_age

    # Output results
    return f"Predicted Brain Age: {predicted_brain_age_corrected:.2f} years", \
           f"Brain Age Difference (BAD): {brain_age_difference:.2f} years"

# Gradio UI
iface = gr.Interface(
    fn=predict_brain_age,
    inputs=[
        gr.File(label="Upload MRI (NIfTI .nii.gz)"),
        gr.Number(label="Enter Age"),
        gr.Radio(["Male", "Female"], label="Select Sex")
    ],
    outputs=[
        gr.Textbox(label="Predicted Brain Age"),
        gr.Textbox(label="Brain Age Difference (BAD)")
    ],
    title="Brain Age Prediction with MedNeXt",
    description="Upload an MRI scan (.nii.gz), enter your age and sex, and get a brain age prediction.",
    theme="default"
)

# Launch the Gradio app
if __name__ == "__main__":
    iface.launch()