ClassCat's picture
update app.py
1757726
raw
history blame
5.37 kB
import torch
import matplotlib.pyplot as plt
from monai.networks.nets import SegResNet
from monai.inferers import sliding_window_inference
from monai.transforms import (
Activations,
AsDiscrete,
Compose,
)
model = SegResNet(
blocks_down=[1, 2, 2, 4],
blocks_up=[1, 1, 1],
init_filters=16,
in_channels=4,
out_channels=3,
dropout_prob=0.2,
)
model.load_state_dict(
torch.load("weights/model.pt", map_location=torch.device('cpu'))
)
# define inference method
VAL_AMP = True
def inference(input):
def _compute(input):
return sliding_window_inference(
inputs=input,
roi_size=(240, 240, 160),
sw_batch_size=1,
predictor=model,
overlap=0.5,
)
if VAL_AMP:
with torch.cuda.amp.autocast():
return _compute(input)
else:
return _compute(input)
post_trans = Compose(
[Activations(sigmoid=True), AsDiscrete(threshold=0.5)]
)
import gradio as gr
def load_sample1():
return load_sample(1)
def load_sample2():
return load_sample(2)
def load_sample3():
return load_sample(3)
def load_sample4():
return load_sample(4)
import torchvision
def load_sample(index):
#sample_index = index
sample = torch.load(f"samples/val{index-1}.pt")
imgs = []
for i in range(4):
imgs.append(sample["image"][i, :, :, 70])
pil_images = []
for i in range(4):
pil_images.append(torchvision.transforms.functional.to_pil_image(imgs[i]))
imgs_label = []
for i in range(3):
imgs_label.append(sample["label"][i, :, :, 70])
pil_images_label = []
for i in range(3):
pil_images_label.append(torchvision.transforms.functional.to_pil_image(imgs_label[i]))
return [index, pil_images[0], pil_images[1], pil_images[2], pil_images[3],
pil_images_label[0], pil_images_label[1], pil_images_label[2]]
def predict(sample_index):
print(sample_index)
sample = torch.load(f"samples/val{sample_index-1}.pt")
model.eval()
with torch.no_grad():
# select one image to evaluate and visualize the model output
val_input = sample["image"].unsqueeze(0)
roi_size = (128, 128, 64)
sw_batch_size = 4
val_output = inference(val_input)
val_output = post_trans(val_output[0])
imgs_output = []
for i in range(3):
imgs_output.append(val_output[i, :, :, 70])
pil_images_output = []
for i in range(3):
pil_images_output.append(torchvision.transforms.functional.to_pil_image(imgs_output[i]))
return [pil_images_output[0], pil_images_output[1], pil_images_output[2]]
with gr.Blocks(css=".gradio-container {background:lightyellow;color:red;}", title="テスト"
) as demo:
sample_index = gr.State([])
gr.HTML('<div style="font-size:12pt; text-align:center; color:yellow;">MNIST 分類器</div>')
with gr.Row():
input_image0 = gr.Image(label="image channel 0", type="pil", shape=(240, 240))
input_image1 = gr.Image(label="image channel 1", type="pil", shape=(240, 240))
input_image2 = gr.Image(label="image channel 2", type="pil", shape=(240, 240))
input_image3 = gr.Image(label="image channel 3", type="pil", shape=(240, 240))
#input_image = gr.Image(label="画像入力", type="pil", image_mode="RGB", shape=(240, 240))
with gr.Row():
label_image0 = gr.Image(label="label channel 0", type="pil")
label_image1 = gr.Image(label="label channel 1", type="pil")
label_image2 = gr.Image(label="label channel 2", type="pil")
with gr.Row():
example1_btn = gr.Button("Example 1")
example2_btn = gr.Button("Example 2")
example3_btn = gr.Button("Example 3")
example4_btn = gr.Button("Example 4")
example1_btn.click(fn=load_sample1, inputs=None,
outputs=[sample_index, input_image0, input_image1, input_image2, input_image3,
label_image0, label_image1, label_image2])
example2_btn.click(fn=load_sample2, inputs=None,
outputs=[sample_index, input_image0, input_image1, input_image2, input_image3,
label_image0, label_image1, label_image2])
example3_btn.click(fn=load_sample3, inputs=None,
outputs=[sample_index, input_image0, input_image1, input_image2, input_image3,
label_image0, label_image1, label_image2])
example4_btn.click(fn=load_sample4, inputs=None,
outputs=[sample_index, input_image0, input_image1, input_image2, input_image3,
label_image0, label_image1, label_image2])
with gr.Row():
output_image0 = gr.Image(label="output channel 0", type="pil")
output_image1 = gr.Image(label="output channel 1", type="pil")
output_image2 = gr.Image(label="output channel 2", type="pil")
#output_label=gr.Label(label="予測確率", num_top_classes=3)
send_btn = gr.Button("予測する")
#gr.Examples(['2.png', '4.png'], inputs=input_image2)
send_btn.click(fn=predict, inputs=[sample_index], outputs=[output_image0, output_image1, output_image2])
#demo.queue()
demo.launch(debug=True)
### EOF ###