Spaces:
Runtime error
Runtime error
prova full vol
Browse files- app.py +21 -5
- requirements.txt +2 -1
app.py
CHANGED
@@ -10,6 +10,18 @@ import os
|
|
10 |
import tempfile
|
11 |
from pathlib import Path
|
12 |
import pandas as pd
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
# Set page configuration
|
15 |
st.set_page_config(
|
@@ -110,14 +122,18 @@ if uploaded_file is not None and process_button:
|
|
110 |
|
111 |
# Process the tensor through the model
|
112 |
with st.spinner('Processing the tensor through the model...'):
|
113 |
-
|
114 |
-
output = model(tensor)
|
115 |
-
if isinstance(output, tuple):
|
116 |
-
output = output[0]
|
117 |
-
output = output.squeeze(0)
|
118 |
|
119 |
st.success("Processing complete.")
|
120 |
st.write(f"Output tensor shape: `{output.shape}`")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
|
122 |
# Convert output to NumPy array
|
123 |
output_np = output.detach().cpu().numpy()
|
|
|
10 |
import tempfile
|
11 |
from pathlib import Path
|
12 |
import pandas as pd
|
13 |
+
from skimage.filters import threshold_otsu
|
14 |
+
|
15 |
+
def infer_full_vol(tensor, model):
|
16 |
+
tensor = torch.movedim(tensor, -1, -3)
|
17 |
+
tensor = tensor / tensor.max()
|
18 |
+
with torch.no_grad():
|
19 |
+
output = model(tensor)
|
20 |
+
if type(output) is tuple or type(output) is list:
|
21 |
+
output = output[0]
|
22 |
+
output = torch.sigmoid(output)
|
23 |
+
output = torch.movedim(output, -3, -1).type(tensor.type())
|
24 |
+
return output.squeeze(0)
|
25 |
|
26 |
# Set page configuration
|
27 |
st.set_page_config(
|
|
|
122 |
|
123 |
# Process the tensor through the model
|
124 |
with st.spinner('Processing the tensor through the model...'):
|
125 |
+
output = infer_full_vol(tensor, model)
|
|
|
|
|
|
|
|
|
126 |
|
127 |
st.success("Processing complete.")
|
128 |
st.write(f"Output tensor shape: `{output.shape}`")
|
129 |
+
|
130 |
+
try:
|
131 |
+
thresh = threshold_otsu(output)
|
132 |
+
output = output > thresh
|
133 |
+
except Exception as error:
|
134 |
+
print(error)
|
135 |
+
output = output > 0.5 # exception only if input image seems to have just one color 1.0.
|
136 |
+
output = output.astype('uint16')
|
137 |
|
138 |
# Convert output to NumPy array
|
139 |
output_np = output.detach().cpu().numpy()
|
requirements.txt
CHANGED
@@ -3,4 +3,5 @@ torch
|
|
3 |
pytorch_lightning
|
4 |
scipy
|
5 |
transformers
|
6 |
-
torchvision
|
|
|
|
3 |
pytorch_lightning
|
4 |
scipy
|
5 |
transformers
|
6 |
+
torchvision
|
7 |
+
scikit-image
|