FrancescoLR commited on
Commit
6c748cb
·
1 Parent(s): 2bce9bd

Updated app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -3
app.py CHANGED
@@ -31,6 +31,10 @@ def run_nnunet_predict(nifti_file):
31
  os.makedirs(INPUT_DIR, exist_ok=True)
32
  os.makedirs(OUTPUT_DIR, exist_ok=True)
33
 
 
 
 
 
34
  # Save the uploaded file to the input directory
35
  input_path = os.path.join(INPUT_DIR, "image_0000.nii.gz")
36
  os.rename(nifti_file.name, input_path) # Move the uploaded file to the expected input location
@@ -50,19 +54,23 @@ def run_nnunet_predict(nifti_file):
50
  "-d", "004", # Dataset ID
51
  "-c", "3d_fullres", # Configuration
52
  "-tr", "nnUNetTrainer_8000epochs",
53
- "-device", "cuda" # Explicitly use GPU 0
54
  ]
55
  try:
56
  subprocess.run(command, check=True)
57
- # Get the output file
 
58
  output_file = os.path.join(OUTPUT_DIR, "image.nii.gz")
 
59
  if os.path.exists(output_file):
60
- return output_file
 
61
  else:
62
  return "Error: Output file not found."
63
  except subprocess.CalledProcessError as e:
64
  return f"Error: {e}"
65
 
 
66
  # Gradio Interface
67
  interface = gr.Interface(
68
  fn=run_nnunet_predict,
 
31
  os.makedirs(INPUT_DIR, exist_ok=True)
32
  os.makedirs(OUTPUT_DIR, exist_ok=True)
33
 
34
+ # Extract the original filename without the extension
35
+ original_filename = os.path.basename(nifti_file.name)
36
+ base_filename = original_filename.replace(".nii.gz", "").replace("_0000", "")
37
+
38
  # Save the uploaded file to the input directory
39
  input_path = os.path.join(INPUT_DIR, "image_0000.nii.gz")
40
  os.rename(nifti_file.name, input_path) # Move the uploaded file to the expected input location
 
54
  "-d", "004", # Dataset ID
55
  "-c", "3d_fullres", # Configuration
56
  "-tr", "nnUNetTrainer_8000epochs",
57
+ "-device", "cuda" # Explicitly use GPU
58
  ]
59
  try:
60
  subprocess.run(command, check=True)
61
+
62
+ # Rename the output file to match the original input filename
63
  output_file = os.path.join(OUTPUT_DIR, "image.nii.gz")
64
+ new_output_file = os.path.join(OUTPUT_DIR, f"{base_filename}_LesionMask.nii.gz")
65
  if os.path.exists(output_file):
66
+ os.rename(output_file, new_output_file)
67
+ return new_output_file
68
  else:
69
  return "Error: Output file not found."
70
  except subprocess.CalledProcessError as e:
71
  return f"Error: {e}"
72
 
73
+
74
  # Gradio Interface
75
  interface = gr.Interface(
76
  fn=run_nnunet_predict,