import torch
import matplotlib.pyplot as plt

from monai.networks.nets import SegResNet
from monai.inferers import sliding_window_inference

from monai.transforms import (
    Activations,
    AsDiscrete,
    Compose,
)

model = SegResNet(
    blocks_down=[1, 2, 2, 4],
    blocks_up=[1, 1, 1],
    init_filters=16,
    in_channels=4,
    out_channels=3,
    dropout_prob=0.2,
)

model.load_state_dict(
    torch.load("weights/model.pt", map_location=torch.device('cpu'))
)

# define inference method
VAL_AMP = True

def inference(input):

    def _compute(input):
        return sliding_window_inference(
            inputs=input,
            roi_size=(240, 240, 160),
            sw_batch_size=1,
            predictor=model,
            overlap=0.5,
        )

    if VAL_AMP:
        with torch.cuda.amp.autocast():
            return _compute(input)
    else:
        return _compute(input)


post_trans = Compose(
    [Activations(sigmoid=True), AsDiscrete(threshold=0.5)]
)

import gradio as gr

def load_sample1():
    return load_sample(1)

def load_sample2():
    return load_sample(2)

def load_sample3():
    return load_sample(3)

def load_sample4():
    return load_sample(4)

def load_sample5():
    return load_sample(5)

def load_sample6():
    return load_sample(6)

def load_sample7():
    return load_sample(7)

def load_sample8():
    return load_sample(8)

import torchvision

def load_sample(index):
    #sample_index = index

    image_filenames = []
    for i in range(4):
        image_filenames.append(f"thumbnails/image{index-1}_{i}.png")

    label_filenames = []
    for i in range(3):
        label_filenames.append(f"thumbnails_label/label{index-1}_{i}.png")

    return [index, image_filenames[0], image_filenames[1], image_filenames[2], image_filenames[3],
            label_filenames[0], label_filenames[1], label_filenames[2]]


def predict(sample_index):
    sample = torch.load(f"samples/val{sample_index-1}.pt")
    model.eval()
    with torch.no_grad():
        # select one image to evaluate and visualize the model output
        val_input = sample["image"].unsqueeze(0)
        roi_size = (128, 128, 64)
        sw_batch_size = 4
        val_output = inference(val_input)
        val_output = post_trans(val_output[0])

    imgs_output = []
    for i in range(3):
        imgs_output.append(val_output[i, :, :, 70])

    pil_images_output = []
    for i in range(3):
        pil_images_output.append(torchvision.transforms.functional.to_pil_image(imgs_output[i]))

    return [pil_images_output[0], pil_images_output[1], pil_images_output[2]]

with gr.Blocks(title="Brain tumor 3D segmentation with MONAI - ClassCat",
                    css=".gradio-container {background:azure;}"
               ) as demo:
    sample_index = gr.State([])

    gr.HTML("""<div style="font-family:'Times New Roman', 'Serif'; font-size:16pt; font-weight:bold; text-align:center; color:royalblue;">Brain tumor 3D segmentation with MONAI</div>""")

    gr.HTML("""<h4 style="color:navy;">1. Select an example, which includes input images and label images, by clicking "Example x" button.</h4>""")

    with gr.Row():
        input_image0 = gr.Image(label="image channel 0", type="filepath", shape=(240, 240))
        input_image1 = gr.Image(label="image channel 1", type="filepath", shape=(240, 240))
        input_image2 = gr.Image(label="image channel 2", type="filepath", shape=(240, 240))
        input_image3 = gr.Image(label="image channel 3", type="filepath", shape=(240, 240))

    with gr.Row():
        label_image0 = gr.Image(label="label channel 0", type="filepath", shape=(240, 240))
        label_image1 = gr.Image(label="label channel 1", type="filepath", shape=(240, 240))
        label_image2 = gr.Image(label="label channel 2", type="filepath", shape=(240, 240))

    with gr.Row():
        example1_btn = gr.Button("Example 1")
        example2_btn = gr.Button("Example 2")
        example3_btn = gr.Button("Example 3")
        example4_btn = gr.Button("Example 4")
        example5_btn = gr.Button("Example 5")
        example6_btn = gr.Button("Example 6")
        example7_btn = gr.Button("Example 7")
        example8_btn = gr.Button("Example 8")

        example1_btn.click(fn=load_sample1, inputs=None, 
                           outputs=[sample_index, input_image0, input_image1, input_image2, input_image3,
                                    label_image0, label_image1, label_image2])
        example2_btn.click(fn=load_sample2, inputs=None, 
                           outputs=[sample_index, input_image0, input_image1, input_image2, input_image3,
                                    label_image0, label_image1, label_image2])
        example3_btn.click(fn=load_sample3, inputs=None, 
                           outputs=[sample_index, input_image0, input_image1, input_image2, input_image3,
                                    label_image0, label_image1, label_image2])
        example4_btn.click(fn=load_sample4, inputs=None, 
                           outputs=[sample_index, input_image0, input_image1, input_image2, input_image3,
                                    label_image0, label_image1, label_image2])
        example5_btn.click(fn=load_sample5, inputs=None, 
                           outputs=[sample_index, input_image0, input_image1, input_image2, input_image3,
                                    label_image0, label_image1, label_image2])
        example6_btn.click(fn=load_sample6, inputs=None, 
                           outputs=[sample_index, input_image0, input_image1, input_image2, input_image3,
                                    label_image0, label_image1, label_image2])
        example7_btn.click(fn=load_sample7, inputs=None, 
                           outputs=[sample_index, input_image0, input_image1, input_image2, input_image3,
                                    label_image0, label_image1, label_image2])
        example8_btn.click(fn=load_sample8, inputs=None, 
                           outputs=[sample_index, input_image0, input_image1, input_image2, input_image3,
                                    label_image0, label_image1, label_image2])

    gr.HTML("""<br/>""")
    gr.HTML("""<h4 style="color:navy;">2. Then, click "Infer" button to predict segmentation images. It will take about 30 seconds (on cpu)</h4>""")

    with gr.Row():
        output_image0 = gr.Image(label="output channel 0", type="pil")
        output_image1 = gr.Image(label="output channel 1", type="pil")
        output_image2 = gr.Image(label="output channel 2", type="pil")
    
    send_btn = gr.Button("Infer")
    send_btn.click(fn=predict, inputs=[sample_index], outputs=[output_image0, output_image1, output_image2])

    gr.HTML("""<br/>""")
    gr.HTML("""<h4 style="color:navy;">Reference</h4>""")
    gr.HTML("""<ul>""")
    gr.HTML("""<li><a href="https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/brats_segmentation_3d.ipynb" target="_blank">Brain tumor 3D segmentation with MONAI</a></li>""")
    gr.HTML("""</ul>""")


#demo.queue()
demo.launch(debug=True)



### EOF  ###