import gradio as gr from gradio_client import Client import numpy as np #import torch import requests from PIL import Image #from torchvision import transforms from predict_unet import predict_model title = "
Medical Image Segmentation with UNet
" examples = [["examples/50494616.jpg"], ["examples/50494676.jpg"], ["examples/56399783.jpg"], ["examples/56399789.jpg"], ["examples/56399831.jpg"], ["examples/56399959.jpg"], ["examples/56400014.jpg"], ["examples/56400119.jpg"], ["examples/56481903.jpg"], ["examples/70749195.jpg"]] def run_unetv0(input): output = predict_model(input, "v0") normalized_output = np.clip(output, 0, 1) return normalized_output def run_unetv1(input): output = predict_model(input, "v1") normalized_output = np.clip(output, 0, 1) return normalized_output def run_unetv2(input): output = predict_model(input, "v2") normalized_output = np.clip(output, 0, 1) return normalized_output def run_unetv3(input): output = predict_model(input, "v3") normalized_output = np.clip(output, 0, 1) return normalized_output input_img_v0 = gr.Image(label="Input", type='numpy') segm_img_v0 = gr.Image(label="Segmented Image") input_img_v1 = gr.Image(label="Input", type='numpy') segm_img_v1 = gr.Image(label="Segmented Image") input_img_v2 = gr.Image(label="Input", type='numpy') segm_img_v2 = gr.Image(label="Segmented Image") input_img_v3 = gr.Image(label="Input", type='numpy') segm_img_v3 = gr.Image(label="Segmented Image") with gr.Blocks(title='UNet examples') as demo: # v0: regular UNet with gr.Tab("Regular UNet (v0)"): # display input image and segmented image with gr.Row(variant="panel"): with gr.Column(scale=1): input_img_v0.render() with gr.Column(scale=1): segm_img_v0.render() # submit and clear with gr.Row(): with gr.Column(): segment_btn_v0 = gr.Button("Run Segmentation", variant='primary') clear_btn_v0 = gr.Button("Clear", variant="secondary") # load examples gr.Markdown("Try some of the examples below") gr.Examples(examples=examples, inputs=[input_img_v0], outputs=segm_img_v0, fn=run_unetv0, cache_examples=False, examples_per_page=5) # just a placeholder for second column with gr.Column(): gr.Markdown("") segment_btn_v0.click(run_unetv0, inputs=[ input_img_v0, ], outputs=segm_img_v0) # v1: UNet3+ with gr.Tab("UNet3+ (v1)"): # display input image and segmented image with gr.Row(variant="panel"): with gr.Column(scale=1): input_img_v1.render() with gr.Column(scale=1): segm_img_v1.render() # submit and clear with gr.Row(): with gr.Column(): segment_btn_v1 = gr.Button("Run Segmentation", variant='primary') clear_btn_v1 = gr.Button("Clear", variant="secondary") # load examples gr.Markdown("Try some of the examples below") gr.Examples(examples=examples, inputs=[input_img_v1], outputs=segm_img_v1, fn=run_unetv1, cache_examples=False, examples_per_page=5) # just a placeholder for second column with gr.Column(): gr.Markdown("") segment_btn_v1.click(run_unetv1, inputs=[ input_img_v1, ], outputs=segm_img_v1) # v2: UNet3+ with deep supervision with gr.Tab("UNet3+(v2) with deep supervision"): # display input image and segmented image with gr.Row(variant="panel"): with gr.Column(scale=1): input_img_v2.render() with gr.Column(scale=1): segm_img_v2.render() # submit and clear with gr.Row(): with gr.Column(): segment_btn_v2 = gr.Button("Run Segmentation", variant='primary') clear_btn_v2 = gr.Button("Clear", variant="secondary") # load examples gr.Markdown("Try some of the examples below") gr.Examples(examples=examples, inputs=[input_img_v2], outputs=segm_img_v2, fn=run_unetv2, cache_examples=False, examples_per_page=5) # just a placeholder for second column with gr.Column(): gr.Markdown("") segment_btn_v2.click(run_unetv2, inputs=[ input_img_v2, ], outputs=segm_img_v2) # v3: UNet3+ with deep supervision and cgm with gr.Tab("UNet3+(v3) with deep supervision and cgm"): # display input image and segmented image with gr.Row(variant="panel"): with gr.Column(scale=1): input_img_v3.render() with gr.Column(scale=1): segm_img_v3.render() # submit and clear with gr.Row(): with gr.Column(): segment_btn_v3 = gr.Button("Run Segmentation", variant='primary') clear_btn_v3 = gr.Button("Clear", variant="secondary") # load examples gr.Markdown("Try some of the examples below") gr.Examples(examples=examples, inputs=[input_img_v3], outputs=segm_img_v3, fn=run_unetv3, cache_examples=False, examples_per_page=5) # just a placeholder for second column with gr.Column(): gr.Markdown("") segment_btn_v3.click(run_unetv3, inputs=[ input_img_v3, ], outputs=segm_img_v3) def clear(): return None, None clear_btn_v0.click(clear, outputs=[input_img_v0, segm_img_v0]) clear_btn_v1.click(clear, outputs=[input_img_v1, segm_img_v1]) clear_btn_v2.click(clear, outputs=[input_img_v2, segm_img_v2]) clear_btn_v3.click(clear, outputs=[input_img_v3, segm_img_v3]) demo.queue() demo.launch()