ClassCat commited on
Commit
58f37a3
·
1 Parent(s): ab612d6

add app.py

Browse files
Files changed (1) hide show
  1. app.py +125 -0
app.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch, torchvision
3
+ from monai.networks.nets import UNet
4
+ from monai.networks.layers import Norm
5
+ from monai.inferers import sliding_window_inference
6
+ import PIL
7
+ from torchvision.utils import save_image
8
+ import numpy as np
9
+
10
+ model = UNet(
11
+ spatial_dims=3,
12
+ in_channels=1,
13
+ out_channels=2,
14
+ channels=(16, 32, 64, 128, 256),
15
+ strides=(2, 2, 2, 2),
16
+ num_res_units=2,
17
+ norm=Norm.BATCH,
18
+ )
19
+
20
+ model.load_state_dict(torch.load("weights/model.pt", map_location=torch.device('cpu')))
21
+
22
+ import gradio as gr
23
+
24
+ def load_image0():
25
+ return load_image(0)
26
+
27
+ def load_image1():
28
+ return load_image(1)
29
+
30
+ def load_image2():
31
+ return load_image(2)
32
+
33
+ def load_image3():
34
+ return load_image(3)
35
+
36
+ def load_image4():
37
+ return load_image(4)
38
+
39
+ def load_image5():
40
+ return load_image(5)
41
+
42
+ def load_image6():
43
+ return load_image(6)
44
+
45
+ def load_image7():
46
+ return load_image(7)
47
+
48
+ def load_image8():
49
+ return load_image(8)
50
+
51
+ def load_image(index):
52
+ return [index, f"thumbnails/val_image{index}.png", f"thumbnails_label/val_label{index}.png"]
53
+
54
+ def predict(index):
55
+ val_data = torch.load(f"samples/val_data{index}.pt")
56
+
57
+ model.eval()
58
+ with torch.no_grad():
59
+ roi_size = (160, 160, 160)
60
+ sw_batch_size = 4
61
+ val_outputs = sliding_window_inference(val_data, roi_size, sw_batch_size, model)
62
+
63
+ meta_tsr = torch.argmax(val_outputs, dim=1)[0, :, :, 80]
64
+ pil_image = torchvision.transforms.functional.to_pil_image(meta_tsr.to(torch.float32))
65
+
66
+ return pil_image
67
+
68
+
69
+ with gr.Blocks(title="Spleen 3D segmentation with MONAI - ClassCat",
70
+ css=".gradio-container {background:azure;}"
71
+ ) as demo:
72
+ sample_index = gr.State([])
73
+
74
+ gr.HTML("""<div style="font-family:'Times New Roman', 'Serif'; font-size:16pt; font-weight:bold; text-align:center; color:royalblue;">Spleen 3D segmentation with MONAI</div>""")
75
+
76
+ gr.HTML("""<h4 style="color:navy;">1. Select an example, which includes input images and label images, by clicking "Example x" button.</h4>""")
77
+
78
+ with gr.Row():
79
+ input_image = gr.Image(label="a piece of input image data", type="filepath")
80
+ label_image = gr.Image(label="label image", type="filepath")
81
+ output_image = gr.Image(label="predicted image", type="pil")
82
+
83
+
84
+ with gr.Row():
85
+ with gr.Column():
86
+ ex_btn0 = gr.Button("Example 1")
87
+ ex_btn0.style(full_width=False, css="width:20px;")
88
+ ex_image0 = gr.Image(value='thumbnails/val_image0.png', interactive=False, label='ex 1')
89
+ ex_image0.style(width=128, height=128)
90
+
91
+ with gr.Column():
92
+ ex_btn1 = gr.Button("Example 2")
93
+ ex_btn1.style(full_width=False, css="width:20px;")
94
+ ex_image1 = gr.Image(value='thumbnails/val_image1.png', interactive=False, label='ex 2')
95
+ ex_image1.style(width=128, height=128)
96
+
97
+ with gr.Column():
98
+ ex_btn2 = gr.Button("Example 3")
99
+ ex_btn2.style(full_width=False, css="width:20px;")
100
+ ex_image2 = gr.Image(value='thumbnails/val_image2.png', interactive=False, label='ex 3')
101
+ ex_image2.style(width=128, height=128)
102
+
103
+ with gr.Column():
104
+ ex_btn3 = gr.Button("Example 4")
105
+ ex_btn3.style(full_width=False, css="width:20px;")
106
+ ex_image3 = gr.Image(value='thumbnails/val_image3.png', interactive=False, label='ex 4')
107
+ ex_image3.style(width=128, height=128)
108
+
109
+ ex_btn0.click(fn=load_image0, outputs=[sample_index, input_image, label_image])
110
+ ex_btn1.click(fn=load_image1, outputs=[sample_index, input_image, label_image])
111
+ ex_btn2.click(fn=load_image2, outputs=[sample_index, input_image, label_image])
112
+ ex_btn3.click(fn=load_image3, outputs=[sample_index, input_image, label_image])
113
+
114
+ gr.HTML("""<br/>""")
115
+ gr.HTML("""<h4 style="color:navy;">2. Then, click "Infer" button to predict segmentation images. It will take about 30 seconds (on cpu)</h4>""")
116
+
117
+ send_btn = gr.Button("Infer")
118
+ send_btn.click(fn=predict, inputs=[sample_index], outputs=[output_image])
119
+
120
+
121
+ #demo.queue()
122
+ demo.launch(debug=True)
123
+
124
+
125
+ ### EOF ###