Spaces:
Runtime error
Runtime error
import spaces | |
import tempfile | |
import os | |
from pathlib import Path | |
import SimpleITK as sitk | |
import numpy as np | |
import nibabel as nib | |
from totalsegmentator.python_api import totalsegmentator | |
import gradio as gr | |
from segmap import seg_map | |
import logging | |
# Logging configuration | |
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
sample_files = ["ct1.nii.gz", "ct2.nii.gz", "ct3.nii.gz"] | |
def map_labels(seg_array): | |
labels = [] | |
count = 0 | |
logger.debug("unique segs:") | |
logger.debug(str(len(np.unique(seg_array)))) | |
for seg_class in np.unique(seg_array): | |
if seg_class == 0: | |
continue | |
labels.append((seg_array == seg_class, seg_map[seg_class])) | |
count += 1 | |
return labels | |
def sitk_to_numpy(img_sitk, norm=False): | |
img_sitk = sitk.DICOMOrient(img_sitk, "LPS") | |
img_np = sitk.GetArrayFromImage(img_sitk) | |
if norm: | |
min_val, max_val = np.min(img_np), np.max(img_np) | |
img_np = ((img_np - min_val) / (max_val - min_val)).clip(0, 1) * 255 | |
img_np = img_np.astype(np.uint8) | |
return img_np | |
def load_image(path, norm=False): | |
img_sitk = sitk.ReadImage(path) | |
return sitk_to_numpy(img_sitk, norm) | |
def show_img_seg(img_np, seg_np=None, slice_idx=50): | |
if img_np is None or (isinstance(img_np, list) and len(img_np) == 0): | |
return None | |
if isinstance(img_np, list): | |
img_np = img_np[-1] | |
slice_pos = int(slice_idx * (img_np.shape[0] / 100)) | |
img_slice = img_np[slice_pos, :, :] | |
if seg_np is None or (isinstance(seg_np, list) and len(seg_np) == 0): | |
seg_np = [] | |
else: | |
if isinstance(seg_np, list): | |
seg_np = seg_np[-1] | |
seg_np = map_labels(seg_np[slice_pos, :, :]) | |
return img_slice, seg_np | |
def load_img_to_state(path, img_state, seg_state): | |
img_state.clear() | |
seg_state.clear() | |
if path: | |
img_np = load_image(path, norm=True) | |
img_state.append(img_np) | |
return None, img_state, seg_state | |
else: | |
return None, img_state, seg_state | |
def save_seg(seg, path): | |
if Path(path).name in sample_files: | |
path = os.path.join("output_examples", f"{Path(Path(path).stem).stem}_seg.nii.gz") | |
else: | |
sitk.WriteImage(seg, path) | |
return path | |
def run_inference(path): | |
with tempfile.TemporaryDirectory() as temp_dir: | |
input_nib = nib.load(path) | |
output_nib = totalsegmentator(input_nib, fast=True) | |
output_path = os.path.join(temp_dir, "totalseg_output.nii.gz") | |
nib.save(output_nib, output_path) | |
seg_sitk = sitk.ReadImage(output_path) | |
return seg_sitk | |
def inference_wrapper(input_file, img_state, seg_state, slice_slider=50): | |
file_name = Path(input_file).name | |
if file_name in sample_files: | |
seg_sitk = sitk.ReadImage(os.path.join("output_examples", f"{Path(Path(file_name).stem).stem}_seg.nii.gz")) | |
else: | |
seg_sitk = run_inference(input_file.name) | |
seg_path = save_seg(seg_sitk, input_file.name) | |
seg_state.append(sitk_to_numpy(seg_sitk)) | |
if not img_state: | |
img_sitk = sitk.ReadImage(input_file.name) | |
img_state.append(sitk_to_numpy(img_sitk)) | |
return show_img_seg(img_state[-1], seg_state[-1], slice_slider), seg_state, seg_path | |
with gr.Blocks(title="TotalSegmentator") as interface: | |
gr.Markdown("# TotalSegmentator: Segmentation of 117 Classes in CT and MR Images") | |
gr.Markdown(""" | |
- **GitHub:** https://github.com/wasserth/TotalSegmentator | |
- **Please Note:** This tool is intended for research purposes only and can segment 117 classes in CT/MRI images | |
- Supports both CT and MR imaging modalities | |
- Credit: adapted from `DiGuaQiu/MRSegmentator-Gradio` | |
""") | |
img_state = gr.State([]) | |
seg_state = gr.State([]) | |
with gr.Accordion(label='Upload CT Scan (nifti file) then click on Generate Segmentation to run TotalSegmentator', open=True): | |
with gr.Row(): | |
with gr.Column(): | |
file_input = gr.File( | |
type="filepath", label="Upload a CT or MR Image (.nii/.nii.gz)", file_types=[".gz", ".nii.gz"] | |
) | |
gr.Examples(["input_examples/" + example for example in sample_files], file_input) | |
with gr.Row(): | |
infer_button = gr.Button("Generate Segmentations", variant="primary") | |
clear_button = gr.ClearButton() | |
with gr.Column(): | |
slice_slider = gr.Slider(1, 100, value=50, step=2, label="Select (relative) Slice") | |
img_viewer = gr.AnnotatedImage(label="Image Viewer") | |
download_seg = gr.File(label="Download Segmentation", interactive=False) | |
file_input.change( | |
load_img_to_state, | |
inputs=[file_input, img_state, seg_state], | |
outputs=[img_viewer, img_state, seg_state], | |
) | |
slice_slider.change(show_img_seg, inputs=[img_state, seg_state, slice_slider], outputs=[img_viewer]) | |
infer_button.click( | |
inference_wrapper, | |
inputs=[file_input, img_state, seg_state, slice_slider], | |
outputs=[img_viewer, seg_state, download_seg], | |
) | |
clear_button.add([file_input, img_viewer, img_state, seg_state, download_seg]) | |
if __name__ == "__main__": | |
interface.queue() | |
interface.launch(debug=True) |