soumickmj commited on
Commit
bbd86c5
·
1 Parent(s): 23492e1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -7
app.py CHANGED
@@ -61,6 +61,8 @@ def infer_patch_based(tensor, model, patch_size=64, stride_length=32, stride_wid
61
  total_batches = len(patch_loader)
62
  progress_bar = st.progress(0)
63
  for i, patches_batch in enumerate(patch_loader):
 
 
64
  local_batch = patches_batch['img'][tio.DATA].float()
65
  local_batch = local_batch / local_batch.max()
66
  locations = patches_batch[tio.LOCATION]
@@ -76,7 +78,6 @@ def infer_patch_based(tensor, model, patch_size=64, stride_length=32, stride_wid
76
  aggregator.add_batch(output, locations)
77
 
78
  progress_bar.progress((i + 1) / total_batches)
79
- st.text(f"Processing batch {i + 1} of {total_batches}... ({((i + 1) / total_batches) * 100:.2f}% complete)")
80
  st.text(f"CPU usage: {psutil.cpu_percent()}% | RAM usage: {psutil.virtual_memory().percent}%")
81
 
82
  predicted = aggregator.get_output_tensor().squeeze().numpy()
@@ -124,12 +125,16 @@ selected_mode = st.selectbox("Select the running mode:", mode_options)
124
 
125
  # Parameters for patch-based inference
126
  if selected_mode == "Patch-based inference [Default for DS6]":
127
- patch_size = st.number_input("Patch size:", min_value=1, value=64)
128
- stride_length = st.number_input("Stride length:", min_value=1, value=32)
129
- stride_width = st.number_input("Stride width:", min_value=1, value=32)
130
- stride_depth = st.number_input("Stride depth:", min_value=1, value=16)
131
- batch_size = st.number_input("Batch size:", min_value=1, value=10)
132
- num_worker = st.number_input("Number of workers:", min_value=1, value=2)
 
 
 
 
133
 
134
  # Process button
135
  process_button = st.button("Process")
 
61
  total_batches = len(patch_loader)
62
  progress_bar = st.progress(0)
63
  for i, patches_batch in enumerate(patch_loader):
64
+ st.text(f"Processing batch {i + 1} of {total_batches}... ({((i + 1) / total_batches) * 100:.2f}% complete)")
65
+
66
  local_batch = patches_batch['img'][tio.DATA].float()
67
  local_batch = local_batch / local_batch.max()
68
  locations = patches_batch[tio.LOCATION]
 
78
  aggregator.add_batch(output, locations)
79
 
80
  progress_bar.progress((i + 1) / total_batches)
 
81
  st.text(f"CPU usage: {psutil.cpu_percent()}% | RAM usage: {psutil.virtual_memory().percent}%")
82
 
83
  predicted = aggregator.get_output_tensor().squeeze().numpy()
 
125
 
126
  # Parameters for patch-based inference
127
  if selected_mode == "Patch-based inference [Default for DS6]":
128
+ col1, col2, col3 = st.columns(3)
129
+ with col1:
130
+ patch_size = st.number_input("Patch size:", min_value=1, value=64)
131
+ stride_length = st.number_input("Stride length:", min_value=1, value=32)
132
+ with col2:
133
+ stride_width = st.number_input("Stride width:", min_value=1, value=32)
134
+ stride_depth = st.number_input("Stride depth:", min_value=1, value=16)
135
+ with col3:
136
+ batch_size = st.number_input("Batch size:", min_value=1, value=14)
137
+ num_worker = st.number_input("Number of workers:", min_value=1, value=3)
138
 
139
  # Process button
140
  process_button = st.button("Process")