|
|
|
import torch, torchvision |
|
from monai.networks.nets import UNet |
|
from monai.networks.layers import Norm |
|
from monai.inferers import sliding_window_inference |
|
import PIL |
|
from torchvision.utils import save_image |
|
import numpy as np |
|
|
|
model = UNet( |
|
spatial_dims=3, |
|
in_channels=1, |
|
out_channels=2, |
|
channels=(16, 32, 64, 128, 256), |
|
strides=(2, 2, 2, 2), |
|
num_res_units=2, |
|
norm=Norm.BATCH, |
|
) |
|
|
|
model.load_state_dict(torch.load("weights/model.pt", map_location=torch.device('cpu'))) |
|
|
|
import gradio as gr |
|
|
|
def load_image0(): |
|
return load_image(0) |
|
|
|
def load_image1(): |
|
return load_image(1) |
|
|
|
def load_image2(): |
|
return load_image(2) |
|
|
|
def load_image3(): |
|
return load_image(3) |
|
|
|
def load_image4(): |
|
return load_image(4) |
|
|
|
def load_image5(): |
|
return load_image(5) |
|
|
|
def load_image6(): |
|
return load_image(6) |
|
|
|
def load_image7(): |
|
return load_image(7) |
|
|
|
def load_image8(): |
|
return load_image(8) |
|
|
|
def load_image(index): |
|
return [index, f"thumbnails/val_image{index}.png", f"thumbnails_label/val_label{index}.png"] |
|
|
|
def predict(index): |
|
val_data = torch.load(f"samples/val_data{index}.pt") |
|
|
|
model.eval() |
|
with torch.no_grad(): |
|
roi_size = (160, 160, 160) |
|
sw_batch_size = 4 |
|
val_outputs = sliding_window_inference(val_data, roi_size, sw_batch_size, model) |
|
|
|
meta_tsr = torch.argmax(val_outputs, dim=1)[0, :, :, 80] |
|
pil_image = torchvision.transforms.functional.to_pil_image(meta_tsr.to(torch.float32)) |
|
|
|
return pil_image |
|
|
|
|
|
with gr.Blocks(title="Spleen 3D segmentation with MONAI - ClassCat", |
|
css=".gradio-container {background:azure;}" |
|
) as demo: |
|
sample_index = gr.State([]) |
|
|
|
gr.HTML("""<div style="font-family:'Times New Roman', 'Serif'; font-size:16pt; font-weight:bold; text-align:center; color:royalblue;">Spleen 3D segmentation with MONAI</div>""") |
|
|
|
gr.HTML("""<h4 style="color:navy;">1. Select an example, which includes input images and label images, by clicking "Example x" button.</h4>""") |
|
|
|
with gr.Row(): |
|
input_image = gr.Image(label="a piece of input image data", type="filepath") |
|
label_image = gr.Image(label="label image", type="filepath") |
|
output_image = gr.Image(label="predicted image", type="pil") |
|
|
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
ex_btn0 = gr.Button("Example 1") |
|
ex_btn0.style(full_width=False, css="width:20px;") |
|
ex_image0 = gr.Image(value='thumbnails/val_image0.png', interactive=False, label='ex 1') |
|
ex_image0.style(width=128, height=128) |
|
|
|
with gr.Column(): |
|
ex_btn1 = gr.Button("Example 2") |
|
ex_btn1.style(full_width=False, css="width:20px;") |
|
ex_image1 = gr.Image(value='thumbnails/val_image1.png', interactive=False, label='ex 2') |
|
ex_image1.style(width=128, height=128) |
|
|
|
with gr.Column(): |
|
ex_btn2 = gr.Button("Example 3") |
|
ex_btn2.style(full_width=False, css="width:20px;") |
|
ex_image2 = gr.Image(value='thumbnails/val_image2.png', interactive=False, label='ex 3') |
|
ex_image2.style(width=128, height=128) |
|
|
|
with gr.Column(): |
|
ex_btn3 = gr.Button("Example 4") |
|
ex_btn3.style(full_width=False, css="width:20px;") |
|
ex_image3 = gr.Image(value='thumbnails/val_image3.png', interactive=False, label='ex 4') |
|
ex_image3.style(width=128, height=128) |
|
|
|
with gr.Column(): |
|
ex_btn4 = gr.Button("Example 5") |
|
ex_btn4.style(full_width=False, css="width:20px;") |
|
ex_image4 = gr.Image(value='thumbnails/val_image4.png', interactive=False, label='ex 5') |
|
ex_image4.style(width=128, height=128) |
|
|
|
with gr.Column(): |
|
ex_btn5 = gr.Button("Example 6") |
|
ex_btn5.style(full_width=False, css="width:20px;") |
|
ex_image5 = gr.Image(value='thumbnails/val_image5.png', interactive=False, label='ex 6') |
|
ex_image5.style(width=128, height=128) |
|
|
|
with gr.Column(): |
|
ex_btn6 = gr.Button("Example 7") |
|
ex_btn6.style(full_width=False, css="width:20px;") |
|
ex_image6 = gr.Image(value='thumbnails/val_image6.png', interactive=False, label='ex 7') |
|
ex_image6.style(width=128, height=128) |
|
|
|
with gr.Column(): |
|
ex_btn7 = gr.Button("Example 8") |
|
ex_btn7.style(full_width=False, css="width:20px;") |
|
ex_image7 = gr.Image(value='thumbnails/val_image7.png', interactive=False, label='ex 8') |
|
ex_image7.style(width=128, height=128) |
|
|
|
with gr.Column(): |
|
ex_btn8 = gr.Button("Example 9") |
|
ex_btn8.style(full_width=False, css="width:20px;") |
|
ex_image8 = gr.Image(value='thumbnails/val_image8.png', interactive=False, label='ex 9') |
|
ex_image8.style(width=128, height=128) |
|
|
|
|
|
ex_btn0.click(fn=load_image0, outputs=[sample_index, input_image, label_image]) |
|
ex_btn1.click(fn=load_image1, outputs=[sample_index, input_image, label_image]) |
|
ex_btn2.click(fn=load_image2, outputs=[sample_index, input_image, label_image]) |
|
ex_btn3.click(fn=load_image3, outputs=[sample_index, input_image, label_image]) |
|
ex_btn4.click(fn=load_image4, outputs=[sample_index, input_image, label_image]) |
|
ex_btn5.click(fn=load_image5, outputs=[sample_index, input_image, label_image]) |
|
ex_btn6.click(fn=load_image6, outputs=[sample_index, input_image, label_image]) |
|
ex_btn7.click(fn=load_image7, outputs=[sample_index, input_image, label_image]) |
|
ex_btn8.click(fn=load_image8, outputs=[sample_index, input_image, label_image]) |
|
|
|
|
|
gr.HTML("""<br/>""") |
|
gr.HTML("""<h4 style="color:navy;">2. Then, click "Infer" button to predict a segmentation image. It will take about 15 seconds (on cpu)</h4>""") |
|
|
|
send_btn = gr.Button("Infer") |
|
send_btn.click(fn=predict, inputs=[sample_index], outputs=[output_image]) |
|
|
|
|
|
|
|
demo.launch(debug=True) |
|
|
|
|
|
|
|
|