FrancescoLR commited on
Commit
ccd4fd7
·
verified ·
1 Parent(s): 72ec283

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -2
app.py CHANGED
@@ -269,8 +269,11 @@ def run_nnunet_predict(nifti_file,hd_bet=False):
269
  image = extract_middle_slices(input_path, input_slice_path, center=center)
270
  labeled_mask = extract_middle_slices(new_output_file, output_slice_path, center=center, label_components=True)
271
 
 
 
 
272
  labeled_mask_path = os.path.join(OUTPUT_DIR, f"{base_filename}_LabeledClusters.nii.gz")
273
- nib.save(nib.Nifti1Image(labeled_mask.astype(np.int16), img.affine), labeled_mask_path)
274
 
275
  # Return paths for the Gradio interface
276
  return new_output_file, input_slice_path, output_slice_path, labeled_mask_path
@@ -302,9 +305,10 @@ with gr.Blocks() as demo:
302
  submit_button = gr.Button("Submit")
303
  with gr.Column(scale=2):
304
  seg_output = gr.File(label="Download the Lesion Segmentation Mask")
 
305
  input_img = gr.Image(label="Input: FLAIR image")
306
  output_img = gr.Image(label="Output: Binary Lesion Mask")
307
- clusters_output = gr.Image(label="Output: Labeled Lesion Mask")
308
 
309
  gr.Markdown("""
310
  **If you find this tool useful, please consider citing:**
 
269
  image = extract_middle_slices(input_path, input_slice_path, center=center)
270
  labeled_mask = extract_middle_slices(new_output_file, output_slice_path, center=center, label_components=True)
271
 
272
+ # Load the binary lesion mask to get its affine
273
+ output_img = nib.load(new_output_file)
274
+
275
  labeled_mask_path = os.path.join(OUTPUT_DIR, f"{base_filename}_LabeledClusters.nii.gz")
276
+ nib.save(nib.Nifti1Image(labeled_mask.astype(np.int16), output_img.affine), labeled_mask_path)
277
 
278
  # Return paths for the Gradio interface
279
  return new_output_file, input_slice_path, output_slice_path, labeled_mask_path
 
305
  submit_button = gr.Button("Submit")
306
  with gr.Column(scale=2):
307
  seg_output = gr.File(label="Download the Lesion Segmentation Mask")
308
+ clusters_output = gr.File(label="Download the Labeled Lesion Segmentation Mask")
309
  input_img = gr.Image(label="Input: FLAIR image")
310
  output_img = gr.Image(label="Output: Binary Lesion Mask")
311
+
312
 
313
  gr.Markdown("""
314
  **If you find this tool useful, please consider citing:**