import torch, torchvision from monai.networks.nets import UNet from monai.networks.layers import Norm from monai.inferers import sliding_window_inference import PIL from torchvision.utils import save_image import numpy as np model = UNet( spatial_dims=3, in_channels=1, out_channels=2, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, norm=Norm.BATCH, ) model.load_state_dict(torch.load("weights/model.pt", map_location=torch.device('cpu'))) import gradio as gr def load_image0(): return load_image(0) def load_image1(): return load_image(1) def load_image2(): return load_image(2) def load_image3(): return load_image(3) def load_image4(): return load_image(4) def load_image5(): return load_image(5) def load_image6(): return load_image(6) def load_image7(): return load_image(7) def load_image8(): return load_image(8) def load_image(index): return [index, f"thumbnails/val_image{index}.png", f"thumbnails_label/val_label{index}.png"] def predict(index): val_data = torch.load(f"samples/val_data{index}.pt") model.eval() with torch.no_grad(): roi_size = (160, 160, 160) sw_batch_size = 4 val_outputs = sliding_window_inference(val_data, roi_size, sw_batch_size, model) meta_tsr = torch.argmax(val_outputs, dim=1)[0, :, :, 80] pil_image = torchvision.transforms.functional.to_pil_image(meta_tsr.to(torch.float32)) return pil_image with gr.Blocks(title="Spleen 3D segmentation with MONAI - ClassCat", css=".gradio-container {background:azure;}" ) as demo: sample_index = gr.State([]) gr.HTML("""
Spleen 3D segmentation with MONAI
""") gr.HTML("""

1. Select an example, which includes input images and label images, by clicking "Example x" button.

""") with gr.Row(): input_image = gr.Image(label="a piece of input image data", type="filepath") label_image = gr.Image(label="label image", type="filepath") output_image = gr.Image(label="predicted image", type="pil") with gr.Row(): with gr.Column(): ex_btn0 = gr.Button("Example 1") ex_btn0.style(full_width=False, css="width:20px;") ex_image0 = gr.Image(value='thumbnails/val_image0.png', interactive=False, label='ex 1') ex_image0.style(width=128, height=128) with gr.Column(): ex_btn1 = gr.Button("Example 2") ex_btn1.style(full_width=False, css="width:20px;") ex_image1 = gr.Image(value='thumbnails/val_image1.png', interactive=False, label='ex 2') ex_image1.style(width=128, height=128) with gr.Column(): ex_btn2 = gr.Button("Example 3") ex_btn2.style(full_width=False, css="width:20px;") ex_image2 = gr.Image(value='thumbnails/val_image2.png', interactive=False, label='ex 3') ex_image2.style(width=128, height=128) with gr.Column(): ex_btn3 = gr.Button("Example 4") ex_btn3.style(full_width=False, css="width:20px;") ex_image3 = gr.Image(value='thumbnails/val_image3.png', interactive=False, label='ex 4') ex_image3.style(width=128, height=128) with gr.Column(): ex_btn4 = gr.Button("Example 5") ex_btn4.style(full_width=False, css="width:20px;") ex_image4 = gr.Image(value='thumbnails/val_image4.png', interactive=False, label='ex 5') ex_image4.style(width=128, height=128) with gr.Column(): ex_btn5 = gr.Button("Example 6") ex_btn5.style(full_width=False, css="width:20px;") ex_image5 = gr.Image(value='thumbnails/val_image5.png', interactive=False, label='ex 6') ex_image5.style(width=128, height=128) ex_btn0.click(fn=load_image0, outputs=[sample_index, input_image, label_image]) ex_btn1.click(fn=load_image1, outputs=[sample_index, input_image, label_image]) ex_btn2.click(fn=load_image2, outputs=[sample_index, input_image, label_image]) ex_btn3.click(fn=load_image3, outputs=[sample_index, input_image, label_image]) ex_btn4.click(fn=load_image4, outputs=[sample_index, input_image, label_image]) ex_btn5.click(fn=load_image5, outputs=[sample_index, input_image, label_image]) gr.HTML("""
""") gr.HTML("""

2. Then, click "Infer" button to predict a segmentation image. It will take about 15 seconds (on cpu)

""") send_btn = gr.Button("Infer") send_btn.click(fn=predict, inputs=[sample_index], outputs=[output_image]) #demo.queue() demo.launch(debug=True) ### EOF ###