ClassCat commited on
Commit
08ccc4a
1 Parent(s): be6dbe0

add app.py

Browse files
Files changed (1) hide show
  1. app.py +170 -0
app.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import matplotlib.pyplot as plt
4
+
5
+ from monai.networks.nets import SegResNet
6
+ from monai.inferers import sliding_window_inference
7
+
8
+ from monai.transforms import (
9
+ Activations,
10
+ AsDiscrete,
11
+ Compose,
12
+ )
13
+
14
+ model = SegResNet(
15
+ blocks_down=[1, 2, 2, 4],
16
+ blocks_up=[1, 1, 1],
17
+ init_filters=16,
18
+ in_channels=4,
19
+ out_channels=3,
20
+ dropout_prob=0.2,
21
+ )
22
+
23
+ model.load_state_dict(
24
+ torch.load("model.pt", map_location=torch.device('cpu'))
25
+ )
26
+
27
+ # define inference method
28
+ VAL_AMP = True
29
+
30
+ def inference(input):
31
+
32
+ def _compute(input):
33
+ return sliding_window_inference(
34
+ inputs=input,
35
+ roi_size=(240, 240, 160),
36
+ sw_batch_size=1,
37
+ predictor=model,
38
+ overlap=0.5,
39
+ )
40
+
41
+ if VAL_AMP:
42
+ with torch.cuda.amp.autocast():
43
+ return _compute(input)
44
+ else:
45
+ return _compute(input)
46
+
47
+
48
+ post_trans = Compose(
49
+ [Activations(sigmoid=True), AsDiscrete(threshold=0.5)]
50
+ )
51
+
52
+ import gradio as gr
53
+
54
+ def load_sample1():
55
+ return load_sample(1)
56
+
57
+ def load_sample2():
58
+ return load_sample(2)
59
+
60
+ def load_sample3():
61
+ return load_sample(3)
62
+
63
+ def load_sample4():
64
+ return load_sample(4)
65
+
66
+ import torchvision
67
+
68
+ def load_sample(index):
69
+ #sample_index = index
70
+
71
+ sample = torch.load(f"val{index-1}.pt")
72
+ imgs = []
73
+ for i in range(4):
74
+ imgs.append(sample["image"][i, :, :, 70])
75
+
76
+ pil_images = []
77
+ for i in range(4):
78
+ pil_images.append(torchvision.transforms.functional.to_pil_image(imgs[i]))
79
+
80
+ imgs_label = []
81
+ for i in range(3):
82
+ imgs_label.append(sample["label"][i, :, :, 70])
83
+
84
+ pil_images_label = []
85
+ for i in range(3):
86
+ pil_images_label.append(torchvision.transforms.functional.to_pil_image(imgs_label[i]))
87
+
88
+ return [index, pil_images[0], pil_images[1], pil_images[2], pil_images[3],
89
+ pil_images_label[0], pil_images_label[1], pil_images_label[2]]
90
+
91
+
92
+ def predict(sample_index):
93
+ print(sample_index)
94
+ sample = torch.load(f"val{sample_index-1}.pt")
95
+ model.eval()
96
+ with torch.no_grad():
97
+ # select one image to evaluate and visualize the model output
98
+ val_input = sample["image"].unsqueeze(0)
99
+ roi_size = (128, 128, 64)
100
+ sw_batch_size = 4
101
+ val_output = inference(val_input)
102
+ val_output = post_trans(val_output[0])
103
+
104
+ imgs_output = []
105
+ for i in range(3):
106
+ imgs_output.append(val_output[i, :, :, 70])
107
+
108
+ pil_images_output = []
109
+ for i in range(3):
110
+ pil_images_output.append(torchvision.transforms.functional.to_pil_image(imgs_output[i]))
111
+
112
+ return [pil_images_output[0], pil_images_output[1], pil_images_output[2]]
113
+
114
+ with gr.Blocks(css=".gradio-container {background:lightyellow;color:red;}", title="テスト"
115
+ ) as demo:
116
+ sample_index = gr.State([])
117
+
118
+ gr.HTML('<div style="font-size:12pt; text-align:center; color:yellow;">MNIST 分類器</div>')
119
+
120
+ with gr.Row():
121
+ input_image0 = gr.Image(label="image channel 0", type="pil", shape=(240, 240))
122
+ input_image1 = gr.Image(label="image channel 1", type="pil", shape=(240, 240))
123
+ input_image2 = gr.Image(label="image channel 2", type="pil", shape=(240, 240))
124
+ input_image3 = gr.Image(label="image channel 3", type="pil", shape=(240, 240))
125
+
126
+ #input_image = gr.Image(label="画像入力", type="pil", image_mode="RGB", shape=(240, 240))
127
+
128
+ with gr.Row():
129
+ label_image0 = gr.Image(label="label channel 0", type="pil")
130
+ label_image1 = gr.Image(label="label channel 1", type="pil")
131
+ label_image2 = gr.Image(label="label channel 2", type="pil")
132
+
133
+ with gr.Row():
134
+ example1_btn = gr.Button("Example 1")
135
+ example2_btn = gr.Button("Example 2")
136
+ example3_btn = gr.Button("Example 3")
137
+ example4_btn = gr.Button("Example 4")
138
+
139
+ example1_btn.click(fn=load_sample1, inputs=None,
140
+ outputs=[sample_index, input_image0, input_image1, input_image2, input_image3,
141
+ label_image0, label_image1, label_image2])
142
+ example2_btn.click(fn=load_sample2, inputs=None,
143
+ outputs=[sample_index, input_image0, input_image1, input_image2, input_image3,
144
+ label_image0, label_image1, label_image2])
145
+ example3_btn.click(fn=load_sample3, inputs=None,
146
+ outputs=[sample_index, input_image0, input_image1, input_image2, input_image3,
147
+ label_image0, label_image1, label_image2])
148
+ example4_btn.click(fn=load_sample4, inputs=None,
149
+ outputs=[sample_index, input_image0, input_image1, input_image2, input_image3,
150
+ label_image0, label_image1, label_image2])
151
+
152
+ with gr.Row():
153
+ output_image0 = gr.Image(label="output channel 0", type="pil")
154
+ output_image1 = gr.Image(label="output channel 1", type="pil")
155
+ output_image2 = gr.Image(label="output channel 2", type="pil")
156
+
157
+ #output_label=gr.Label(label="予測確率", num_top_classes=3)
158
+
159
+ send_btn = gr.Button("予測する")
160
+
161
+ #gr.Examples(['2.png', '4.png'], inputs=input_image2)
162
+
163
+ send_btn.click(fn=predict, inputs=[sample_index], outputs=[output_image0, output_image1, output_image2])
164
+
165
+ #demo.queue()
166
+ demo.launch(debug=True)
167
+
168
+
169
+
170
+ ### EOF ###