import torch import matplotlib.pyplot as plt from monai.networks.nets import SegResNet from monai.inferers import sliding_window_inference from monai.transforms import ( Activations, AsDiscrete, Compose, ) model = SegResNet( blocks_down=[1, 2, 2, 4], blocks_up=[1, 1, 1], init_filters=16, in_channels=4, out_channels=3, dropout_prob=0.2, ) model.load_state_dict( torch.load("weights/model.pt", map_location=torch.device('cpu')) ) # define inference method VAL_AMP = True def inference(input): def _compute(input): return sliding_window_inference( inputs=input, roi_size=(240, 240, 160), sw_batch_size=1, predictor=model, overlap=0.5, ) if VAL_AMP: with torch.cuda.amp.autocast(): return _compute(input) else: return _compute(input) post_trans = Compose( [Activations(sigmoid=True), AsDiscrete(threshold=0.5)] ) import gradio as gr def load_sample1(): return load_sample(1) def load_sample2(): return load_sample(2) def load_sample3(): return load_sample(3) def load_sample4(): return load_sample(4) def load_sample5(): return load_sample(5) def load_sample6(): return load_sample(6) def load_sample7(): return load_sample(7) def load_sample8(): return load_sample(8) import torchvision def load_sample(index): #sample_index = index sample = torch.load(f"samples/val{index-1}.pt") imgs = [] for i in range(4): imgs.append(sample["image"][i, :, :, 70]) pil_images = [] for i in range(4): pil_images.append(torchvision.transforms.functional.to_pil_image(imgs[i])) imgs_label = [] for i in range(3): imgs_label.append(sample["label"][i, :, :, 70]) pil_images_label = [] for i in range(3): pil_images_label.append(torchvision.transforms.functional.to_pil_image(imgs_label[i])) return [index, pil_images[0], pil_images[1], pil_images[2], pil_images[3], pil_images_label[0], pil_images_label[1], pil_images_label[2]] def predict(sample_index): sample = torch.load(f"samples/val{sample_index-1}.pt") model.eval() with torch.no_grad(): # select one image to evaluate and visualize the model output val_input = sample["image"].unsqueeze(0) roi_size = (128, 128, 64) sw_batch_size = 4 val_output = inference(val_input) val_output = post_trans(val_output[0]) imgs_output = [] for i in range(3): imgs_output.append(val_output[i, :, :, 70]) pil_images_output = [] for i in range(3): pil_images_output.append(torchvision.transforms.functional.to_pil_image(imgs_output[i])) return [pil_images_output[0], pil_images_output[1], pil_images_output[2]] with gr.Blocks(title="Brain tumor 3D segmentation with MONAIMNIST - ClassCat", css=".gradio-container {background:azure;}" ) as demo: sample_index = gr.State([]) gr.HTML("""
Brain tumor 3D segmentation with MONAI
""") gr.HTML("""

1. Select an example, which includes input images and label images, by clicking "Example x" button.

""") with gr.Row(): input_image0 = gr.Image(label="image channel 0", type="pil", shape=(240, 240)) input_image1 = gr.Image(label="image channel 1", type="pil", shape=(240, 240)) input_image2 = gr.Image(label="image channel 2", type="pil", shape=(240, 240)) input_image3 = gr.Image(label="image channel 3", type="pil", shape=(240, 240)) with gr.Row(): label_image0 = gr.Image(label="label channel 0", type="pil") label_image1 = gr.Image(label="label channel 1", type="pil") label_image2 = gr.Image(label="label channel 2", type="pil") with gr.Row(): example1_btn = gr.Button("Example 1") example2_btn = gr.Button("Example 2") example3_btn = gr.Button("Example 3") example4_btn = gr.Button("Example 4") example5_btn = gr.Button("Example 5") example6_btn = gr.Button("Example 6") example7_btn = gr.Button("Example 7") example8_btn = gr.Button("Example 8") example1_btn.click(fn=load_sample1, inputs=None, outputs=[sample_index, input_image0, input_image1, input_image2, input_image3, label_image0, label_image1, label_image2]) example2_btn.click(fn=load_sample2, inputs=None, outputs=[sample_index, input_image0, input_image1, input_image2, input_image3, label_image0, label_image1, label_image2]) example3_btn.click(fn=load_sample3, inputs=None, outputs=[sample_index, input_image0, input_image1, input_image2, input_image3, label_image0, label_image1, label_image2]) example4_btn.click(fn=load_sample4, inputs=None, outputs=[sample_index, input_image0, input_image1, input_image2, input_image3, label_image0, label_image1, label_image2]) example5_btn.click(fn=load_sample5, inputs=None, outputs=[sample_index, input_image0, input_image1, input_image2, input_image3, label_image0, label_image1, label_image2]) example6_btn.click(fn=load_sample6, inputs=None, outputs=[sample_index, input_image0, input_image1, input_image2, input_image3, label_image0, label_image1, label_image2]) example7_btn.click(fn=load_sample7, inputs=None, outputs=[sample_index, input_image0, input_image1, input_image2, input_image3, label_image0, label_image1, label_image2]) example8_btn.click(fn=load_sample8, inputs=None, outputs=[sample_index, input_image0, input_image1, input_image2, input_image3, label_image0, label_image1, label_image2]) gr.HTML("""
""") gr.HTML("""

2. Then, click "Infer" button to predict segmentation images. It will take about 30 seconds (on cpu)

""") with gr.Row(): output_image0 = gr.Image(label="output channel 0", type="pil") output_image1 = gr.Image(label="output channel 1", type="pil") output_image2 = gr.Image(label="output channel 2", type="pil") send_btn = gr.Button("Infer") send_btn.click(fn=predict, inputs=[sample_index], outputs=[output_image0, output_image1, output_image2]) gr.HTML("""
""") gr.HTML("""

Reference

""") gr.HTML("""""") #demo.queue() demo.launch(debug=True) ### EOF ###