Spaces:
Runtime error
Runtime error
import streamlit as st | |
import json | |
import math | |
import numpy as np | |
import nibabel as nib | |
import torch | |
import torch.nn.functional as F | |
import scipy.io | |
from io import BytesIO | |
from transformers import AutoModel | |
import os | |
import tempfile | |
from pathlib import Path | |
import pandas as pd | |
from skimage.filters import threshold_otsu | |
def infer_full_vol(tensor, model): | |
tensor = torch.movedim(tensor, -1, -3) | |
tensor = tensor / tensor.max() | |
sizes = tensor.shape[-3:] | |
new_sizes = [math.ceil(s / 16) * 16 for s in sizes] | |
total_pads = [new_size - s for s, new_size in zip(sizes, new_sizes)] | |
pad_before = [pad // 2 for pad in total_pads] | |
pad_after = [pad - pad_before[i] for i, pad in enumerate(total_pads)] | |
padding = [] | |
for i in reversed(range(len(pad_before))): | |
padding.extend([pad_before[i], pad_after[i]]) | |
tensor = F.pad(tensor, padding) | |
with torch.no_grad(): | |
output = model(tensor) | |
if type(output) is tuple or type(output) is list: | |
output = output[0] | |
output = torch.sigmoid(output) | |
slices = [slice(None)] * output.dim() | |
for i in range(len(pad_before)): | |
dim = -3 + i | |
start = pad_before[i] | |
size = sizes[i] | |
end = start + size | |
slices[dim] = slice(start, end) | |
output = output[tuple(slices)] | |
output = torch.movedim(output, -3, -1).type(tensor.type()) | |
return output.squeeze().detach().cpu().numpy() | |
# Set page configuration | |
st.set_page_config( | |
page_title="DS6 | Segmenting vessels in 3D MRA-ToF (ideally, 7T)", | |
page_icon="🧠", | |
layout="wide", | |
initial_sidebar_state="expanded", | |
) | |
# Sidebar content | |
with st.sidebar: | |
st.title("Segmenting vessels in the brain from a 3D Magnetic Resonance Angiograph, ideally acquired at 7T | DS6") | |
st.markdown(""" | |
This application allows you to upload a 3D NIfTI file (dims: H x W x D, where the final dim is the slice dim in the axial plane), process it through a pre-trained 3D model (from DS6 and other related works), and download the output as a `.nii.gz` file containing the vessel segmentation. | |
**Instructions**: | |
- Upload your 3D NIfTI file (`.nii` or `.nii.gz`). It should be a single-slice cardiac long-axis dynamic CINE scan, where the first dimension represents time. | |
- Select a seed value from the dropdown menu. | |
- Click the "Process" button to generate the latent factors. | |
""") | |
st.markdown("---") | |
st.markdown("© 2024 Soumick Chatterjee") | |
# Main content | |
st.header("DS6, Deformation-Aware Semi-Supervised Learning: Application to Small Vessel Segmentation with Noisy Training Data") | |
# File uploader | |
uploaded_file = st.file_uploader( | |
"Please upload a 3D NIfTI file (.nii or .nii.gz)", | |
type=["nii", "nii.gz"] | |
) | |
# Seed selection | |
model_options = ["SMILEUHURA_DS6_CamSVD_UNetMSS3D_wDeform"] | |
selected_model = st.selectbox("Select a pretrained model:", model_options) | |
# Process button | |
process_button = st.button("Process") | |
if uploaded_file is not None and process_button: | |
try: | |
# Save the uploaded file to a temporary file | |
file_extension = ''.join(Path(uploaded_file.name).suffixes) | |
with tempfile.NamedTemporaryFile(suffix=file_extension) as tmp_file: | |
tmp_file.write(uploaded_file.read()) | |
tmp_file.flush() | |
# Load the NIfTI file from the temporary file | |
nifti_img = nib.load(tmp_file.name) | |
data = nifti_img.get_fdata() | |
# Convert to PyTorch tensor | |
tensor = torch.from_numpy(data).float() | |
# Ensure it's 3D | |
if tensor.ndim != 3: | |
st.error("The uploaded NIfTI file is not a 3D volume. Please upload a valid 3D NIfTI file.") | |
else: | |
# Display input details | |
st.success("File successfully uploaded and read.") | |
st.write(f"Input tensor shape: `{tensor.shape}`") | |
st.write(f"Selected pretrained model: `{selected_model}`") | |
# Add batch and channel dimensions | |
tensor = tensor.unsqueeze(0).unsqueeze(0) # Shape: [1, 1, D, H, W] | |
# Construct the model name based on the selected seed | |
model_name = f"soumickmj/{selected_model}" | |
# Load the pre-trained model from Hugging Face | |
def load_model(model_name): | |
hf_token = os.environ.get('HF_API_TOKEN') | |
if hf_token is None: | |
st.error("Hugging Face API token is not set. Please set the 'HF_API_TOKEN' environment variable.") | |
return None | |
try: | |
model = AutoModel.from_pretrained( | |
model_name, | |
trust_remote_code=True, | |
use_auth_token=hf_token | |
) | |
model.eval() | |
return model | |
except Exception as e: | |
st.error(f"Failed to load model: {e}") | |
return None | |
with st.spinner('Loading the pre-trained model...'): | |
model = load_model(model_name) | |
if model is None: | |
st.stop() # Stop the app if the model couldn't be loaded | |
# Move model and tensor to CPU (ensure compatibility with Spaces) | |
device = torch.device('cpu') | |
model = model.to(device) | |
tensor = tensor.to(device) | |
# Process the tensor through the model | |
with st.spinner('Processing the tensor through the model...'): | |
output = infer_full_vol(tensor, model) | |
st.success("Processing complete.") | |
st.write(f"Output tensor shape: `{output.shape}`") | |
try: | |
thresh = threshold_otsu(output) | |
output = output > thresh | |
except Exception as error: | |
print(error) | |
output = output > 0.5 # exception only if input image seems to have just one color 1.0. | |
output = output.astype('uint16') | |
# Save the output as a NIfTI file | |
output_img = nib.Nifti1Image(output, affine=nifti_img.affine) | |
output_path = tempfile.NamedTemporaryFile(suffix='.nii.gz', delete=False).name | |
nib.save(output_img, output_path) | |
# Read the saved file for download | |
with open(output_path, "rb") as f: | |
output_data = f.read() | |
# Download button for NIfTI file | |
st.download_button( | |
label="Download Segmentation Output", | |
data=output_data, | |
file_name='segmentation_output.nii.gz', | |
mime='application/gzip' | |
) | |
except Exception as e: | |
st.error(f"An error occurred: {e}") | |
elif uploaded_file is None: | |
st.info("Awaiting file upload...") | |
elif not process_button: | |
st.info("Click the 'Process' button to start processing.") |