FrancescoLR commited on
Commit
e148d83
·
verified ·
1 Parent(s): 4f6004d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +122 -0
app.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ import pandas as pd
5
+ import os
6
+ import torchio
7
+ import torch.nn as nn
8
+ from huggingface_hub import hf_hub_download
9
+ from monai.transforms import Compose, LoadImaged, Spacingd, CropForegroundd, SpatialPadd, CenterSpatialCropd
10
+ from monai.data import Dataset
11
+ from nnunet_mednext import create_mednext_encoder_v1
12
+
13
+ # Device selection
14
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+
16
+ # Model definition
17
+ class MedNeXtEncReg(nn.Module):
18
+ def __init__(self):
19
+ super(MedNeXtEncReg, self).__init__()
20
+ self.mednextv1 = create_mednext_encoder_v1(
21
+ num_input_channels=1, num_classes=1, model_id='B', kernel_size=3, deep_supervision=True
22
+ )
23
+ self.global_avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1))
24
+ self.regression_fc = nn.Sequential(
25
+ nn.Linear(512, 64),
26
+ nn.ReLU(),
27
+ nn.Dropout(0.0),
28
+ nn.Linear(64, 1)
29
+ )
30
+
31
+ def forward(self, x):
32
+ x = self.mednextv1(x)
33
+ x = self.global_avg_pool(x)
34
+ x = torch.flatten(x, start_dim=1)
35
+ age_estimate = self.regression_fc(x)
36
+ return age_estimate.squeeze()
37
+
38
+ # Download the model from Hugging Face Hub
39
+ def initialize_model():
40
+ model_paths = [
41
+ hf_hub_download(repo_id="FrancescoLR/BrainAgeNeXt", filename=f"BrainAge_{i}.pth") for i in range(1, 6)
42
+ ]
43
+
44
+ models = []
45
+ for model_path in model_paths:
46
+ model = MedNeXtEncReg().to(device)
47
+ model.load_state_dict(torch.load(model_path, map_location=device))
48
+ model.eval()
49
+ models.append(model)
50
+ return models
51
+
52
+ # Define preprocessing transforms
53
+ def prepare_transforms():
54
+ return Compose([
55
+ LoadImaged(keys=["image"], ensure_channel_first=True),
56
+ Spacingd(keys=["image"], pixdim=(1.0, 1.0, 1.0)),
57
+ CropForegroundd(keys=["image"], allow_smaller=True, source_key="image"),
58
+ SpatialPadd(keys=["image"], spatial_size=(160, 192, 160)),
59
+ CenterSpatialCropd(keys=["image"], roi_size=(160, 192, 160)),
60
+ torchio.transforms.ZNormalization(masking_method=lambda x: x > 0, keys=["image"])
61
+ ])
62
+
63
+ # Process uploaded MRI scan
64
+ def preprocess_mri(mri_path):
65
+ transforms = prepare_transforms()
66
+ data_dict = {"image": mri_path}
67
+ dataset = Dataset([data_dict], transform=transforms)
68
+ dataloader = DataLoader(dataset, batch_size=1, num_workers=0)
69
+ return next(iter(dataloader))["image"].to(device)
70
+
71
+ # Predict brain age
72
+ def predict_brain_age(mri_path, actual_age, sex):
73
+ if not os.path.exists(mri_path):
74
+ return "Error: MRI file not found"
75
+
76
+ # Load the model
77
+ models = initialize_model()
78
+
79
+ # Preprocess MRI
80
+ image = preprocess_mri(mri_path)
81
+
82
+ # Run predictions
83
+ predictions = []
84
+ with torch.no_grad():
85
+ for model in models:
86
+ pred = model(image)
87
+ predictions.append(pred.cpu().numpy())
88
+
89
+ # Compute median brain age prediction
90
+ predicted_brain_age = np.median(np.stack(predictions))
91
+
92
+ # Apply correction based on actual age
93
+ predicted_brain_age_corrected = (
94
+ predicted_brain_age + (actual_age * 0.062) - 2.96 if actual_age > 18 else predicted_brain_age
95
+ )
96
+
97
+ brain_age_difference = predicted_brain_age_corrected - actual_age
98
+
99
+ # Output results
100
+ return f"Predicted Brain Age: {predicted_brain_age_corrected:.2f} years", \
101
+ f"Brain Age Difference (BAD): {brain_age_difference:.2f} years"
102
+
103
+ # Gradio UI
104
+ iface = gr.Interface(
105
+ fn=predict_brain_age,
106
+ inputs=[
107
+ gr.File(label="Upload MRI (NIfTI .nii.gz)"),
108
+ gr.Number(label="Enter Age"),
109
+ gr.Radio(["Male", "Female"], label="Select Sex")
110
+ ],
111
+ outputs=[
112
+ gr.Textbox(label="Predicted Brain Age"),
113
+ gr.Textbox(label="Brain Age Difference (BAD)")
114
+ ],
115
+ title="Brain Age Prediction with MedNeXt",
116
+ description="Upload an MRI scan (.nii.gz), enter your age and sex, and get a brain age prediction.",
117
+ theme="default"
118
+ )
119
+
120
+ # Launch the Gradio app
121
+ if __name__ == "__main__":
122
+ iface.launch()