Spaces:
Runtime error
Runtime error
import streamlit as st | |
import json | |
import numpy as np | |
import nibabel as nib | |
import torch | |
import scipy.io | |
from io import BytesIO | |
from transformers import AutoModel | |
import os | |
import tempfile | |
from pathlib import Path | |
import pandas as pd | |
from skimage.filters import threshold_otsu | |
def infer_full_vol(tensor, model): | |
tensor = torch.movedim(tensor, -1, -3) | |
tensor = tensor / tensor.max() | |
with torch.no_grad(): | |
output = model(tensor) | |
if type(output) is tuple or type(output) is list: | |
output = output[0] | |
output = torch.sigmoid(output) | |
output = torch.movedim(output, -3, -1).type(tensor.type()) | |
return output.detach().cpu().numpy() | |
# Set page configuration | |
st.set_page_config( | |
page_title="DS6 | Segmenting vessels in 3D MRA-ToF (ideally, 7T)", | |
page_icon="🧠", | |
layout="wide", | |
initial_sidebar_state="expanded", | |
) | |
# Sidebar content | |
with st.sidebar: | |
st.title("Segmenting vessels in the brain from a 3D Magnetic Resonance Angiograph, ideally acquired at 7T | DS6") | |
st.markdown(""" | |
This application allows you to upload a 3D NIfTI file (dims: H x W x D), process it through a pre-trained 3D model (from DS6 and other related works), and download the output as a `.nii.gz` file containing the vessel segmentation. | |
**Instructions**: | |
- Upload your 3D NIfTI file (`.nii` or `.nii.gz`). It should be a single-slice cardiac long-axis dynamic CINE scan, where the first dimension represents time. | |
- Select a seed value from the dropdown menu. | |
- Click the "Process" button to generate the latent factors. | |
""") | |
st.markdown("---") | |
st.markdown("© 2024 Soumick Chatterjee") | |
# Main content | |
st.header("From single-slice cardiac long-axis dynamic CINE scan (3D: HxWxD) to 128 latent factors...") | |
# File uploader | |
uploaded_file = st.file_uploader( | |
"Please upload a 3D NIfTI file (.nii or .nii.gz)", | |
type=["nii", "nii.gz"] | |
) | |
# Seed selection | |
model_options = ["SMILEUHURA_DS6_CamSVD_UNetMSS3D_wDeform"] | |
selected_model = st.selectbox("Select a pretrained model:", model_options) | |
# Process button | |
process_button = st.button("Process") | |
if uploaded_file is not None and process_button: | |
try: | |
# Save the uploaded file to a temporary file | |
file_extension = ''.join(Path(uploaded_file.name).suffixes) | |
with tempfile.NamedTemporaryFile(suffix=file_extension) as tmp_file: | |
tmp_file.write(uploaded_file.read()) | |
tmp_file.flush() | |
# Load the NIfTI file from the temporary file | |
nifti_img = nib.load(tmp_file.name) | |
data = nifti_img.get_fdata() | |
# Convert to PyTorch tensor | |
tensor = torch.from_numpy(data).float() | |
# Ensure it's 3D | |
if tensor.ndim != 3: | |
st.error("The uploaded NIfTI file is not a 3D volume. Please upload a valid 3D NIfTI file.") | |
else: | |
# Display input details | |
st.success("File successfully uploaded and read.") | |
st.write(f"Input tensor shape: `{tensor.shape}`") | |
st.write(f"Selected pretrained model: `{selected_model}`") | |
# Add batch and channel dimensions | |
tensor = tensor.unsqueeze(0).unsqueeze(0) # Shape: [1, 1, D, H, W] | |
# Construct the model name based on the selected seed | |
model_name = f"soumickmj/{selected_model}" | |
# Load the pre-trained model from Hugging Face | |
def load_model(model_name): | |
hf_token = os.environ.get('HF_API_TOKEN') | |
if hf_token is None: | |
st.error("Hugging Face API token is not set. Please set the 'HF_API_TOKEN' environment variable.") | |
return None | |
try: | |
model = AutoModel.from_pretrained( | |
model_name, | |
trust_remote_code=True, | |
use_auth_token=hf_token | |
) | |
model.eval() | |
return model | |
except Exception as e: | |
st.error(f"Failed to load model: {e}") | |
return None | |
with st.spinner('Loading the pre-trained model...'): | |
model = load_model(model_name) | |
if model is None: | |
st.stop() # Stop the app if the model couldn't be loaded | |
# Move model and tensor to CPU (ensure compatibility with Spaces) | |
device = torch.device('cpu') | |
model = model.to(device) | |
tensor = tensor.to(device) | |
# Process the tensor through the model | |
with st.spinner('Processing the tensor through the model...'): | |
output = infer_full_vol(tensor, model) | |
st.success("Processing complete.") | |
st.write(f"Output tensor shape: `{output.shape}`") | |
try: | |
thresh = threshold_otsu(output) | |
output = output > thresh | |
except Exception as error: | |
print(error) | |
output = output > 0.5 # exception only if input image seems to have just one color 1.0. | |
output = output.astype('uint16') | |
# Save the output as a NIfTI file | |
output_img = nib.Nifti1Image(output, affine=nifti_img.affine) | |
output_path = tempfile.NamedTemporaryFile(suffix='.nii.gz', delete=False).name | |
nib.save(output_img, output_path) | |
# Read the saved file for download | |
with open(output_path, "rb") as f: | |
output_data = f.read() | |
# Download button for NIfTI file | |
st.download_button( | |
label="Download Segmentation Output", | |
data=output_data, | |
file_name='segmentation_output.nii.gz', | |
mime='application/gzip' | |
) | |
except Exception as e: | |
st.error(f"An error occurred: {e}") | |
elif uploaded_file is None: | |
st.info("Awaiting file upload...") | |
elif not process_button: | |
st.info("Click the 'Process' button to start processing.") |