Spaces:
Runtime error
Runtime error
patching added for trying
Browse files- app.py +45 -8
- requirements.txt +2 -4
app.py
CHANGED
@@ -1,18 +1,15 @@
|
|
1 |
import streamlit as st
|
2 |
-
import json
|
3 |
import math
|
4 |
import numpy as np
|
5 |
import nibabel as nib
|
6 |
import torch
|
7 |
import torch.nn.functional as F
|
8 |
-
import scipy.io
|
9 |
-
from io import BytesIO
|
10 |
from transformers import AutoModel
|
11 |
import os
|
12 |
import tempfile
|
13 |
from pathlib import Path
|
14 |
-
import pandas as pd
|
15 |
from skimage.filters import threshold_otsu
|
|
|
16 |
|
17 |
def infer_full_vol(tensor, model):
|
18 |
tensor = torch.movedim(tensor, -1, -3)
|
@@ -46,6 +43,37 @@ def infer_full_vol(tensor, model):
|
|
46 |
output = torch.movedim(output, -3, -1).type(tensor.type())
|
47 |
return output.squeeze().detach().cpu().numpy()
|
48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
# Set page configuration
|
50 |
st.set_page_config(
|
51 |
page_title="DS6 | Segmenting vessels in 3D MRA-ToF (ideally, 7T)",
|
@@ -62,7 +90,7 @@ with st.sidebar:
|
|
62 |
|
63 |
**Instructions**:
|
64 |
- Upload your 3D NIfTI file (`.nii` or `.nii.gz`). It should be a single-slice cardiac long-axis dynamic CINE scan, where the first dimension represents time.
|
65 |
-
- Select a
|
66 |
- Click the "Process" button to generate the latent factors.
|
67 |
""")
|
68 |
st.markdown("---")
|
@@ -77,10 +105,14 @@ uploaded_file = st.file_uploader(
|
|
77 |
type=["nii", "nii.gz"]
|
78 |
)
|
79 |
|
80 |
-
#
|
81 |
model_options = ["SMILEUHURA_DS6_CamSVD_UNetMSS3D_wDeform"]
|
82 |
selected_model = st.selectbox("Select a pretrained model:", model_options)
|
83 |
|
|
|
|
|
|
|
|
|
84 |
# Process button
|
85 |
process_button = st.button("Process")
|
86 |
|
@@ -111,7 +143,7 @@ if uploaded_file is not None and process_button:
|
|
111 |
# Add batch and channel dimensions
|
112 |
tensor = tensor.unsqueeze(0).unsqueeze(0) # Shape: [1, 1, D, H, W]
|
113 |
|
114 |
-
# Construct the model name based on the selected
|
115 |
model_name = f"soumickmj/{selected_model}"
|
116 |
|
117 |
# Load the pre-trained model from Hugging Face
|
@@ -145,7 +177,12 @@ if uploaded_file is not None and process_button:
|
|
145 |
|
146 |
# Process the tensor through the model
|
147 |
with st.spinner('Processing the tensor through the model...'):
|
148 |
-
|
|
|
|
|
|
|
|
|
|
|
149 |
|
150 |
st.success("Processing complete.")
|
151 |
st.write(f"Output tensor shape: `{output.shape}`")
|
|
|
1 |
import streamlit as st
|
|
|
2 |
import math
|
3 |
import numpy as np
|
4 |
import nibabel as nib
|
5 |
import torch
|
6 |
import torch.nn.functional as F
|
|
|
|
|
7 |
from transformers import AutoModel
|
8 |
import os
|
9 |
import tempfile
|
10 |
from pathlib import Path
|
|
|
11 |
from skimage.filters import threshold_otsu
|
12 |
+
import torchio as tio
|
13 |
|
14 |
def infer_full_vol(tensor, model):
|
15 |
tensor = torch.movedim(tensor, -1, -3)
|
|
|
43 |
output = torch.movedim(output, -3, -1).type(tensor.type())
|
44 |
return output.squeeze().detach().cpu().numpy()
|
45 |
|
46 |
+
def infer_patch_based(tensor, model, patch_size=64, stride_length=32, stride_width=32, stride_depth=16, batch_size=10, num_worker=2):
|
47 |
+
test_subject = tio.Subject(img = tio.ScalarImage(tensor=tensor))
|
48 |
+
overlap = np.subtract(patch_size, (stride_length, stride_width, stride_depth))
|
49 |
+
|
50 |
+
with torch.no_grad():
|
51 |
+
grid_sampler = tio.inference.GridSampler(
|
52 |
+
test_subject,
|
53 |
+
patch_size,
|
54 |
+
overlap,
|
55 |
+
)
|
56 |
+
aggregator = tio.inference.GridAggregator(grid_sampler, overlap_mode="average")
|
57 |
+
patch_loader = torch.utils.data.DataLoader(grid_sampler, batch_size=batch_size, shuffle=False, num_workers=num_worker)
|
58 |
+
for _, patches_batch in enumerate(patch_loader):
|
59 |
+
local_batch = patches_batch['img'][tio.DATA].float()
|
60 |
+
local_batch = local_batch / local_batch.max()
|
61 |
+
locations = patches_batch[tio.LOCATION]
|
62 |
+
|
63 |
+
local_batch = torch.movedim(local_batch, -1, -3)
|
64 |
+
|
65 |
+
output = model(local_batch)
|
66 |
+
if type(output) is tuple or type(output) is list:
|
67 |
+
output = output[0]
|
68 |
+
output = torch.sigmoid(output).detach().cpu()
|
69 |
+
|
70 |
+
output = torch.movedim(output, -3, -1).type(local_batch.type())
|
71 |
+
aggregator.add_batch(output, locations)
|
72 |
+
|
73 |
+
predicted = aggregator.get_output_tensor().squeeze().numpy()
|
74 |
+
|
75 |
+
return predicted
|
76 |
+
|
77 |
# Set page configuration
|
78 |
st.set_page_config(
|
79 |
page_title="DS6 | Segmenting vessels in 3D MRA-ToF (ideally, 7T)",
|
|
|
90 |
|
91 |
**Instructions**:
|
92 |
- Upload your 3D NIfTI file (`.nii` or `.nii.gz`). It should be a single-slice cardiac long-axis dynamic CINE scan, where the first dimension represents time.
|
93 |
+
- Select a model from the dropdown menu.
|
94 |
- Click the "Process" button to generate the latent factors.
|
95 |
""")
|
96 |
st.markdown("---")
|
|
|
105 |
type=["nii", "nii.gz"]
|
106 |
)
|
107 |
|
108 |
+
# Model selection
|
109 |
model_options = ["SMILEUHURA_DS6_CamSVD_UNetMSS3D_wDeform"]
|
110 |
selected_model = st.selectbox("Select a pretrained model:", model_options)
|
111 |
|
112 |
+
# Mode selection
|
113 |
+
mode_options = ["Full volume inference", "Patch-based inference [Default for DS6]"]
|
114 |
+
selected_mode = st.selectbox("Select the running mode:", mode_options)
|
115 |
+
|
116 |
# Process button
|
117 |
process_button = st.button("Process")
|
118 |
|
|
|
143 |
# Add batch and channel dimensions
|
144 |
tensor = tensor.unsqueeze(0).unsqueeze(0) # Shape: [1, 1, D, H, W]
|
145 |
|
146 |
+
# Construct the model name based on the selected model
|
147 |
model_name = f"soumickmj/{selected_model}"
|
148 |
|
149 |
# Load the pre-trained model from Hugging Face
|
|
|
177 |
|
178 |
# Process the tensor through the model
|
179 |
with st.spinner('Processing the tensor through the model...'):
|
180 |
+
if selected_mode == "full volume inference":
|
181 |
+
st.info("Running full volume inference...")
|
182 |
+
output = infer_full_vol(tensor, model)
|
183 |
+
else:
|
184 |
+
st.info("Running patch-based inference [Default for DS6]...")
|
185 |
+
output = infer_patch_based(tensor, model)
|
186 |
|
187 |
st.success("Processing complete.")
|
188 |
st.write(f"Output tensor shape: `{output.shape}`")
|
requirements.txt
CHANGED
@@ -1,7 +1,5 @@
|
|
|
|
1 |
nibabel
|
2 |
torch
|
3 |
-
pytorch_lightning
|
4 |
-
scipy
|
5 |
transformers
|
6 |
-
|
7 |
-
scikit-image
|
|
|
1 |
+
scikit-image
|
2 |
nibabel
|
3 |
torch
|
|
|
|
|
4 |
transformers
|
5 |
+
torchio
|
|