Spaces:
Runtime error
Runtime error
import torch, torchvision | |
from monai.networks.nets import UNet | |
from monai.networks.layers import Norm | |
from monai.inferers import sliding_window_inference | |
import PIL | |
from torchvision.utils import save_image | |
import numpy as np | |
model = UNet( | |
spatial_dims=3, | |
in_channels=1, | |
out_channels=2, | |
channels=(16, 32, 64, 128, 256), | |
strides=(2, 2, 2, 2), | |
num_res_units=2, | |
norm=Norm.BATCH, | |
) | |
model.load_state_dict(torch.load("weights/model.pt", map_location=torch.device('cpu'))) | |
import gradio as gr | |
def load_image0(): | |
return load_image(0) | |
def load_image1(): | |
return load_image(1) | |
def load_image2(): | |
return load_image(2) | |
def load_image3(): | |
return load_image(3) | |
def load_image4(): | |
return load_image(4) | |
def load_image5(): | |
return load_image(5) | |
def load_image6(): | |
return load_image(6) | |
def load_image7(): | |
return load_image(7) | |
def load_image8(): | |
return load_image(8) | |
def load_image(index): | |
return [index, f"thumbnails/val_image{index}.png", f"thumbnails_label/val_label{index}.png"] | |
def predict(index): | |
val_data = torch.load(f"samples/val_data{index}.pt") | |
model.eval() | |
with torch.no_grad(): | |
roi_size = (160, 160, 160) | |
sw_batch_size = 4 | |
val_outputs = sliding_window_inference(val_data, roi_size, sw_batch_size, model) | |
meta_tsr = torch.argmax(val_outputs, dim=1)[0, :, :, 80] | |
pil_image = torchvision.transforms.functional.to_pil_image(meta_tsr.to(torch.float32)) | |
return pil_image | |
with gr.Blocks(title="Spleen 3D segmentation with MONAI - ClassCat", | |
css=".gradio-container {background:azure;}" | |
) as demo: | |
sample_index = gr.State([]) | |
gr.HTML("""<div style="font-family:'Times New Roman', 'Serif'; font-size:16pt; font-weight:bold; text-align:center; color:royalblue;">Spleen 3D segmentation with MONAI</div>""") | |
gr.HTML("""<h4 style="color:navy;">1. Select an example, which includes input images and label images, by clicking "Example x" button.</h4>""") | |
with gr.Row(): | |
input_image = gr.Image(label="a piece of input image data", type="filepath") | |
label_image = gr.Image(label="label image", type="filepath") | |
output_image = gr.Image(label="predicted image", type="pil") | |
with gr.Row(): | |
with gr.Column(): | |
ex_btn0 = gr.Button("Example 1") | |
ex_btn0.style(full_width=False, css="width:20px;") | |
ex_image0 = gr.Image(value='thumbnails/val_image0.png', interactive=False, label='ex 1') | |
ex_image0.style(width=128, height=128) | |
with gr.Column(): | |
ex_btn1 = gr.Button("Example 2") | |
ex_btn1.style(full_width=False, css="width:20px;") | |
ex_image1 = gr.Image(value='thumbnails/val_image1.png', interactive=False, label='ex 2') | |
ex_image1.style(width=128, height=128) | |
with gr.Column(): | |
ex_btn2 = gr.Button("Example 3") | |
ex_btn2.style(full_width=False, css="width:20px;") | |
ex_image2 = gr.Image(value='thumbnails/val_image2.png', interactive=False, label='ex 3') | |
ex_image2.style(width=128, height=128) | |
with gr.Column(): | |
ex_btn3 = gr.Button("Example 4") | |
ex_btn3.style(full_width=False, css="width:20px;") | |
ex_image3 = gr.Image(value='thumbnails/val_image3.png', interactive=False, label='ex 4') | |
ex_image3.style(width=128, height=128) | |
with gr.Column(): | |
ex_btn4 = gr.Button("Example 5") | |
ex_btn4.style(full_width=False, css="width:20px;") | |
ex_image4 = gr.Image(value='thumbnails/val_image4.png', interactive=False, label='ex 5') | |
ex_image4.style(width=128, height=128) | |
with gr.Column(): | |
ex_btn5 = gr.Button("Example 6") | |
ex_btn5.style(full_width=False, css="width:20px;") | |
ex_image5 = gr.Image(value='thumbnails/val_image5.png', interactive=False, label='ex 6') | |
ex_image5.style(width=128, height=128) | |
ex_btn0.click(fn=load_image0, outputs=[sample_index, input_image, label_image]) | |
ex_btn1.click(fn=load_image1, outputs=[sample_index, input_image, label_image]) | |
ex_btn2.click(fn=load_image2, outputs=[sample_index, input_image, label_image]) | |
ex_btn3.click(fn=load_image3, outputs=[sample_index, input_image, label_image]) | |
ex_btn4.click(fn=load_image4, outputs=[sample_index, input_image, label_image]) | |
ex_btn5.click(fn=load_image5, outputs=[sample_index, input_image, label_image]) | |
gr.HTML("""<br/>""") | |
gr.HTML("""<h4 style="color:navy;">2. Then, click "Infer" button to predict a segmentation image. It will take about 15 seconds (on cpu)</h4>""") | |
send_btn = gr.Button("Infer") | |
send_btn.click(fn=predict, inputs=[sample_index], outputs=[output_image]) | |
#demo.queue() | |
demo.launch(debug=True) | |
### EOF ### | |