|
|
|
import torch |
|
import matplotlib.pyplot as plt |
|
|
|
from monai.networks.nets import SegResNet |
|
from monai.inferers import sliding_window_inference |
|
|
|
from monai.transforms import ( |
|
Activations, |
|
AsDiscrete, |
|
Compose, |
|
) |
|
|
|
model = SegResNet( |
|
blocks_down=[1, 2, 2, 4], |
|
blocks_up=[1, 1, 1], |
|
init_filters=16, |
|
in_channels=4, |
|
out_channels=3, |
|
dropout_prob=0.2, |
|
) |
|
|
|
model.load_state_dict( |
|
torch.load("weights/model.pt", map_location=torch.device('cpu')) |
|
) |
|
|
|
|
|
VAL_AMP = True |
|
|
|
def inference(input): |
|
|
|
def _compute(input): |
|
return sliding_window_inference( |
|
inputs=input, |
|
roi_size=(240, 240, 160), |
|
sw_batch_size=1, |
|
predictor=model, |
|
overlap=0.5, |
|
) |
|
|
|
if VAL_AMP: |
|
with torch.cuda.amp.autocast(): |
|
return _compute(input) |
|
else: |
|
return _compute(input) |
|
|
|
|
|
post_trans = Compose( |
|
[Activations(sigmoid=True), AsDiscrete(threshold=0.5)] |
|
) |
|
|
|
import gradio as gr |
|
|
|
def load_sample1(): |
|
return load_sample(1) |
|
|
|
def load_sample2(): |
|
return load_sample(2) |
|
|
|
def load_sample3(): |
|
return load_sample(3) |
|
|
|
def load_sample4(): |
|
return load_sample(4) |
|
|
|
import torchvision |
|
|
|
def load_sample(index): |
|
|
|
|
|
sample = torch.load(f"samples/val{index-1}.pt") |
|
imgs = [] |
|
for i in range(4): |
|
imgs.append(sample["image"][i, :, :, 70]) |
|
|
|
pil_images = [] |
|
for i in range(4): |
|
pil_images.append(torchvision.transforms.functional.to_pil_image(imgs[i])) |
|
|
|
imgs_label = [] |
|
for i in range(3): |
|
imgs_label.append(sample["label"][i, :, :, 70]) |
|
|
|
pil_images_label = [] |
|
for i in range(3): |
|
pil_images_label.append(torchvision.transforms.functional.to_pil_image(imgs_label[i])) |
|
|
|
return [index, pil_images[0], pil_images[1], pil_images[2], pil_images[3], |
|
pil_images_label[0], pil_images_label[1], pil_images_label[2]] |
|
|
|
|
|
def predict(sample_index): |
|
print(sample_index) |
|
sample = torch.load(f"samples/val{sample_index-1}.pt") |
|
model.eval() |
|
with torch.no_grad(): |
|
|
|
val_input = sample["image"].unsqueeze(0) |
|
roi_size = (128, 128, 64) |
|
sw_batch_size = 4 |
|
val_output = inference(val_input) |
|
val_output = post_trans(val_output[0]) |
|
|
|
imgs_output = [] |
|
for i in range(3): |
|
imgs_output.append(val_output[i, :, :, 70]) |
|
|
|
pil_images_output = [] |
|
for i in range(3): |
|
pil_images_output.append(torchvision.transforms.functional.to_pil_image(imgs_output[i])) |
|
|
|
return [pil_images_output[0], pil_images_output[1], pil_images_output[2]] |
|
|
|
with gr.Blocks(css=".gradio-container {background:lightyellow;color:red;}", title="テスト" |
|
) as demo: |
|
sample_index = gr.State([]) |
|
|
|
gr.HTML('<div style="font-size:12pt; text-align:center; color:yellow;">MNIST 分類器</div>') |
|
|
|
with gr.Row(): |
|
input_image0 = gr.Image(label="image channel 0", type="pil", shape=(240, 240)) |
|
input_image1 = gr.Image(label="image channel 1", type="pil", shape=(240, 240)) |
|
input_image2 = gr.Image(label="image channel 2", type="pil", shape=(240, 240)) |
|
input_image3 = gr.Image(label="image channel 3", type="pil", shape=(240, 240)) |
|
|
|
|
|
|
|
with gr.Row(): |
|
label_image0 = gr.Image(label="label channel 0", type="pil") |
|
label_image1 = gr.Image(label="label channel 1", type="pil") |
|
label_image2 = gr.Image(label="label channel 2", type="pil") |
|
|
|
with gr.Row(): |
|
example1_btn = gr.Button("Example 1") |
|
example2_btn = gr.Button("Example 2") |
|
example3_btn = gr.Button("Example 3") |
|
example4_btn = gr.Button("Example 4") |
|
|
|
example1_btn.click(fn=load_sample1, inputs=None, |
|
outputs=[sample_index, input_image0, input_image1, input_image2, input_image3, |
|
label_image0, label_image1, label_image2]) |
|
example2_btn.click(fn=load_sample2, inputs=None, |
|
outputs=[sample_index, input_image0, input_image1, input_image2, input_image3, |
|
label_image0, label_image1, label_image2]) |
|
example3_btn.click(fn=load_sample3, inputs=None, |
|
outputs=[sample_index, input_image0, input_image1, input_image2, input_image3, |
|
label_image0, label_image1, label_image2]) |
|
example4_btn.click(fn=load_sample4, inputs=None, |
|
outputs=[sample_index, input_image0, input_image1, input_image2, input_image3, |
|
label_image0, label_image1, label_image2]) |
|
|
|
with gr.Row(): |
|
output_image0 = gr.Image(label="output channel 0", type="pil") |
|
output_image1 = gr.Image(label="output channel 1", type="pil") |
|
output_image2 = gr.Image(label="output channel 2", type="pil") |
|
|
|
|
|
|
|
send_btn = gr.Button("予測する") |
|
|
|
|
|
|
|
send_btn.click(fn=predict, inputs=[sample_index], outputs=[output_image0, output_image1, output_image2]) |
|
|
|
|
|
demo.launch(debug=True) |
|
|
|
|
|
|
|
|
|
|