FrancescoLR commited on
Commit
73bc6a0
ยท
verified ยท
1 Parent(s): d509cdf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -62
app.py CHANGED
@@ -1,30 +1,47 @@
1
  import gradio as gr
2
- import subprocess
3
- import os
4
- import shutil
5
- from huggingface_hub import hf_hub_download
6
  import torch
7
- import spaces # Import spaces for GPU decoration
8
  import numpy as np
9
- import pandas as pd
 
10
  import torchio
11
  import torch.nn as nn
 
 
12
  from huggingface_hub import hf_hub_download
13
  from monai.transforms import Compose, LoadImaged, Spacingd, CropForegroundd, SpatialPadd, CenterSpatialCropd
14
- from monai.data import Dataset
15
  from nnunet_mednext import create_mednext_encoder_v1
16
 
17
- # Debugging GPU environment
18
- if torch.cuda.is_available():
19
- print(f"GPU is available: {torch.cuda.get_device_name(0)}")
20
- else:
21
- print("No GPU available. Falling back to CPU.")
22
- os.system("nvidia-smi")
 
23
 
24
- # Device selection
25
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
- # Model definition
28
  class MedNeXtEncReg(nn.Module):
29
  def __init__(self):
30
  super(MedNeXtEncReg, self).__init__()
@@ -46,21 +63,7 @@ class MedNeXtEncReg(nn.Module):
46
  age_estimate = self.regression_fc(x)
47
  return age_estimate.squeeze()
48
 
49
- # Download the model from Hugging Face Hub
50
- def initialize_model():
51
- model_paths = [
52
- hf_hub_download(repo_id="FrancescoLR/BrainAgeNeXt", filename=f"BrainAge_{i}.pth") for i in range(1, 6)
53
- ]
54
-
55
- models = []
56
- for model_path in model_paths:
57
- model = MedNeXtEncReg().to(device)
58
- model.load_state_dict(torch.load(model_path, map_location=device))
59
- model.eval()
60
- models.append(model)
61
- return models
62
-
63
- # Define preprocessing transforms
64
  def prepare_transforms():
65
  return Compose([
66
  LoadImaged(keys=["image"], ensure_channel_first=True),
@@ -71,63 +74,86 @@ def prepare_transforms():
71
  torchio.transforms.ZNormalization(masking_method=lambda x: x > 0, keys=["image"])
72
  ])
73
 
74
- # Process uploaded MRI scan
75
- def preprocess_mri(mri_path):
76
  transforms = prepare_transforms()
77
- data_dict = {"image": mri_path}
78
  dataset = Dataset([data_dict], transform=transforms)
79
  dataloader = DataLoader(dataset, batch_size=1, num_workers=0)
80
  return next(iter(dataloader))["image"].to(device)
81
 
82
- # Predict brain age
83
- def predict_brain_age(mri_path, actual_age, sex):
84
- if not os.path.exists(mri_path):
 
85
  return "Error: MRI file not found"
86
 
87
- # Load the model
88
  models = initialize_model()
89
 
90
  # Preprocess MRI
91
- image = preprocess_mri(mri_path)
92
 
93
- # Run predictions
94
  predictions = []
95
  with torch.no_grad():
96
  for model in models:
97
  pred = model(image)
98
  predictions.append(pred.cpu().numpy())
99
 
100
- # Compute median brain age prediction
101
  predicted_brain_age = np.median(np.stack(predictions))
102
 
103
- # Apply correction based on actual age
104
  predicted_brain_age_corrected = (
105
  predicted_brain_age + (actual_age * 0.062) - 2.96 if actual_age > 18 else predicted_brain_age
106
  )
107
 
108
  brain_age_difference = predicted_brain_age_corrected - actual_age
109
 
110
- # Output results
111
  return f"Predicted Brain Age: {predicted_brain_age_corrected:.2f} years", \
112
  f"Brain Age Difference (BAD): {brain_age_difference:.2f} years"
113
 
114
- # Gradio UI
115
- demo = gr.Interface(
116
- fn=predict_brain_age,
117
- inputs=[
118
- gr.File(label="Upload MRI (NIfTI .nii.gz)"),
119
- gr.Number(label="Enter Age"),
120
- gr.Radio(["Male", "Female"], label="Select Sex")
121
- ],
122
- outputs=[
123
- gr.Textbox(label="Predicted Brain Age"),
124
- gr.Textbox(label="Brain Age Difference (BAD)")
125
- ],
126
- title="Brain Age Prediction with MedNeXt",
127
- description="Upload an MRI scan (.nii.gz), enter your age and sex, and get a brain age prediction.",
128
- theme="default"
129
- )
130
-
131
- # Launch the Gradio app
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  if __name__ == "__main__":
133
- demo.launch(share=True)
 
1
  import gradio as gr
 
 
 
 
2
  import torch
 
3
  import numpy as np
4
+ import os
5
+ import nibabel as nib
6
  import torchio
7
  import torch.nn as nn
8
+ import subprocess
9
+ from scipy.ndimage.measurements import center_of_mass
10
  from huggingface_hub import hf_hub_download
11
  from monai.transforms import Compose, LoadImaged, Spacingd, CropForegroundd, SpatialPadd, CenterSpatialCropd
12
+ from monai.data import Dataset, DataLoader
13
  from nnunet_mednext import create_mednext_encoder_v1
14
 
15
+ # Model and data directory setup
16
+ MODEL_DIR = "/root/.cache/huggingface/hub"
17
+ DATASET_DIR = os.path.join(MODEL_DIR, "BrainAgeNeXt")
18
+ REPO_ID = "FrancescoLR/BrainAgeNeXt"
19
+
20
+ # Ensure model directory exists
21
+ os.makedirs(MODEL_DIR, exist_ok=True)
22
 
23
+ # ๐Ÿ”น Function to Download Model Weights from Hugging Face
24
+ def download_model():
25
+ if not os.path.exists(DATASET_DIR):
26
+ os.makedirs(DATASET_DIR, exist_ok=True)
27
+ print("Downloading BrainAgeNeXt model weights...")
28
+ for i in range(1, 6):
29
+ hf_hub_download(repo_id=REPO_ID, filename=f"BrainAge_{i}.pth", cache_dir=MODEL_DIR)
30
+ print("โœ… BrainAgeNeXt model downloaded successfully.")
31
+
32
+ # ๐Ÿ”น Function to Load Model
33
+ def initialize_model():
34
+ model_paths = [hf_hub_download(repo_id=REPO_ID, filename=f"BrainAge_{i}.pth", cache_dir=MODEL_DIR) for i in range(1, 6)]
35
+
36
+ models = []
37
+ for model_path in model_paths:
38
+ model = MedNeXtEncReg().to(device)
39
+ model.load_state_dict(torch.load(model_path, map_location=device))
40
+ model.eval()
41
+ models.append(model)
42
+ return models
43
 
44
+ # ๐Ÿ”น Define Model
45
  class MedNeXtEncReg(nn.Module):
46
  def __init__(self):
47
  super(MedNeXtEncReg, self).__init__()
 
63
  age_estimate = self.regression_fc(x)
64
  return age_estimate.squeeze()
65
 
66
+ # ๐Ÿ”น Preprocessing Pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  def prepare_transforms():
68
  return Compose([
69
  LoadImaged(keys=["image"], ensure_channel_first=True),
 
74
  torchio.transforms.ZNormalization(masking_method=lambda x: x > 0, keys=["image"])
75
  ])
76
 
77
+ # ๐Ÿ”น Process MRI File
78
+ def preprocess_mri(nifti_path):
79
  transforms = prepare_transforms()
80
+ data_dict = {"image": nifti_path}
81
  dataset = Dataset([data_dict], transform=transforms)
82
  dataloader = DataLoader(dataset, batch_size=1, num_workers=0)
83
  return next(iter(dataloader))["image"].to(device)
84
 
85
+ # ๐Ÿ”น Run Brain Age Prediction (Decorated for GPU Execution)
86
+ @spaces.GPU(duration=90)
87
+ def predict_brain_age(nifti_file, actual_age, sex):
88
+ if not os.path.exists(nifti_file.name):
89
  return "Error: MRI file not found"
90
 
91
+ # Load Model
92
  models = initialize_model()
93
 
94
  # Preprocess MRI
95
+ image = preprocess_mri(nifti_file.name)
96
 
97
+ # Run Predictions
98
  predictions = []
99
  with torch.no_grad():
100
  for model in models:
101
  pred = model(image)
102
  predictions.append(pred.cpu().numpy())
103
 
104
+ # Compute Median Brain Age Prediction
105
  predicted_brain_age = np.median(np.stack(predictions))
106
 
107
+ # Apply Correction Based on Actual Age
108
  predicted_brain_age_corrected = (
109
  predicted_brain_age + (actual_age * 0.062) - 2.96 if actual_age > 18 else predicted_brain_age
110
  )
111
 
112
  brain_age_difference = predicted_brain_age_corrected - actual_age
113
 
114
+ # Output Results
115
  return f"Predicted Brain Age: {predicted_brain_age_corrected:.2f} years", \
116
  f"Brain Age Difference (BAD): {brain_age_difference:.2f} years"
117
 
118
+ # ๐Ÿ”น Gradio Interface Setup
119
+ with gr.Blocks() as demo:
120
+ gr.Markdown("""
121
+ # ๐Ÿง  Brain Age Prediction with MedNeXt
122
+ Upload an MRI scan (.nii.gz), enter your age and sex, and get a brain age prediction.
123
+ """)
124
+
125
+ with gr.Row():
126
+ with gr.Column(scale=1):
127
+ mri_input = gr.File(label="Upload MRI (NIfTI .nii.gz)")
128
+ age_input = gr.Number(label="Enter Age", value=30)
129
+ sex_input = gr.Radio(["Male", "Female"], label="Select Sex")
130
+ submit_button = gr.Button("Predict")
131
+
132
+ with gr.Column(scale=2):
133
+ brain_age_output = gr.Textbox(label="Predicted Brain Age")
134
+ bad_output = gr.Textbox(label="Brain Age Difference (BAD)")
135
+
136
+ submit_button.click(
137
+ fn=predict_brain_age,
138
+ inputs=[mri_input, age_input, sex_input],
139
+ outputs=[brain_age_output, bad_output]
140
+ )
141
+
142
+ gr.Markdown("""
143
+ **Disclaimer:** This is a research tool and is not intended for clinical use.
144
+ """)
145
+ # ๐Ÿ”น Debugging GPU Environment
146
+ if torch.cuda.is_available():
147
+ print(f"GPU available: {torch.cuda.get_device_name(0)}")
148
+ device = torch.device("cuda")
149
+ else:
150
+ print("No GPU detected. Falling back to CPU.")
151
+ os.system("nvidia-smi")
152
+ device = torch.device("cpu")
153
+
154
+ # ๐Ÿ”น Download Model Weights
155
+ download_model()
156
+
157
+ # ๐Ÿ”น Run Gradio App
158
  if __name__ == "__main__":
159
+ demo.launch(share=True)