import streamlit as st import math import numpy as np import nibabel as nib import torch import torch.nn.functional as F from transformers import AutoModel import os import tempfile from pathlib import Path from skimage.filters import threshold_otsu import torchio as tio def infer_full_vol(tensor, model): tensor = torch.movedim(tensor, -1, -3) tensor = tensor / tensor.max() sizes = tensor.shape[-3:] new_sizes = [math.ceil(s / 16) * 16 for s in sizes] total_pads = [new_size - s for s, new_size in zip(sizes, new_sizes)] pad_before = [pad // 2 for pad in total_pads] pad_after = [pad - pad_before[i] for i, pad in enumerate(total_pads)] padding = [] for i in reversed(range(len(pad_before))): padding.extend([pad_before[i], pad_after[i]]) tensor = F.pad(tensor, padding) with torch.no_grad(): output = model(tensor) if type(output) is tuple or type(output) is list: output = output[0] output = torch.sigmoid(output) slices = [slice(None)] * output.dim() for i in range(len(pad_before)): dim = -3 + i start = pad_before[i] size = sizes[i] end = start + size slices[dim] = slice(start, end) output = output[tuple(slices)] output = torch.movedim(output, -3, -1).type(tensor.type()) return output.squeeze().detach().cpu().numpy() def infer_patch_based(tensor, model, patch_size=64, stride_length=32, stride_width=32, stride_depth=16, batch_size=10, num_worker=2): test_subject = tio.Subject(img = tio.ScalarImage(tensor=tensor)) overlap = np.subtract(patch_size, (stride_length, stride_width, stride_depth)) with torch.no_grad(): grid_sampler = tio.inference.GridSampler( test_subject, patch_size, overlap, ) aggregator = tio.inference.GridAggregator(grid_sampler, overlap_mode="average") patch_loader = torch.utils.data.DataLoader(grid_sampler, batch_size=batch_size, shuffle=False, num_workers=num_worker) for _, patches_batch in enumerate(patch_loader): local_batch = patches_batch['img'][tio.DATA].float() local_batch = local_batch / local_batch.max() locations = patches_batch[tio.LOCATION] local_batch = torch.movedim(local_batch, -1, -3) output = model(local_batch) if type(output) is tuple or type(output) is list: output = output[0] output = torch.sigmoid(output).detach().cpu() output = torch.movedim(output, -3, -1).type(local_batch.type()) aggregator.add_batch(output, locations) predicted = aggregator.get_output_tensor().squeeze().numpy() return predicted # Set page configuration st.set_page_config( page_title="DS6 | Segmenting vessels in 3D MRA-ToF (ideally, 7T)", page_icon="🧠", layout="wide", initial_sidebar_state="expanded", ) # Sidebar content with st.sidebar: st.title("Segmenting vessels in the brain from a 3D Magnetic Resonance Angiograph, ideally acquired at 7T | DS6") st.markdown(""" This application allows you to upload a 3D NIfTI file (dims: H x W x D, where the final dim is the slice dim in the axial plane), process it through a pre-trained 3D model (from DS6 and other related works), and download the output as a `.nii.gz` file containing the vessel segmentation. **Instructions**: - Upload your 3D NIfTI file (`.nii` or `.nii.gz`). It should be a single-slice cardiac long-axis dynamic CINE scan, where the first dimension represents time. - Select a model from the dropdown menu. - Click the "Process" button to generate the latent factors. """) st.markdown("---") st.markdown("© 2024 Soumick Chatterjee") # Main content st.header("DS6, Deformation-Aware Semi-Supervised Learning: Application to Small Vessel Segmentation with Noisy Training Data") # File uploader uploaded_file = st.file_uploader( "Please upload a 3D NIfTI file (.nii or .nii.gz)", type=["nii", "nii.gz"] ) # Model selection model_options = ["SMILEUHURA_DS6_CamSVD_UNetMSS3D_wDeform"] selected_model = st.selectbox("Select a pretrained model:", model_options) # Mode selection mode_options = ["Full volume inference", "Patch-based inference [Default for DS6]"] selected_mode = st.selectbox("Select the running mode:", mode_options) # Process button process_button = st.button("Process") if uploaded_file is not None and process_button: try: # Save the uploaded file to a temporary file file_extension = ''.join(Path(uploaded_file.name).suffixes) with tempfile.NamedTemporaryFile(suffix=file_extension) as tmp_file: tmp_file.write(uploaded_file.read()) tmp_file.flush() # Load the NIfTI file from the temporary file nifti_img = nib.load(tmp_file.name) data = nifti_img.get_fdata() # Convert to PyTorch tensor tensor = torch.from_numpy(data).float() # Ensure it's 3D if tensor.ndim != 3: st.error("The uploaded NIfTI file is not a 3D volume. Please upload a valid 3D NIfTI file.") else: # Display input details st.success("File successfully uploaded and read.") st.write(f"Input tensor shape: `{tensor.shape}`") st.write(f"Selected pretrained model: `{selected_model}`") # Add batch and channel dimensions tensor = tensor.unsqueeze(0).unsqueeze(0) # Shape: [1, 1, D, H, W] # Construct the model name based on the selected model model_name = f"soumickmj/{selected_model}" # Load the pre-trained model from Hugging Face @st.cache_resource def load_model(model_name): hf_token = os.environ.get('HF_API_TOKEN') if hf_token is None: st.error("Hugging Face API token is not set. Please set the 'HF_API_TOKEN' environment variable.") return None try: model = AutoModel.from_pretrained( model_name, trust_remote_code=True, use_auth_token=hf_token ) model.eval() return model except Exception as e: st.error(f"Failed to load model: {e}") return None with st.spinner('Loading the pre-trained model...'): model = load_model(model_name) if model is None: st.stop() # Stop the app if the model couldn't be loaded # Move model and tensor to CPU (ensure compatibility with Spaces) device = torch.device('cpu') model = model.to(device) tensor = tensor.to(device) # Process the tensor through the model with st.spinner('Processing the tensor through the model...'): if selected_mode == "full volume inference": st.info("Running full volume inference...") output = infer_full_vol(tensor, model) else: st.info("Running patch-based inference [Default for DS6]...") output = infer_patch_based(tensor, model) st.success("Processing complete.") st.write(f"Output tensor shape: `{output.shape}`") try: thresh = threshold_otsu(output) output = output > thresh except Exception as error: print(error) output = output > 0.5 # exception only if input image seems to have just one color 1.0. output = output.astype('uint16') # Save the output as a NIfTI file output_img = nib.Nifti1Image(output, affine=nifti_img.affine) output_path = tempfile.NamedTemporaryFile(suffix='.nii.gz', delete=False).name nib.save(output_img, output_path) # Read the saved file for download with open(output_path, "rb") as f: output_data = f.read() # Download button for NIfTI file st.download_button( label="Download Segmentation Output", data=output_data, file_name='segmentation_output.nii.gz', mime='application/gzip' ) except Exception as e: st.error(f"An error occurred: {e}") elif uploaded_file is None: st.info("Awaiting file upload...") elif not process_button: st.info("Click the 'Process' button to start processing.")