FrancescoLR commited on
Commit
81fd262
·
verified ·
1 Parent(s): 0b88f39

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -8
app.py CHANGED
@@ -44,15 +44,18 @@ def extract_middle_slices(nifti_path, output_image_path, slice_size=180):
44
  # Define half the slice size to extract regions around the center of mass
45
  half_size = slice_size // 2
46
 
47
- # Extract slices around the center of mass
48
- def extract_slice(data, center, axis):
49
  slices = [slice(None)] * 3
50
- slices[axis] = slice(center[axis] - half_size, center[axis] + half_size)
51
- return np.take(data, range(center[axis] - half_size, center[axis] + half_size), axis=axis, mode='constant', cval=0)
 
 
 
52
 
53
- axial_slice = extract_slice(data, center, axis=2) # Axial (z-axis)
54
- coronal_slice = extract_slice(data, center, axis=1) # Coronal (y-axis)
55
- sagittal_slice = extract_slice(data, center, axis=0) # Sagittal (x-axis)
56
 
57
  # Create subplots
58
  fig, axes = plt.subplots(1, 3, figsize=(12, 4))
@@ -76,7 +79,6 @@ def extract_middle_slices(nifti_path, output_image_path, slice_size=180):
76
  plt.close()
77
 
78
 
79
-
80
  # Function to run nnUNet inference
81
  @spaces.GPU # Decorate the function to allocate GPU for its execution
82
  def run_nnunet_predict(nifti_file):
 
44
  # Define half the slice size to extract regions around the center of mass
45
  half_size = slice_size // 2
46
 
47
+ # Safely extract slices with boundary checks
48
+ def safe_slice(data, center, axis, half_size):
49
  slices = [slice(None)] * 3
50
+ slices[axis] = slice(
51
+ max(center[axis] - half_size, 0),
52
+ min(center[axis] + half_size, data.shape[axis])
53
+ )
54
+ return data[tuple(slices)]
55
 
56
+ axial_slice = safe_slice(data, center, axis=2, half_size=half_size) # Axial (z-axis)
57
+ coronal_slice = safe_slice(data, center, axis=1, half_size=half_size) # Coronal (y-axis)
58
+ sagittal_slice = safe_slice(data, center, axis=0, half_size=half_size) # Sagittal (x-axis)
59
 
60
  # Create subplots
61
  fig, axes = plt.subplots(1, 3, figsize=(12, 4))
 
79
  plt.close()
80
 
81
 
 
82
  # Function to run nnUNet inference
83
  @spaces.GPU # Decorate the function to allocate GPU for its execution
84
  def run_nnunet_predict(nifti_file):