Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -1,30 +1,47 @@
|
|
1 |
import gradio as gr
|
2 |
-
import subprocess
|
3 |
-
import os
|
4 |
-
import shutil
|
5 |
-
from huggingface_hub import hf_hub_download
|
6 |
import torch
|
7 |
-
import spaces # Import spaces for GPU decoration
|
8 |
import numpy as np
|
9 |
-
import
|
|
|
10 |
import torchio
|
11 |
import torch.nn as nn
|
|
|
|
|
12 |
from huggingface_hub import hf_hub_download
|
13 |
from monai.transforms import Compose, LoadImaged, Spacingd, CropForegroundd, SpatialPadd, CenterSpatialCropd
|
14 |
-
from monai.data import Dataset
|
15 |
from nnunet_mednext import create_mednext_encoder_v1
|
16 |
|
17 |
-
#
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
|
|
23 |
|
24 |
-
#
|
25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
-
# Model
|
28 |
class MedNeXtEncReg(nn.Module):
|
29 |
def __init__(self):
|
30 |
super(MedNeXtEncReg, self).__init__()
|
@@ -46,21 +63,7 @@ class MedNeXtEncReg(nn.Module):
|
|
46 |
age_estimate = self.regression_fc(x)
|
47 |
return age_estimate.squeeze()
|
48 |
|
49 |
-
#
|
50 |
-
def initialize_model():
|
51 |
-
model_paths = [
|
52 |
-
hf_hub_download(repo_id="FrancescoLR/BrainAgeNeXt", filename=f"BrainAge_{i}.pth") for i in range(1, 6)
|
53 |
-
]
|
54 |
-
|
55 |
-
models = []
|
56 |
-
for model_path in model_paths:
|
57 |
-
model = MedNeXtEncReg().to(device)
|
58 |
-
model.load_state_dict(torch.load(model_path, map_location=device))
|
59 |
-
model.eval()
|
60 |
-
models.append(model)
|
61 |
-
return models
|
62 |
-
|
63 |
-
# Define preprocessing transforms
|
64 |
def prepare_transforms():
|
65 |
return Compose([
|
66 |
LoadImaged(keys=["image"], ensure_channel_first=True),
|
@@ -71,63 +74,86 @@ def prepare_transforms():
|
|
71 |
torchio.transforms.ZNormalization(masking_method=lambda x: x > 0, keys=["image"])
|
72 |
])
|
73 |
|
74 |
-
# Process
|
75 |
-
def preprocess_mri(
|
76 |
transforms = prepare_transforms()
|
77 |
-
data_dict = {"image":
|
78 |
dataset = Dataset([data_dict], transform=transforms)
|
79 |
dataloader = DataLoader(dataset, batch_size=1, num_workers=0)
|
80 |
return next(iter(dataloader))["image"].to(device)
|
81 |
|
82 |
-
#
|
83 |
-
|
84 |
-
|
|
|
85 |
return "Error: MRI file not found"
|
86 |
|
87 |
-
# Load
|
88 |
models = initialize_model()
|
89 |
|
90 |
# Preprocess MRI
|
91 |
-
image = preprocess_mri(
|
92 |
|
93 |
-
# Run
|
94 |
predictions = []
|
95 |
with torch.no_grad():
|
96 |
for model in models:
|
97 |
pred = model(image)
|
98 |
predictions.append(pred.cpu().numpy())
|
99 |
|
100 |
-
# Compute
|
101 |
predicted_brain_age = np.median(np.stack(predictions))
|
102 |
|
103 |
-
# Apply
|
104 |
predicted_brain_age_corrected = (
|
105 |
predicted_brain_age + (actual_age * 0.062) - 2.96 if actual_age > 18 else predicted_brain_age
|
106 |
)
|
107 |
|
108 |
brain_age_difference = predicted_brain_age_corrected - actual_age
|
109 |
|
110 |
-
# Output
|
111 |
return f"Predicted Brain Age: {predicted_brain_age_corrected:.2f} years", \
|
112 |
f"Brain Age Difference (BAD): {brain_age_difference:.2f} years"
|
113 |
|
114 |
-
# Gradio
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
)
|
130 |
-
|
131 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
if __name__ == "__main__":
|
133 |
-
demo.launch(share=True)
|
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
2 |
import torch
|
|
|
3 |
import numpy as np
|
4 |
+
import os
|
5 |
+
import nibabel as nib
|
6 |
import torchio
|
7 |
import torch.nn as nn
|
8 |
+
import subprocess
|
9 |
+
from scipy.ndimage.measurements import center_of_mass
|
10 |
from huggingface_hub import hf_hub_download
|
11 |
from monai.transforms import Compose, LoadImaged, Spacingd, CropForegroundd, SpatialPadd, CenterSpatialCropd
|
12 |
+
from monai.data import Dataset, DataLoader
|
13 |
from nnunet_mednext import create_mednext_encoder_v1
|
14 |
|
15 |
+
# Model and data directory setup
|
16 |
+
MODEL_DIR = "/root/.cache/huggingface/hub"
|
17 |
+
DATASET_DIR = os.path.join(MODEL_DIR, "BrainAgeNeXt")
|
18 |
+
REPO_ID = "FrancescoLR/BrainAgeNeXt"
|
19 |
+
|
20 |
+
# Ensure model directory exists
|
21 |
+
os.makedirs(MODEL_DIR, exist_ok=True)
|
22 |
|
23 |
+
# ๐น Function to Download Model Weights from Hugging Face
|
24 |
+
def download_model():
|
25 |
+
if not os.path.exists(DATASET_DIR):
|
26 |
+
os.makedirs(DATASET_DIR, exist_ok=True)
|
27 |
+
print("Downloading BrainAgeNeXt model weights...")
|
28 |
+
for i in range(1, 6):
|
29 |
+
hf_hub_download(repo_id=REPO_ID, filename=f"BrainAge_{i}.pth", cache_dir=MODEL_DIR)
|
30 |
+
print("โ
BrainAgeNeXt model downloaded successfully.")
|
31 |
+
|
32 |
+
# ๐น Function to Load Model
|
33 |
+
def initialize_model():
|
34 |
+
model_paths = [hf_hub_download(repo_id=REPO_ID, filename=f"BrainAge_{i}.pth", cache_dir=MODEL_DIR) for i in range(1, 6)]
|
35 |
+
|
36 |
+
models = []
|
37 |
+
for model_path in model_paths:
|
38 |
+
model = MedNeXtEncReg().to(device)
|
39 |
+
model.load_state_dict(torch.load(model_path, map_location=device))
|
40 |
+
model.eval()
|
41 |
+
models.append(model)
|
42 |
+
return models
|
43 |
|
44 |
+
# ๐น Define Model
|
45 |
class MedNeXtEncReg(nn.Module):
|
46 |
def __init__(self):
|
47 |
super(MedNeXtEncReg, self).__init__()
|
|
|
63 |
age_estimate = self.regression_fc(x)
|
64 |
return age_estimate.squeeze()
|
65 |
|
66 |
+
# ๐น Preprocessing Pipeline
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
def prepare_transforms():
|
68 |
return Compose([
|
69 |
LoadImaged(keys=["image"], ensure_channel_first=True),
|
|
|
74 |
torchio.transforms.ZNormalization(masking_method=lambda x: x > 0, keys=["image"])
|
75 |
])
|
76 |
|
77 |
+
# ๐น Process MRI File
|
78 |
+
def preprocess_mri(nifti_path):
|
79 |
transforms = prepare_transforms()
|
80 |
+
data_dict = {"image": nifti_path}
|
81 |
dataset = Dataset([data_dict], transform=transforms)
|
82 |
dataloader = DataLoader(dataset, batch_size=1, num_workers=0)
|
83 |
return next(iter(dataloader))["image"].to(device)
|
84 |
|
85 |
+
# ๐น Run Brain Age Prediction (Decorated for GPU Execution)
|
86 |
+
@spaces.GPU(duration=90)
|
87 |
+
def predict_brain_age(nifti_file, actual_age, sex):
|
88 |
+
if not os.path.exists(nifti_file.name):
|
89 |
return "Error: MRI file not found"
|
90 |
|
91 |
+
# Load Model
|
92 |
models = initialize_model()
|
93 |
|
94 |
# Preprocess MRI
|
95 |
+
image = preprocess_mri(nifti_file.name)
|
96 |
|
97 |
+
# Run Predictions
|
98 |
predictions = []
|
99 |
with torch.no_grad():
|
100 |
for model in models:
|
101 |
pred = model(image)
|
102 |
predictions.append(pred.cpu().numpy())
|
103 |
|
104 |
+
# Compute Median Brain Age Prediction
|
105 |
predicted_brain_age = np.median(np.stack(predictions))
|
106 |
|
107 |
+
# Apply Correction Based on Actual Age
|
108 |
predicted_brain_age_corrected = (
|
109 |
predicted_brain_age + (actual_age * 0.062) - 2.96 if actual_age > 18 else predicted_brain_age
|
110 |
)
|
111 |
|
112 |
brain_age_difference = predicted_brain_age_corrected - actual_age
|
113 |
|
114 |
+
# Output Results
|
115 |
return f"Predicted Brain Age: {predicted_brain_age_corrected:.2f} years", \
|
116 |
f"Brain Age Difference (BAD): {brain_age_difference:.2f} years"
|
117 |
|
118 |
+
# ๐น Gradio Interface Setup
|
119 |
+
with gr.Blocks() as demo:
|
120 |
+
gr.Markdown("""
|
121 |
+
# ๐ง Brain Age Prediction with MedNeXt
|
122 |
+
Upload an MRI scan (.nii.gz), enter your age and sex, and get a brain age prediction.
|
123 |
+
""")
|
124 |
+
|
125 |
+
with gr.Row():
|
126 |
+
with gr.Column(scale=1):
|
127 |
+
mri_input = gr.File(label="Upload MRI (NIfTI .nii.gz)")
|
128 |
+
age_input = gr.Number(label="Enter Age", value=30)
|
129 |
+
sex_input = gr.Radio(["Male", "Female"], label="Select Sex")
|
130 |
+
submit_button = gr.Button("Predict")
|
131 |
+
|
132 |
+
with gr.Column(scale=2):
|
133 |
+
brain_age_output = gr.Textbox(label="Predicted Brain Age")
|
134 |
+
bad_output = gr.Textbox(label="Brain Age Difference (BAD)")
|
135 |
+
|
136 |
+
submit_button.click(
|
137 |
+
fn=predict_brain_age,
|
138 |
+
inputs=[mri_input, age_input, sex_input],
|
139 |
+
outputs=[brain_age_output, bad_output]
|
140 |
+
)
|
141 |
+
|
142 |
+
gr.Markdown("""
|
143 |
+
**Disclaimer:** This is a research tool and is not intended for clinical use.
|
144 |
+
""")
|
145 |
+
# ๐น Debugging GPU Environment
|
146 |
+
if torch.cuda.is_available():
|
147 |
+
print(f"GPU available: {torch.cuda.get_device_name(0)}")
|
148 |
+
device = torch.device("cuda")
|
149 |
+
else:
|
150 |
+
print("No GPU detected. Falling back to CPU.")
|
151 |
+
os.system("nvidia-smi")
|
152 |
+
device = torch.device("cpu")
|
153 |
+
|
154 |
+
# ๐น Download Model Weights
|
155 |
+
download_model()
|
156 |
+
|
157 |
+
# ๐น Run Gradio App
|
158 |
if __name__ == "__main__":
|
159 |
+
demo.launch(share=True)
|