|
import os |
|
import numpy as np |
|
import gradio as gr |
|
import torch |
|
import monai |
|
import morphsnakes as ms |
|
from utils.sliding_window import sw_inference |
|
from utils.tumor_features import generate_features |
|
from monai.networks.nets import SegResNetVAE |
|
from monai.transforms import ( |
|
LoadImage, Orientation, Compose, ToTensor, Activations, |
|
FillHoles, KeepLargestConnectedComponent, AsDiscrete, ScaleIntensityRange |
|
) |
|
|
|
|
|
|
|
THIS_DIR = os.path.dirname(os.path.abspath(__file__)) |
|
examples_path = [ |
|
os.path.join(THIS_DIR, 'examples', 'HCC_003.nrrd'), |
|
os.path.join(THIS_DIR, 'examples', 'HCC_006.nrrd'), |
|
os.path.join(THIS_DIR, 'examples', 'HCC_007.nrrd'), |
|
os.path.join(THIS_DIR, 'examples', 'HCC_018.nrrd') |
|
] |
|
models_path = { |
|
"liver": os.path.join(THIS_DIR, 'checkpoints', 'liver_3DSegResNetVAE.pth'), |
|
"tumor": os.path.join(THIS_DIR, 'checkpoints', 'tumor_3DSegResNetVAE_weak_morp.pth') |
|
} |
|
cache_path = { |
|
"liver mask": "liver_mask.npy", |
|
"tumor mask": "tumor_mask.npy" |
|
} |
|
device = "cpu" |
|
mydict = {} |
|
|
|
|
|
def render(image_name, x, selected_slice): |
|
|
|
if not isinstance(image_name, str) or '/' in image_name: |
|
image_name = image_name.name.split('/')[-1].replace(".nrrd","") |
|
|
|
if 'img' not in mydict[image_name].keys(): |
|
return (np.zeros((512, 512)), []), f'z-value: {x}, (zmin: {None}, zmax: {None})' |
|
|
|
|
|
zmin, zmax = 0, mydict[image_name]['img'].shape[-1] - 1 |
|
if x > zmax: x = zmax |
|
if x < zmin: x = zmin |
|
|
|
|
|
img = mydict[image_name]['img'][:,:,x] |
|
img = (img - np.min(img)) / (np.max(img) - np.min(img)) |
|
|
|
|
|
annotations = [] |
|
if 'liver mask' in mydict[image_name].keys(): |
|
annotations.append((mydict[image_name]['liver mask'][:,:,x], "segmented liver")) |
|
if 'tumor mask' in mydict[image_name].keys(): |
|
annotations.append((mydict[image_name]['tumor mask'][:,:,x], "segmented tumor")) |
|
|
|
return img, annotations |
|
|
|
|
|
def load_liver_model(): |
|
|
|
liver_model = SegResNetVAE( |
|
input_image_size=(512,512,16), |
|
vae_estimate_std=False, |
|
vae_default_std=0.3, |
|
vae_nz=256, |
|
spatial_dims=3, |
|
blocks_down=[1, 2, 2, 4], |
|
blocks_up=[1, 1, 1], |
|
init_filters=16, |
|
in_channels=1, |
|
norm='instance', |
|
out_channels=2, |
|
dropout_prob=0.2, |
|
) |
|
|
|
liver_model.load_state_dict(torch.load(models_path['liver'], map_location=torch.device(device))) |
|
|
|
return liver_model |
|
|
|
|
|
def load_tumor_model(): |
|
|
|
tumor_model = SegResNetVAE( |
|
input_image_size=(256,256,32), |
|
vae_estimate_std=False, |
|
vae_default_std=0.3, |
|
vae_nz=256, |
|
spatial_dims=3, |
|
blocks_down=[1, 2, 2, 4], |
|
blocks_up=[1, 1, 1], |
|
init_filters=16, |
|
in_channels=1, |
|
norm='instance', |
|
out_channels=3, |
|
dropout_prob=0.2, |
|
) |
|
|
|
tumor_model.load_state_dict(torch.load(models_path['tumor'], map_location=torch.device('cpu'))) |
|
|
|
return tumor_model |
|
|
|
|
|
def load_image(image, slider, selected_slice): |
|
|
|
global mydict |
|
|
|
image_name = image.name.split('/')[-1].replace(".nrrd","") |
|
mydict = {image_name: {}} |
|
|
|
preprocessing_liver = Compose([ |
|
|
|
LoadImage(reader="NrrdReader", ensure_channel_first=True), |
|
|
|
Orientation(axcodes="PLI"), |
|
|
|
ToTensor() |
|
]) |
|
|
|
input = preprocessing_liver(image.name) |
|
mydict[image_name]["img"] = input[0].numpy() |
|
|
|
print("Loaded image", image_name) |
|
|
|
image, annotations = render(image_name, slider, selected_slice) |
|
|
|
return f"Your image is successfully loaded! Please use the slider to view the image (zmin: 1, zmax: {input.shape[-1]}).", (image, annotations) |
|
|
|
|
|
def segment_tumor(image_name): |
|
|
|
if os.path.isfile(f"cache/{image_name}_{cache_path['tumor mask']}"): |
|
mydict[image_name]['tumor mask'] = np.load(f"cache/{image_name}_{cache_path['tumor mask']}") |
|
|
|
if 'tumor mask' in mydict[image_name].keys() and mydict[image_name]['tumor mask'] is not None: |
|
return |
|
|
|
input = torch.from_numpy(mydict[image_name]['img']) |
|
|
|
tumor_model = load_tumor_model() |
|
|
|
preprocessing_tumor = Compose([ |
|
ScaleIntensityRange(a_min=-200, a_max=250, b_min=0.0, b_max=1.0, clip=True) |
|
]) |
|
|
|
postprocessing_tumor = Compose([ |
|
Activations(sigmoid=True), |
|
|
|
AsDiscrete(argmax=True, to_onehot=3), |
|
|
|
KeepLargestConnectedComponent(applied_labels=[2]), |
|
|
|
FillHoles(applied_labels=[2]), |
|
ToTensor() |
|
]) |
|
|
|
|
|
input = preprocessing_tumor(input) |
|
input = torch.multiply(input, torch.from_numpy(mydict[image_name]['liver mask'])) |
|
|
|
|
|
with torch.no_grad(): |
|
segmented_mask = sw_inference(tumor_model, input[None, None, :], (256,256,32), False, discard_second_output=True, overlap=0.2)[0] |
|
|
|
|
|
segmented_mask = postprocessing_tumor(segmented_mask)[-1].numpy() |
|
segmented_mask = ms.morphological_chan_vese(segmented_mask, iterations=2, init_level_set=segmented_mask) |
|
segmented_mask = np.multiply(segmented_mask, mydict[image_name]['liver mask']) |
|
mydict[image_name]["tumor mask"] = segmented_mask |
|
|
|
|
|
np.save(f"cache/{image_name}_{cache_path['tumor mask']}", mydict[image_name]["tumor mask"]) |
|
print(f"tumor mask saved to 'cache/{image_name}_{cache_path['tumor mask']}") |
|
|
|
return |
|
|
|
|
|
def segment_liver(image_name): |
|
|
|
if os.path.isfile(f"cache/{image_name}_{cache_path['liver mask']}"): |
|
mydict[image_name]['liver mask'] = np.load(f"cache/{image_name}_{cache_path['liver mask']}") |
|
|
|
if 'liver mask' in mydict[image_name].keys() and mydict[image_name]['liver mask'] is not None: |
|
return |
|
|
|
input = torch.from_numpy(mydict[image_name]['img']) |
|
|
|
|
|
liver_model = load_liver_model() |
|
|
|
|
|
preprocessing_liver = Compose([ |
|
ScaleIntensityRange(a_min=-150, a_max=250, b_min=0.0, b_max=1.0, clip=True) |
|
]) |
|
|
|
postprocessing_liver = Compose([ |
|
|
|
Activations(sigmoid=True), |
|
|
|
AsDiscrete(argmax=True, to_onehot=None), |
|
|
|
KeepLargestConnectedComponent(applied_labels=[1]), |
|
|
|
FillHoles(applied_labels=[1]), |
|
ToTensor() |
|
]) |
|
|
|
|
|
input = preprocessing_liver(input) |
|
|
|
|
|
with torch.no_grad(): |
|
segmented_mask = sw_inference(liver_model, input[None, None, :], (512,512,16), False, discard_second_output=True, overlap=0.2)[0] |
|
|
|
|
|
segmented_mask = postprocessing_liver(segmented_mask)[0].numpy() |
|
mydict[image_name]["liver mask"] = segmented_mask |
|
print(f"liver mask shape: {segmented_mask.shape}") |
|
|
|
|
|
np.save(f"cache/{image_name}_{cache_path['liver mask']}", mydict[image_name]["liver mask"]) |
|
print(f"liver mask saved to cache/{image_name}_{cache_path['liver mask']}") |
|
|
|
return |
|
|
|
|
|
def segment(image, selected_mask, slider, selected_slice): |
|
|
|
image_name = image.name.split('/')[-1].replace(".nrrd", "") |
|
download_liver = gr.DownloadButton(label="Download liver mask", visible = False) |
|
download_tumor = gr.DownloadButton(label="Download tumor mask", visible = False) |
|
|
|
if 'liver mask' in selected_mask: |
|
print('Segmenting liver...') |
|
segment_liver(image_name) |
|
download_liver = gr.DownloadButton(label="Download liver mask", value=f"cache/{image_name}_{cache_path['liver mask']}", visible=True) |
|
|
|
if 'tumor mask' in selected_mask: |
|
print('Segmenting tumor...') |
|
segment_tumor(image_name) |
|
download_tumor = gr.DownloadButton(label="Download tumor mask", value=f"cache/{image_name}_{cache_path['tumor mask']}", visible=True) |
|
|
|
image, annotations = render(image, slider, selected_slice) |
|
|
|
return f"Segmentation is completed! ", download_liver, download_tumor, (image, annotations) |
|
|
|
|
|
def generate_summary(image): |
|
image_name = image.name.split('/')[-1].replace(".nrrd","") |
|
features = generate_features(mydict[image_name]["img"], mydict[image_name]["liver mask"], mydict[image_name]["tumor mask"]) |
|
print(features) |
|
|
|
return "" |
|
|
|
|
|
with gr.Blocks() as app: |
|
with gr.Column(): |
|
gr.Markdown( |
|
""" |
|
# Lung Tumor Segmentation App |
|
|
|
This tool is designed to assist in the identification and segmentation of lung and tumor from medical images. By uploading a CT scan image, a pre-trained machine learning model will automatically segment the lung and tumor regions. Segmented tumor's characteristics such as shape, size, and location are then analyzed to produce an AI-generated diagnosis report of the lung cancer. |
|
|
|
⚠️ Important disclaimer: these model outputs should NOT replace the medical diagnosis of healthcare professionals. For your reference, our model was trained on the [HCC-TACE-Seg dataset](https://www.cancerimagingarchive.net/collection/hcc-tace-seg/) and achieved 0.954 dice score for lung segmentation and 0.570 dice score for tumor segmentation. Improving tumor segmentation is still an active area of research! |
|
""") |
|
|
|
with gr.Row(): |
|
comment = gr.Textbox(label='Your tool guide:', value="👋 Hi there, welcome to explore the power of AI for automated medical image analysis with our user-friendly app! Start by uploading a CT scan image. Note that for now we accept .nrrd formats only.") |
|
|
|
|
|
with gr.Row(): |
|
|
|
with gr.Column(scale=2): |
|
image_file = gr.File(label="Step 1: Upload a CT image (.nrrd)", file_count='single', file_types=['.nrrd'], type='filepath') |
|
btn_upload = gr.Button("Upload") |
|
|
|
with gr.Column(scale=2): |
|
selected_mask = gr.CheckboxGroup(label='Step 2: Select mask to produce', choices=['liver mask', 'tumor mask'], value = ['liver mask']) |
|
btn_segment = gr.Button("Segment") |
|
|
|
with gr.Row(): |
|
slider = gr.Slider(1, 100, step=1, label="Slice (z)") |
|
selected_slice = gr.State(value=1) |
|
|
|
with gr.Row(): |
|
myimage = gr.AnnotatedImage(label="Image Viewer", height=1000, width=1000, color_map={"segmented liver": "#0373fc", "segmented tumor": "#eb5334"}) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
btn_download_liver = gr.DownloadButton("Download liver mask", visible=False) |
|
with gr.Column(scale=2): |
|
btn_download_tumor = gr.DownloadButton("Download tumor mask", visible=False) |
|
|
|
with gr.Row(): |
|
report = gr.Textbox(label='Step 4. Generate summary report using AI:') |
|
|
|
with gr.Row(): |
|
btn_report = gr.Button("Generate summary") |
|
|
|
|
|
gr.Examples( |
|
examples_path, |
|
[image_file], |
|
) |
|
|
|
btn_upload.click(fn=load_image, |
|
inputs=[image_file, slider, selected_slice], |
|
outputs=[comment, myimage], |
|
) |
|
|
|
btn_segment.click(fn=segment, |
|
inputs=[image_file, selected_mask, slider, selected_slice], |
|
outputs=[comment, btn_download_liver, btn_download_tumor, myimage], |
|
) |
|
|
|
slider.change( |
|
render, |
|
inputs=[image_file, slider, selected_slice], |
|
outputs=[myimage] |
|
) |
|
|
|
btn_report.click(fn=generate_summary, |
|
outputs=report |
|
) |
|
|
|
|
|
app.launch() |
|
|
|
|
|
|
|
|
|
|