soumickmj commited on
Commit
a18760d
1 Parent(s): 4c12a40

prova full vol

Browse files
Files changed (2) hide show
  1. app.py +21 -5
  2. requirements.txt +2 -1
app.py CHANGED
@@ -10,6 +10,18 @@ import os
10
  import tempfile
11
  from pathlib import Path
12
  import pandas as pd
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  # Set page configuration
15
  st.set_page_config(
@@ -110,14 +122,18 @@ if uploaded_file is not None and process_button:
110
 
111
  # Process the tensor through the model
112
  with st.spinner('Processing the tensor through the model...'):
113
- with torch.no_grad():
114
- output = model(tensor)
115
- if isinstance(output, tuple):
116
- output = output[0]
117
- output = output.squeeze(0)
118
 
119
  st.success("Processing complete.")
120
  st.write(f"Output tensor shape: `{output.shape}`")
 
 
 
 
 
 
 
 
121
 
122
  # Convert output to NumPy array
123
  output_np = output.detach().cpu().numpy()
 
10
  import tempfile
11
  from pathlib import Path
12
  import pandas as pd
13
+ from skimage.filters import threshold_otsu
14
+
15
+ def infer_full_vol(tensor, model):
16
+ tensor = torch.movedim(tensor, -1, -3)
17
+ tensor = tensor / tensor.max()
18
+ with torch.no_grad():
19
+ output = model(tensor)
20
+ if type(output) is tuple or type(output) is list:
21
+ output = output[0]
22
+ output = torch.sigmoid(output)
23
+ output = torch.movedim(output, -3, -1).type(tensor.type())
24
+ return output.squeeze(0)
25
 
26
  # Set page configuration
27
  st.set_page_config(
 
122
 
123
  # Process the tensor through the model
124
  with st.spinner('Processing the tensor through the model...'):
125
+ output = infer_full_vol(tensor, model)
 
 
 
 
126
 
127
  st.success("Processing complete.")
128
  st.write(f"Output tensor shape: `{output.shape}`")
129
+
130
+ try:
131
+ thresh = threshold_otsu(output)
132
+ output = output > thresh
133
+ except Exception as error:
134
+ print(error)
135
+ output = output > 0.5 # exception only if input image seems to have just one color 1.0.
136
+ output = output.astype('uint16')
137
 
138
  # Convert output to NumPy array
139
  output_np = output.detach().cpu().numpy()
requirements.txt CHANGED
@@ -3,4 +3,5 @@ torch
3
  pytorch_lightning
4
  scipy
5
  transformers
6
- torchvision
 
 
3
  pytorch_lightning
4
  scipy
5
  transformers
6
+ torchvision
7
+ scikit-image