|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import sys |
|
|
|
this_file_dir = os.path.dirname(os.path.abspath(__file__)) |
|
sys.path.append(os.path.join(this_file_dir, "../ct_seg")) |
|
import json |
|
import warnings |
|
import PIL |
|
from PIL import Image |
|
from typing import Any, Callable, Dict, List, Optional, Tuple |
|
|
|
import monai |
|
import cv2 |
|
import math |
|
import gradio as gr |
|
import torch |
|
import argparse |
|
import imageio |
|
import numpy as np |
|
import scipy |
|
|
|
from torchvision import transforms |
|
from models import dinov2_vitl_transunet |
|
from class_dict import class_dict, dataset_class |
|
from transforms import _MEAN, _STD |
|
from monai import transforms as monai_transforms |
|
from scipy.ndimage import label |
|
|
|
id2label = {v: k for k, v in class_dict.items()} |
|
np.random.seed(0) |
|
id2color = {k: list(np.random.choice(range(256), size=3)) for k,v in id2label.items()} |
|
|
|
|
|
def clean_mask(X): |
|
""" |
|
Cleans the mask for labels 1 and 2 by keeping only the largest connected component for each label. |
|
|
|
Parameters: |
|
X (numpy.ndarray): Volumetric mask of shape [N, 1, W, H] with values 0 (background), 1, or 2. |
|
|
|
Returns: |
|
numpy.ndarray: Cleaned volumetric mask with the same shape as X. |
|
""" |
|
|
|
if X.ndim == 4: |
|
volume = X[:, 0, :, :] |
|
else: |
|
volume = X |
|
|
|
for label_value in [1, 2, 10]: |
|
|
|
mask = (volume == label_value) |
|
if not np.any(mask): |
|
continue |
|
|
|
|
|
structure = np.ones((3, 3, 3), dtype=int) |
|
|
|
|
|
labeled_mask, num_features = label(mask, structure=structure) |
|
if num_features == 0: |
|
continue |
|
|
|
|
|
component_sizes = np.bincount(labeled_mask.ravel()) |
|
component_sizes[0] = 0 |
|
|
|
|
|
largest_component_label = component_sizes.argmax() |
|
|
|
|
|
largest_component_mask = (labeled_mask == largest_component_label) |
|
|
|
|
|
volume[mask] = 0 |
|
volume[largest_component_mask] = label_value |
|
|
|
|
|
if X.ndim == 4: |
|
X[:, 0, :, :] = volume |
|
else: |
|
X = volume |
|
return X |
|
|
|
|
|
def parse_option(): |
|
parser = argparse.ArgumentParser('SEEM Demo', add_help=False) |
|
parser.add_argument('--model_path', default="ckpt/model_19.pth", metavar="FILE", help='path to model file') |
|
|
|
cfg = parser.parse_args() |
|
return cfg |
|
|
|
''' |
|
build args |
|
''' |
|
cfg = parse_option() |
|
|
|
pretrained_pth = cfg.model_path |
|
|
|
def load_tif_images(file_path): |
|
vol = imageio.imread(file_path) |
|
if np.max(vol) <= 1: |
|
vol = vol * 255 |
|
return vol |
|
|
|
def overlay_image_with_mask(image, segmentation_map, path='test.png', ax=None): |
|
color_seg = np.zeros((segmentation_map.shape[0], segmentation_map.shape[1], 3), dtype=np.uint8) |
|
for label, color in id2color.items(): |
|
color_seg[segmentation_map == label, :] = color |
|
|
|
|
|
img = np.array(image) * 0.5 + color_seg * 0.5 |
|
img = img.astype(np.uint8) |
|
return img |
|
|
|
def resize_volume(vol, size, max_frames, nearest_neighbor=False): |
|
W, H, F = vol.shape |
|
|
|
zoom_rate = size / W |
|
vol_reshape = scipy.ndimage.zoom( |
|
vol, (zoom_rate, zoom_rate, zoom_rate), order=3 if not nearest_neighbor else 0 |
|
) |
|
resizeW, resizeH, resizeF = vol_reshape.shape |
|
if resizeF > max_frames: |
|
vol_reshape = vol_reshape[:, :, :max_frames] |
|
resizeF = max_frames |
|
else: |
|
resized_max_fr = int(math.ceil(max_frames * zoom_rate)) |
|
vol_reshape = np.concatenate([vol_reshape, np.zeros((resizeW, resizeH, resized_max_fr - resizeF))], axis=-1) |
|
return vol_reshape, resizeF, zoom_rate |
|
|
|
val_transform = monai_transforms.Compose([monai_transforms.Resized(keys=['image'], spatial_size=(256, 256), mode=['bilinear'])]) |
|
def process_volume(vol: np.ndarray, keep_frames: Callable=lambda x: x > 0.025): |
|
initial_resize = monai.transforms.ResizeWithPadOrCrop((512, 512)) |
|
transform = monai.transforms.CropForeground(keys=["pixel_values"], source_key="pixel_values", return_coords=True) |
|
crop_vol, start_coords, end_coords = transform(vol) |
|
keep_frames = np.where(keep_frames(np.mean(np.mean(crop_vol, axis=-1), axis=-1)))[0] |
|
crop_vol = crop_vol[keep_frames] |
|
W, H, F = crop_vol.shape |
|
proc_vol = cv2.equalizeHist(crop_vol.reshape(W, -1).astype(np.uint8)).reshape(W, H, F) |
|
proc_vol = initial_resize(proc_vol).detach().cpu().numpy().transpose((1, 2, 0)) |
|
proc_vol, max_fr = resize_volume(proc_vol, 256, max_frames=512)[:2] |
|
|
|
images = [] |
|
for i in range(proc_vol.shape[2]): |
|
image = torch.from_numpy(proc_vol[:, :, i]).unsqueeze(0) |
|
image_transformed = val_transform({"image": image})["image"] |
|
images.append(image_transformed) |
|
images = torch.stack(images) |
|
if images.max() > 1: |
|
images = images / 255.0 |
|
|
|
images = images.repeat(1, 3, 1, 1) |
|
for c in range(len(_MEAN)): |
|
images[:, c, :, :] = (images[:, c, :, :] - _MEAN[c]) / _STD[c] |
|
return images, max_fr |
|
|
|
def untransform(img): |
|
for c in range(len(_MEAN)): |
|
img[c] = img[c] * _STD[c] + _MEAN[c] |
|
if img.max() <= 1: |
|
img = img * 255 |
|
return img.long() |
|
|
|
def process_ct(ct_path: str): |
|
vol = load_tif_images(ct_path) |
|
images, frame_indices = process_volume(vol, keep_frames=lambda x: x > 0.025) |
|
return images, frame_indices |
|
|
|
|
|
examples = [["demo/CTseg_57_raw.tif"], |
|
["demo/CTrec-don_1101.tif"]] |
|
|
|
''' |
|
build model |
|
''' |
|
class_names = dataset_class["uwseg"] |
|
class_ids = [class_dict[class_name] for class_name in class_names] |
|
model = dinov2_vitl_transunet(pretrained="", num_classes=len(class_dict), img_size=256) |
|
state_dict = torch.load(pretrained_pth) |
|
model.load_state_dict(state_dict) |
|
model = model.cuda() |
|
|
|
@torch.no_grad() |
|
def inference(image_input): |
|
if isinstance(image_input, str): |
|
|
|
file_path = image_input |
|
else: |
|
|
|
file_path = image_input.name |
|
images, frame_indices = process_ct(file_path) |
|
with torch.no_grad(): |
|
with torch.cuda.amp.autocast(dtype=torch.float16): |
|
logits = model(images.cuda()) |
|
for j in range(len(class_dict)): |
|
if j + 1 not in class_ids: |
|
logits[:, j] = -1000 |
|
pred = torch.argmax(logits, dim=1) + 1 |
|
pred_mask = (torch.max(logits, dim=1)[0] > 0) |
|
pred = pred_mask * pred |
|
pred[frame_indices:] = 0 |
|
pred = torch.from_numpy(clean_mask(pred.cpu().numpy())) |
|
volume_size = torch.sum(pred==2).item() |
|
|
|
volume_size = volume_size / 1000 |
|
|
|
|
|
sizes = pred.view(pred.shape[0], -1).sum(dim=1).cpu().numpy() |
|
|
|
segmentation_results = [] |
|
raw_images = [] |
|
for i in range(len(images)): |
|
images[i] = untransform(images[i]) |
|
raw_image = Image.fromarray(images[i].cpu().permute(1, 2, 0).numpy().astype(np.uint8)) |
|
raw_images.append(raw_image) |
|
image_with_mask = overlay_image_with_mask(images[i].cpu().permute(1, 2, 0).numpy(), pred[i].squeeze(0).cpu().numpy()) |
|
image_with_mask = Image.fromarray(image_with_mask) |
|
segmentation_results.append(image_with_mask) |
|
initial_slice_index = 0 |
|
output_seg = segmentation_results[initial_slice_index] |
|
output_raw = raw_images[initial_slice_index] |
|
num_slices = len(segmentation_results) |
|
initial_size = sizes[initial_slice_index] |
|
return output_seg, output_raw, segmentation_results, raw_images, gr.update(maximum=num_slices - 1), sizes, f"Heart volume size: {volume_size} cm^3" |
|
|
|
def update_slice(slice_index, segmentation_results_state, raw_images_state, sizes_text): |
|
segmentation_results = segmentation_results_state |
|
raw_images = raw_images_state |
|
|
|
if segmentation_results is None or raw_images is None: |
|
return None, None, "" |
|
output_seg = segmentation_results[slice_index] |
|
output_raw = raw_images[slice_index] |
|
|
|
return output_seg, output_raw, size_text |
|
|
|
def load_example(example): |
|
image_file_path = example |
|
return inference(image_file_path) |
|
|
|
title = "CT Segmentation" |
|
description = """ |
|
|
|
<div style="text-align: left; font-weight: bold;"> |
|
<br> |
|
🌪 Note: The current model is run on <span style="color:blue;">CT Segmentation (UW) </span> </p> |
|
</div> |
|
""" |
|
|
|
article = "The Demo is Run on CT-Seg." |
|
with gr.Blocks(theme=gr.themes.Soft(), title=title, css=".gradio-container { max-width: 1000px; margin: auto; }") as demo: |
|
|
|
with gr.Row(): |
|
gr.Markdown(value="# <span style='color: #6366f1;'>UW CT segmentation</span>", elem_id="title") |
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
gr.Markdown(value=""" |
|
Welcome to CT Segmentation, an AI model that segments the thorax and heart out, and computes the volume sizes. |
|
|
|
## How to Use: |
|
0. **Explore Default Examples**: Click on images in the right panel. |
|
1. **Upload Your Image**: something biomedical... but not your lovely pet! |
|
|
|
Click **Segment** and see what CT Seg finds for you! |
|
""", |
|
elem_id="instructions") |
|
gr.Markdown("## Step 1: Upload CT volume .tif image (Try examples on the right panel)") |
|
with gr.Row(equal_height = True): |
|
input_image = gr.File(label="Input Image", file_types=[".tif"]) |
|
|
|
slice_index_slider = gr.Slider(minimum=0, maximum=0, step=1, label="Slice Index") |
|
with gr.Row(equal_height = True): |
|
output_raw = gr.Image(label="Processed Image", interactive=False) |
|
output_seg = gr.Image(label="Segmentation Results", interactive=False) |
|
with gr.Row(): |
|
size_text = gr.Textbox(label="Heart volume Size", interactive=False) |
|
with gr.Row(): |
|
button = gr.Button("Segment", interactive=True, variant='primary') |
|
with gr.Column(scale=0.5): |
|
gr.Markdown("## Click Default Examples") |
|
|
|
segmentation_results_state = gr.State() |
|
raw_images_state = gr.State() |
|
sizes_state = gr.State() |
|
gr.Examples( |
|
examples=examples, |
|
inputs=[input_image], |
|
outputs=[output_seg, output_raw, segmentation_results_state, raw_images_state, slice_index_slider, sizes_state, size_text], |
|
fn=load_example, |
|
cache_examples=False, |
|
examples_per_page=1, |
|
run_on_click=True |
|
) |
|
|
|
button.click( |
|
fn=inference, |
|
inputs=[input_image], |
|
outputs=[output_seg, output_raw, segmentation_results_state, raw_images_state, slice_index_slider, sizes_state, size_text] |
|
) |
|
|
|
slice_index_slider.change( |
|
fn=update_slice, |
|
inputs=[slice_index_slider, segmentation_results_state, raw_images_state, size_text], |
|
outputs=[output_seg, output_raw, size_text] |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.queue().launch(share=True) |
|
|