import os
import copy
from PIL import Image
import matplotlib 
import numpy as np
import gradio as gr
from utils import load_mask, load_mask_edit
from utils_mask import process_mask_to_follow_priority, mask_union, visualize_mask_list_clean
from pathlib import Path
import subprocess
from PIL import Image
from functools import partial
from main import run_main
LENGTH=512 #length of the square area displaying/editing images
TRANSPARENCY = 150 # transparency of the mask in display

def add_mask(mask_np_list_updated, mask_label_list):
    mask_new = np.zeros_like(mask_np_list_updated[0])
    mask_np_list_updated.append(mask_new)
    mask_label_list.append("new")
    return mask_np_list_updated, mask_label_list

def create_segmentation(mask_np_list):
    viridis = matplotlib.pyplot.get_cmap(name = 'viridis', lut = len(mask_np_list))
    segmentation = 0
    for i, m  in enumerate(mask_np_list):
        color = matplotlib.colors.to_rgb(viridis(i))
        color_mat = np.ones_like(m)                                                                            
        color_mat = np.stack([color_mat*color[0], color_mat*color[1],color_mat*color[2] ], axis = 2)
        color_mat = color_mat * m[:,:,np.newaxis]
        segmentation += color_mat
    segmentation = Image.fromarray(np.uint8(segmentation*255))
    return segmentation

def load_mask_ui(input_folder="example_tmp",load_edit = False):
    if not load_edit:
        mask_list, mask_label_list = load_mask(input_folder)
    else:
        mask_list, mask_label_list = load_mask_edit(input_folder) 
        
    mask_np_list = []
    for  m  in mask_list:
        mask_np_list. append( m.cpu().numpy())

    return mask_np_list, mask_label_list

def load_image_ui(load_edit, input_folder="example_tmp"):
    try:
        for img_path in Path(input_folder).iterdir():
            if img_path.name in ["img_512.png"]:
                image = Image.open(img_path)
        mask_np_list, mask_label_list = load_mask_ui(input_folder, load_edit = load_edit)
        image = image.convert('RGB')
        segmentation = create_segmentation(mask_np_list)
        print("!!", len(mask_np_list))  
        return image, segmentation, mask_np_list, mask_label_list, image
    except:
        print("Image folder invalid: The folder should contain image.png")
        return None, None, None, None, None

def run_edit_text(
        num_tokens,
        num_sampling_steps,
        strength,
        edge_thickness,
        tgt_prompt,
        tgt_idx,
        guidance_scale,
        input_folder="example_tmp"
    ):
    subprocess.run(["python", 
                    "main.py" ,
                    "--text",
                    "--name={}".format(input_folder),
                    "--dpm={}".format("sd"),
                    "--resolution={}".format(512),
                    "--load_trained",
                    "--num_tokens={}".format(num_tokens),
                    "--seed={}".format(2024),
                    "--guidance_scale={}".format(guidance_scale),
                    "--num_sampling_step={}".format(num_sampling_steps),
                    "--strength={}".format(strength),
                    "--edge_thickness={}".format(edge_thickness),
                    "--num_imgs={}".format(2),
                    "--tgt_prompt={}".format(tgt_prompt) ,
                    "--tgt_index={}".format(tgt_idx)            
    ])
    
    return Image.open(os.path.join(input_folder, "text", "out_text_0.png"))


def run_optimization(
        num_tokens,
        embedding_learning_rate, 
        max_emb_train_steps, 
        diffusion_model_learning_rate, 
        max_diffusion_train_steps,
        train_batch_size,
        gradient_accumulation_steps,
        input_folder = "example_tmp"
    ):
    subprocess.run(["python", 
                    "main.py" ,
                    "--name={}".format(input_folder),
                    "--dpm={}".format("sd"),
                    "--resolution={}".format(512),
                    "--num_tokens={}".format(num_tokens),
                    "--embedding_learning_rate={}".format(embedding_learning_rate),
                    "--diffusion_model_learning_rate={}".format(diffusion_model_learning_rate),
                    "--max_emb_train_steps={}".format(max_emb_train_steps),
                    "--max_diffusion_train_steps={}".format(max_diffusion_train_steps),
                    "--train_batch_size={}".format(train_batch_size),
                    "--gradient_accumulation_steps={}".format(gradient_accumulation_steps)
                    
    ])
    return 


def transparent_paste_with_mask(backimg, foreimg, mask_np,transparency = 128):
    backimg_solid_np =  np.array(backimg)
    bimg = backimg.copy()
    fimg = foreimg.copy()
    fimg.putalpha(transparency)
    bimg.paste(fimg, (0,0), fimg)

    bimg_np = np.array(bimg)
    mask_np = mask_np[:,:,np.newaxis]

    try:
        new_img_np = bimg_np*mask_np + (1-mask_np)* backimg_solid_np
        return Image.fromarray(new_img_np) 
    except:
        import pdb; pdb.set_trace()

def show_segmentation(image, segmentation, flag):
    if flag is False:
        flag = True
        mask_np = np.ones([image.size[0],image.size[1]]).astype(np.uint8)
        image_edit = transparent_paste_with_mask(image, segmentation, mask_np ,transparency = TRANSPARENCY)
        return image_edit, flag
    else:
        flag = False
        return image,flag

def edit_mask_add(canvas,  image, idx, mask_np_list):
    mask_sel = mask_np_list[idx]
    mask_new = np.uint8(canvas["mask"][:, :, 0]/ 255.)
    mask_np_list_updated = []
    for midx, m  in enumerate(mask_np_list):
        if midx == idx:
            mask_np_list_updated.append(mask_union(mask_sel, mask_new))
        else:
            mask_np_list_updated.append(m)
    
    priority_list = [0 for _ in range(len(mask_np_list_updated))]
    priority_list[idx] = 1
    mask_np_list_updated = process_mask_to_follow_priority(mask_np_list_updated, priority_list)
    mask_ones = np.ones([mask_sel.shape[0], mask_sel.shape[1]]).astype(np.uint8)
    segmentation = create_segmentation(mask_np_list_updated)
    image_edit = transparent_paste_with_mask(image, segmentation, mask_ones ,transparency = TRANSPARENCY)
    return mask_np_list_updated, image_edit

def slider_release(index, image,  mask_np_list_updated, mask_label_list):

    if index > len(mask_np_list_updated):
        return image, "out of range"
    else:
        mask_np = mask_np_list_updated[index]
        mask_label = mask_label_list[index]
        segmentation = create_segmentation(mask_np_list_updated)
        new_image = transparent_paste_with_mask(image, segmentation, mask_np, transparency = TRANSPARENCY)
    return new_image, mask_label

def save_as_orig_mask(mask_np_list_updated, mask_label_list, input_folder="example_tmp"):
    try: 
        assert np.all(sum(mask_np_list_updated)==1)
    except:
        print("please check mask")
        # plt.imsave( "out_mask.png", mask_list_edit[0]) 
        import pdb; pdb.set_trace()
        
    for midx, (mask, mask_label) in enumerate(zip(mask_np_list_updated, mask_label_list)):
        # np.save(os.path.join(input_folder, "maskEDIT{}_{}.npy".format(midx, mask_label)),mask )
        np.save(os.path.join(input_folder, "mask{}_{}.npy".format(midx, mask_label)),mask )
    savepath = os.path.join(input_folder, "seg_current.png")
    visualize_mask_list_clean(mask_np_list_updated, savepath)
    
def save_as_edit_mask(mask_np_list_updated, mask_label_list, input_folder="example_tmp"):
    try: 
        assert np.all(sum(mask_np_list_updated)==1)
    except:
        print("please check mask")
        # plt.imsave( "out_mask.png", mask_list_edit[0]) 
        import pdb; pdb.set_trace()
    for midx, (mask, mask_label) in enumerate(zip(mask_np_list_updated, mask_label_list)):
        np.save(os.path.join(input_folder, "maskEdited{}_{}.npy".format(midx, mask_label)), mask)
    savepath = os.path.join(input_folder, "seg_edited.png")
    visualize_mask_list_clean(mask_np_list_updated, savepath)
  
  
import shutil
if os.path.isdir("./example_tmp"):
    shutil.rmtree("./example_tmp")

from segment import run_segmentation
with gr.Blocks() as demo:
    image = gr.State() # store mask
    image_loaded = gr.State()
    segmentation    = gr.State()

    mask_np_list    = gr.State([])
    mask_label_list = gr.State([])
    mask_np_list_updated = gr.State([])
    true    = gr.State(True)
    false    = gr.State(False)

    with gr.Row():
        gr.Markdown("""# D-Edit""")

    with gr.Tab(label="1 Edit mask"):
        with gr.Row():
            with gr.Column():
                canvas = gr.Image(value = "./img.png", type="numpy",  label="Draw Mask", show_label=True, height=LENGTH, width=LENGTH, interactive=True)
                
                segment_button  = gr.Button("1.1 Run segmentation")
                segment_button.click(run_segmentation, 
                        [canvas] , 
                        [] )

                text_button  = gr.Button("1.2 Load original masks")
                text_button.click(load_image_ui, 
                        [ false] , 
                        [image_loaded, segmentation,  mask_np_list, mask_label_list, canvas] )

                load_edit_button = gr.Button("1.2 Load edited masks")    
                load_edit_button.click(load_image_ui, 
                        [ true] , 
                        [image_loaded, segmentation,  mask_np_list, mask_label_list, canvas] )
                
                show_segment = gr.Checkbox(label = "Show Segmentation")
                flag = gr.State(False)
                show_segment.select(show_segmentation,
                                    [image_loaded, segmentation, flag], 
                                    [canvas, flag])
 
            # mask_np_list_updated.value = copy.deepcopy(mask_np_list.value) #!!
            mask_np_list_updated = mask_np_list
            with gr.Column():
                gr.Markdown("""<p style="text-align: center; font-size: 20px">Draw Mask</p>""")
                slider =  gr.Slider(0, 20, step=1,  interactive=True)
                label = gr.Textbox()
                slider.release(slider_release, 
                        inputs = [slider, image_loaded,   mask_np_list_updated, mask_label_list], 
                        outputs= [canvas, label]
                    )
                add_button  = gr.Button("Add")
                add_button.click( edit_mask_add, 
                        [canvas, image_loaded, slider, mask_np_list_updated] , 
                        [mask_np_list_updated, canvas]
                    )

                save_button2  = gr.Button("Set and Save as edited masks")
                save_button2.click( save_as_edit_mask, 
                        [mask_np_list_updated,  mask_label_list] , 
                        [] )  
                
                save_button  = gr.Button("Set and Save as original masks")
                save_button.click( save_as_orig_mask, 
                        [mask_np_list_updated,  mask_label_list] , 
                        [] )  
                
                back_button  = gr.Button("Back to current seg")
                back_button.click( load_mask_ui, 
                                [] , 
                                [ mask_np_list_updated,mask_label_list] )

                add_mask_button = gr.Button("Add new empty mask")    
                add_mask_button.click(add_mask, 
                        [mask_np_list_updated, mask_label_list] , 
                        [mask_np_list_updated, mask_label_list] )
                
    with gr.Tab(label="2 Optimization"):
        with gr.Row():
            
            with gr.Column():
                gr.Markdown("""<p style="text-align: center; font-size: 20px">Optimization settings (SD)</p>""")
                num_tokens = gr.Number(value="5", label="num tokens to represent each object", interactive= True)
                embedding_learning_rate = gr.Textbox(value="0.0001", label="Embedding optimization: Learning rate", interactive= True )
                max_emb_train_steps =  gr.Number(value="200", label="embedding optimization: Training steps", interactive= True )
                
                diffusion_model_learning_rate = gr.Textbox(value="0.00005", label="UNet Optimization: Learning rate", interactive= True )
                max_diffusion_train_steps = gr.Number(value="200", label="UNet Optimization: Learning rate: Training steps", interactive= True )
                
                train_batch_size = gr.Number(value="5", label="Batch size", interactive= True )
                gradient_accumulation_steps=gr.Number(value="5", label="Gradient accumulation", interactive= True )
                
                add_button  = gr.Button("Run optimization")
                def run_optimization_wrapper (                        
                        num_tokens,
                        embedding_learning_rate , 
                        max_emb_train_steps , 
                        diffusion_model_learning_rate , 
                        max_diffusion_train_steps,
                        train_batch_size,
                        gradient_accumulation_steps
                ):
                    run_optimization = partial(
                        run_main,                 
                        num_tokens=int(num_tokens),
                        embedding_learning_rate = float(embedding_learning_rate), 
                        max_emb_train_steps = int(max_emb_train_steps), 
                        diffusion_model_learning_rate= float(diffusion_model_learning_rate), 
                        max_diffusion_train_steps = int(max_diffusion_train_steps),
                        train_batch_size=int(train_batch_size),
                        gradient_accumulation_steps=int(gradient_accumulation_steps)
                    )
                    run_optimization()
                    
                add_button.click(run_optimization_wrapper, 
                        inputs = [
                            num_tokens,
                            embedding_learning_rate , 
                            max_emb_train_steps , 
                            diffusion_model_learning_rate , 
                            max_diffusion_train_steps,
                            train_batch_size,
                            gradient_accumulation_steps
                        ], 
                        outputs = []
                )


    with gr.Tab(label="3 Editing"):
        with gr.Tab(label="3.1 Text-based editing"):
        
            with gr.Row():
                with gr.Column():
                    canvas_text_edit = gr.Image(value = None, type = "pil", label="Editing results", show_label=True)
                    # canvas_text_edit = gr.Gallery(label = "Edited results")
                
                with gr.Column():
                    gr.Markdown("""<p style="text-align: center; font-size: 20px">Editing setting (SD)</p>""")
                    
                    tgt_prompt =  gr.Textbox(value="White bag", label="Editing: Text prompt", interactive= True )
                    tgt_index = gr.Number(value="0", label="Editing: Object index", interactive= True )
                    guidance_scale = gr.Textbox(value="6", label="Editing: CFG guidance scale", interactive= True )
                    num_sampling_steps = gr.Number(value="50", label="Editing: Sampling steps", interactive= True )
                    edge_thickness = gr.Number(value="10", label="Editing: Edge thickness", interactive= True )
                    strength = gr.Textbox(value="0.5", label="Editing: Mask strength", interactive= True )
                    
                    add_button  = gr.Button("Run Editing")
                    run_edit_text = partial(
                        run_main,
                        load_trained=True,
                        text=True,
                        num_tokens = int(num_tokens.value),
                        guidance_scale = float(guidance_scale.value),
                        num_sampling_steps = int(num_sampling_steps.value),
                        strength = float(strength.value),
                        edge_thickness = int(edge_thickness.value),
                        num_imgs = 1,
                        tgt_prompt = tgt_prompt.value,
                        tgt_index = int(tgt_index.value)
                    )
                        
                    add_button.click(run_edit_text, 
                        inputs = [], 
                        outputs = [canvas_text_edit]
                    )
                    
                    def load_pil_img():
                        from PIL import Image
                        return Image.open("example_tmp/text/out_text_0.png")
                    
                    load_button  = gr.Button("Load results")
                    load_button.click(load_pil_img, 
                        inputs = [], 
                        outputs = [canvas_text_edit]
                    )




demo.queue().launch(share=True, debug=True)