FrancescoLR commited on
Commit
2cac7f2
·
1 Parent(s): b01fffe

Updated app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -15
app.py CHANGED
@@ -3,11 +3,29 @@ import subprocess
3
  import os
4
  import shutil
5
 
6
- # Paths
 
 
 
 
7
  INPUT_DIR = "/tmp/input"
8
  OUTPUT_DIR = "/tmp/output"
9
- MODEL_DIR = "./model"
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  def run_nnunet_predict(nifti_file):
12
  # Prepare directories
13
  os.makedirs(INPUT_DIR, exist_ok=True)
@@ -15,14 +33,15 @@ def run_nnunet_predict(nifti_file):
15
 
16
  # Save the uploaded file to the input directory
17
  input_path = os.path.join(INPUT_DIR, "image.nii.gz")
18
- shutil.copy(nifti_file.name, input_path)
 
19
 
20
  # Set environment variables for nnUNet
21
- os.environ["nnUNet_raw"] = MODEL_DIR
22
- os.environ["nnUNet_preprocessed"] = MODEL_DIR
23
- os.environ["nnUNet_results"] = MODEL_DIR
24
 
25
- # Construct the nnUNetv2_predict command
26
  command = [
27
  "nnUNetv2_predict",
28
  "-i", INPUT_DIR,
@@ -31,8 +50,6 @@ def run_nnunet_predict(nifti_file):
31
  "-c", "3d_fullres", # Configuration
32
  "-tr", "nnUNetTrainer_8000epochs", # Trainer name
33
  ]
34
-
35
- # Run the command
36
  try:
37
  subprocess.run(command, check=True)
38
  # Get the output file
@@ -46,14 +63,13 @@ interface = gr.Interface(
46
  fn=run_nnunet_predict,
47
  inputs=gr.File(label="Upload FLAIR Image (.nii.gz)"),
48
  outputs=gr.File(label="Download Segmentation Mask"),
49
- title="FLAMeS: FLAIR Lesion Analysis in Multiple Sclerosis",
50
- description=(
51
- "Upload a skull-stripped FLAIR image (.nii.gz) to generate a binary segmentation of MS lesions. "
52
- "This model uses nnUNetv2 for inference with ensemble predictions."
53
- ),
54
  )
55
 
 
 
 
56
  # Launch the app
57
  if __name__ == "__main__":
58
  interface.launch()
59
-
 
3
  import os
4
  import shutil
5
 
6
+ export nnUNet_results="./model"
7
+
8
+ # Define paths
9
+ MODEL_DIR = "./model" # Local directory to store the downloaded model
10
+ DATASET_DIR = os.path.join(MODEL_DIR, "Dataset004_WML") # Directory for Dataset004_WML
11
  INPUT_DIR = "/tmp/input"
12
  OUTPUT_DIR = "/tmp/output"
 
13
 
14
+ # Hugging Face Model Repository
15
+ REPO_ID = "FrancescoLR/FLAMeS-model" # Replace with your actual model repository ID
16
+
17
+ # Function to download the model files from Hugging Face Model Hub
18
+ def download_model():
19
+ if not os.path.exists(DATASET_DIR):
20
+ os.makedirs(DATASET_DIR, exist_ok=True)
21
+ print("Downloading Dataset004_WML...")
22
+ hf_hub_download(repo_id=REPO_ID, filename="Dataset004_WML.zip", cache_dir=MODEL_DIR)
23
+ # Unzip the dataset into the correct location
24
+ subprocess.run(["unzip", "-o", os.path.join(MODEL_DIR, "Dataset004_WML.zip"), "-d", DATASET_DIR])
25
+ os.remove(os.path.join(MODEL_DIR, "Dataset004_WML.zip"))
26
+ print("Dataset004_WML downloaded and extracted.")
27
+
28
+ # Function to run nnUNet inference
29
  def run_nnunet_predict(nifti_file):
30
  # Prepare directories
31
  os.makedirs(INPUT_DIR, exist_ok=True)
 
33
 
34
  # Save the uploaded file to the input directory
35
  input_path = os.path.join(INPUT_DIR, "image.nii.gz")
36
+ with open(input_path, "wb") as f:
37
+ f.write(nifti_file.read())
38
 
39
  # Set environment variables for nnUNet
40
+ os.environ["nnUNet_raw"] = DATASET_DIR
41
+ os.environ["nnUNet_preprocessed"] = DATASET_DIR
42
+ os.environ["nnUNet_results"] = DATASET_DIR
43
 
44
+ # Construct and run the nnUNetv2_predict command
45
  command = [
46
  "nnUNetv2_predict",
47
  "-i", INPUT_DIR,
 
50
  "-c", "3d_fullres", # Configuration
51
  "-tr", "nnUNetTrainer_8000epochs", # Trainer name
52
  ]
 
 
53
  try:
54
  subprocess.run(command, check=True)
55
  # Get the output file
 
63
  fn=run_nnunet_predict,
64
  inputs=gr.File(label="Upload FLAIR Image (.nii.gz)"),
65
  outputs=gr.File(label="Download Segmentation Mask"),
66
+ title="FLAMeS: Multiple Sclerosis Lesion Segmentation",
67
+ description="Upload a skull-stripped FLAIR image (.nii.gz) to generate a binary segmentation of MS lesions."
 
 
 
68
  )
69
 
70
+ # Download model files before launching the app
71
+ download_model()
72
+
73
  # Launch the app
74
  if __name__ == "__main__":
75
  interface.launch()