FrancescoLR commited on
Commit
5dd0ea6
·
1 Parent(s): 812aa8d

Updated app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -7
app.py CHANGED
@@ -2,10 +2,9 @@ import gradio as gr
2
  import subprocess
3
  import os
4
  import shutil
5
- import os
6
  from huggingface_hub import hf_hub_download
7
  import torch
8
-
9
 
10
  # Define paths
11
  MODEL_DIR = "./model" # Local directory to store the downloaded model
@@ -26,6 +25,7 @@ def download_model():
26
  print("Dataset004_WML downloaded and extracted.")
27
 
28
  # Function to run nnUNet inference
 
29
  def run_nnunet_predict(nifti_file):
30
  # Prepare directories
31
  os.makedirs(INPUT_DIR, exist_ok=True)
@@ -49,19 +49,20 @@ def run_nnunet_predict(nifti_file):
49
  "-o", OUTPUT_DIR,
50
  "-d", "004", # Dataset ID
51
  "-c", "3d_fullres", # Configuration
52
- "-tr", "nnUNetTrainer_8000epochs",
 
53
  ]
54
- print("Files in /tmp/output:")
55
- print(os.listdir(OUTPUT_DIR))
56
  try:
57
  subprocess.run(command, check=True)
58
  # Get the output file
59
  output_file = os.path.join(OUTPUT_DIR, "image.nii.gz")
60
- return output_file
 
 
 
61
  except subprocess.CalledProcessError as e:
62
  return f"Error: {e}"
63
 
64
-
65
  # Gradio Interface
66
  interface = gr.Interface(
67
  fn=run_nnunet_predict,
@@ -84,3 +85,4 @@ download_model()
84
  # Launch the app
85
  if __name__ == "__main__":
86
  interface.launch()
 
 
2
  import subprocess
3
  import os
4
  import shutil
 
5
  from huggingface_hub import hf_hub_download
6
  import torch
7
+ import spaces # Import spaces for GPU decoration
8
 
9
  # Define paths
10
  MODEL_DIR = "./model" # Local directory to store the downloaded model
 
25
  print("Dataset004_WML downloaded and extracted.")
26
 
27
  # Function to run nnUNet inference
28
+ @spaces.GPU # Decorate the function to allocate GPU for its execution
29
  def run_nnunet_predict(nifti_file):
30
  # Prepare directories
31
  os.makedirs(INPUT_DIR, exist_ok=True)
 
49
  "-o", OUTPUT_DIR,
50
  "-d", "004", # Dataset ID
51
  "-c", "3d_fullres", # Configuration
52
+ "-tr", "nnUNetTrainer_8000epochs",
53
+ "--device", "cuda:0" # Explicitly use GPU 0
54
  ]
 
 
55
  try:
56
  subprocess.run(command, check=True)
57
  # Get the output file
58
  output_file = os.path.join(OUTPUT_DIR, "image.nii.gz")
59
+ if os.path.exists(output_file):
60
+ return output_file
61
+ else:
62
+ return "Error: Output file not found."
63
  except subprocess.CalledProcessError as e:
64
  return f"Error: {e}"
65
 
 
66
  # Gradio Interface
67
  interface = gr.Interface(
68
  fn=run_nnunet_predict,
 
85
  # Launch the app
86
  if __name__ == "__main__":
87
  interface.launch()
88
+