edureisMD's picture
first commit
73825ed
import spaces
import tempfile
import os
from pathlib import Path
import SimpleITK as sitk
import numpy as np
import nibabel as nib
from totalsegmentator.python_api import totalsegmentator
import gradio as gr
from segmap import seg_map
import logging
# Logging configuration
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
sample_files = ["ct1.nii.gz", "ct2.nii.gz", "ct3.nii.gz"]
def map_labels(seg_array):
labels = []
count = 0
logger.debug("unique segs:")
logger.debug(str(len(np.unique(seg_array))))
for seg_class in np.unique(seg_array):
if seg_class == 0:
continue
labels.append((seg_array == seg_class, seg_map[seg_class]))
count += 1
return labels
def sitk_to_numpy(img_sitk, norm=False):
img_sitk = sitk.DICOMOrient(img_sitk, "LPS")
img_np = sitk.GetArrayFromImage(img_sitk)
if norm:
min_val, max_val = np.min(img_np), np.max(img_np)
img_np = ((img_np - min_val) / (max_val - min_val)).clip(0, 1) * 255
img_np = img_np.astype(np.uint8)
return img_np
def load_image(path, norm=False):
img_sitk = sitk.ReadImage(path)
return sitk_to_numpy(img_sitk, norm)
def show_img_seg(img_np, seg_np=None, slice_idx=50):
if img_np is None or (isinstance(img_np, list) and len(img_np) == 0):
return None
if isinstance(img_np, list):
img_np = img_np[-1]
slice_pos = int(slice_idx * (img_np.shape[0] / 100))
img_slice = img_np[slice_pos, :, :]
if seg_np is None or (isinstance(seg_np, list) and len(seg_np) == 0):
seg_np = []
else:
if isinstance(seg_np, list):
seg_np = seg_np[-1]
seg_np = map_labels(seg_np[slice_pos, :, :])
return img_slice, seg_np
def load_img_to_state(path, img_state, seg_state):
img_state.clear()
seg_state.clear()
if path:
img_np = load_image(path, norm=True)
img_state.append(img_np)
return None, img_state, seg_state
else:
return None, img_state, seg_state
def save_seg(seg, path):
if Path(path).name in sample_files:
path = os.path.join("output_examples", f"{Path(Path(path).stem).stem}_seg.nii.gz")
else:
sitk.WriteImage(seg, path)
return path
@spaces.GPU(duration=150)
def run_inference(path):
with tempfile.TemporaryDirectory() as temp_dir:
input_nib = nib.load(path)
output_nib = totalsegmentator(input_nib, fast=True)
output_path = os.path.join(temp_dir, "totalseg_output.nii.gz")
nib.save(output_nib, output_path)
seg_sitk = sitk.ReadImage(output_path)
return seg_sitk
def inference_wrapper(input_file, img_state, seg_state, slice_slider=50):
file_name = Path(input_file).name
if file_name in sample_files:
seg_sitk = sitk.ReadImage(os.path.join("output_examples", f"{Path(Path(file_name).stem).stem}_seg.nii.gz"))
else:
seg_sitk = run_inference(input_file.name)
seg_path = save_seg(seg_sitk, input_file.name)
seg_state.append(sitk_to_numpy(seg_sitk))
if not img_state:
img_sitk = sitk.ReadImage(input_file.name)
img_state.append(sitk_to_numpy(img_sitk))
return show_img_seg(img_state[-1], seg_state[-1], slice_slider), seg_state, seg_path
with gr.Blocks(title="TotalSegmentator") as interface:
gr.Markdown("# TotalSegmentator: Segmentation of 117 Classes in CT and MR Images")
gr.Markdown("""
- **GitHub:** https://github.com/wasserth/TotalSegmentator
- **Please Note:** This tool is intended for research purposes only and can segment 117 classes in CT/MRI images
- Supports both CT and MR imaging modalities
- Credit: adapted from `DiGuaQiu/MRSegmentator-Gradio`
""")
img_state = gr.State([])
seg_state = gr.State([])
with gr.Accordion(label='Upload CT Scan (nifti file) then click on Generate Segmentation to run TotalSegmentator', open=True):
with gr.Row():
with gr.Column():
file_input = gr.File(
type="filepath", label="Upload a CT or MR Image (.nii/.nii.gz)", file_types=[".gz", ".nii.gz"]
)
gr.Examples(["input_examples/" + example for example in sample_files], file_input)
with gr.Row():
infer_button = gr.Button("Generate Segmentations", variant="primary")
clear_button = gr.ClearButton()
with gr.Column():
slice_slider = gr.Slider(1, 100, value=50, step=2, label="Select (relative) Slice")
img_viewer = gr.AnnotatedImage(label="Image Viewer")
download_seg = gr.File(label="Download Segmentation", interactive=False)
file_input.change(
load_img_to_state,
inputs=[file_input, img_state, seg_state],
outputs=[img_viewer, img_state, seg_state],
)
slice_slider.change(show_img_seg, inputs=[img_state, seg_state, slice_slider], outputs=[img_viewer])
infer_button.click(
inference_wrapper,
inputs=[file_input, img_state, seg_state, slice_slider],
outputs=[img_viewer, seg_state, download_seg],
)
clear_button.add([file_input, img_viewer, img_state, seg_state, download_seg])
if __name__ == "__main__":
interface.queue()
interface.launch(debug=True)