File size: 5,967 Bytes
58f37a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2047c26
210af37
 
 
 
 
2047c26
210af37
 
 
 
 
4b17e5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58f37a3
 
 
 
210af37
 
4b17e5b
 
333a52b
210af37
58f37a3
 
210af37
58f37a3
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163

import torch, torchvision
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.inferers import sliding_window_inference
import PIL
from torchvision.utils import save_image
import numpy as np

model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=2,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=Norm.BATCH,
)

model.load_state_dict(torch.load("weights/model.pt", map_location=torch.device('cpu')))

import gradio as gr

def load_image0():
    return load_image(0)

def load_image1():
    return load_image(1)

def load_image2():
    return load_image(2)

def load_image3():
    return load_image(3)

def load_image4():
    return load_image(4)

def load_image5():
    return load_image(5)

def load_image6():
    return load_image(6)

def load_image7():
    return load_image(7)

def load_image8():
    return load_image(8)

def load_image(index):
    return [index, f"thumbnails/val_image{index}.png", f"thumbnails_label/val_label{index}.png"]

def predict(index):
    val_data = torch.load(f"samples/val_data{index}.pt")
    
    model.eval()
    with torch.no_grad():
        roi_size = (160, 160, 160)
        sw_batch_size = 4
        val_outputs = sliding_window_inference(val_data, roi_size, sw_batch_size, model)

    meta_tsr = torch.argmax(val_outputs, dim=1)[0, :, :, 80]
    pil_image = torchvision.transforms.functional.to_pil_image(meta_tsr.to(torch.float32))

    return pil_image


with gr.Blocks(title="Spleen 3D segmentation with MONAI - ClassCat",
                    css=".gradio-container {background:azure;}"
               ) as demo:
    sample_index = gr.State([])

    gr.HTML("""<div style="font-family:'Times New Roman', 'Serif'; font-size:16pt; font-weight:bold; text-align:center; color:royalblue;">Spleen 3D segmentation with MONAI</div>""")

    gr.HTML("""<h4 style="color:navy;">1. Select an example, which includes input images and label images, by clicking "Example x" button.</h4>""")

    with gr.Row():
        input_image = gr.Image(label="a piece of input image data", type="filepath")
        label_image = gr.Image(label="label image", type="filepath")
        output_image = gr.Image(label="predicted image", type="pil")


    with gr.Row():
        with gr.Column():
            ex_btn0 = gr.Button("Example 1")
            ex_btn0.style(full_width=False, css="width:20px;")
            ex_image0 = gr.Image(value='thumbnails/val_image0.png', interactive=False, label='ex 1')
            ex_image0.style(width=128, height=128)

        with gr.Column():
            ex_btn1 = gr.Button("Example 2")
            ex_btn1.style(full_width=False, css="width:20px;")
            ex_image1 = gr.Image(value='thumbnails/val_image1.png', interactive=False, label='ex 2')
            ex_image1.style(width=128, height=128)

        with gr.Column():
            ex_btn2 = gr.Button("Example 3")
            ex_btn2.style(full_width=False, css="width:20px;")
            ex_image2 = gr.Image(value='thumbnails/val_image2.png', interactive=False, label='ex 3')
            ex_image2.style(width=128, height=128)

        with gr.Column():
            ex_btn3 = gr.Button("Example 4")
            ex_btn3.style(full_width=False, css="width:20px;")
            ex_image3 = gr.Image(value='thumbnails/val_image3.png', interactive=False, label='ex 4')
            ex_image3.style(width=128, height=128)

        with gr.Column():
            ex_btn4 = gr.Button("Example 5")
            ex_btn4.style(full_width=False, css="width:20px;")
            ex_image4 = gr.Image(value='thumbnails/val_image4.png', interactive=False, label='ex 5')
            ex_image4.style(width=128, height=128)

        with gr.Column():
            ex_btn5 = gr.Button("Example 6")
            ex_btn5.style(full_width=False, css="width:20px;")
            ex_image5 = gr.Image(value='thumbnails/val_image5.png', interactive=False, label='ex 6')
            ex_image5.style(width=128, height=128)

        with gr.Column():
            ex_btn6 = gr.Button("Example 7")
            ex_btn6.style(full_width=False, css="width:20px;")
            ex_image6 = gr.Image(value='thumbnails/val_image6.png', interactive=False, label='ex 7')
            ex_image6.style(width=128, height=128)

        with gr.Column():
            ex_btn7 = gr.Button("Example 8")
            ex_btn7.style(full_width=False, css="width:20px;")
            ex_image7 = gr.Image(value='thumbnails/val_image7.png', interactive=False, label='ex 8')
            ex_image7.style(width=128, height=128)

        with gr.Column():
            ex_btn8 = gr.Button("Example 9")
            ex_btn8.style(full_width=False, css="width:20px;")
            ex_image8 = gr.Image(value='thumbnails/val_image8.png', interactive=False, label='ex 9')
            ex_image8.style(width=128, height=128)


    ex_btn0.click(fn=load_image0, outputs=[sample_index, input_image, label_image])
    ex_btn1.click(fn=load_image1, outputs=[sample_index, input_image, label_image])
    ex_btn2.click(fn=load_image2, outputs=[sample_index, input_image, label_image])
    ex_btn3.click(fn=load_image3, outputs=[sample_index, input_image, label_image])
    ex_btn4.click(fn=load_image4, outputs=[sample_index, input_image, label_image])
    ex_btn5.click(fn=load_image5, outputs=[sample_index, input_image, label_image])
    ex_btn6.click(fn=load_image6, outputs=[sample_index, input_image, label_image])
    ex_btn7.click(fn=load_image7, outputs=[sample_index, input_image, label_image])
    ex_btn8.click(fn=load_image8, outputs=[sample_index, input_image, label_image])


    gr.HTML("""<br/>""")
    gr.HTML("""<h4 style="color:navy;">2. Then, click "Infer" button to predict a segmentation image. It will take about 15 seconds (on cpu)</h4>""")

    send_btn = gr.Button("Infer")
    send_btn.click(fn=predict, inputs=[sample_index], outputs=[output_image])


#demo.queue()
demo.launch(debug=True)


### EOF ###