import os
import copy
#import spaces
from main import run_main
from PIL import Image
import matplotlib 
import numpy as np
import gradio as gr
from utils import load_mask, load_mask_edit
from utils_mask import process_mask_to_follow_priority, mask_union, visualize_mask_list_clean
from pathlib import Path
from PIL import Image
from functools import partial
import time

LENGTH=512 #length of the square area displaying/editing images
TRANSPARENCY = 150 # transparency of the mask in display


def add_mask(mask_np_list_updated, mask_label_list):
    mask_new = np.zeros_like(mask_np_list_updated[0])
    mask_np_list_updated.append(mask_new)
    mask_label_list.append("new")
    return mask_np_list_updated, mask_label_list

def create_segmentation(mask_np_list):
    viridis = matplotlib.pyplot.get_cmap(name = 'viridis', lut = len(mask_np_list))
    segmentation = 0
    for i, m  in enumerate(mask_np_list):
        color = matplotlib.colors.to_rgb(viridis(i))
        color_mat = np.ones_like(m)                                                                            
        color_mat = np.stack([color_mat*color[0], color_mat*color[1],color_mat*color[2] ], axis = 2)
        color_mat = color_mat * m[:,:,np.newaxis]
        segmentation += color_mat
    segmentation = Image.fromarray(np.uint8(segmentation*255))
    return segmentation


#@spaces.GPU
def run_segmentation_wrapper(image):
    try:
        print(image.shape)
        image, mask_np_list,mask_label_list = run_segmentation(image)
        #image = image.convert('RGB')
        segmentation = create_segmentation(mask_np_list)
        print("!!", len(mask_np_list))  
        max_val = len(mask_np_list)-1
        sliderup = gr.Slider(value = 0, minimum=0, maximum=max_val, step=1, visible=True)
        gr.Info('Segmentation finish. Select mask id and move to the next step.')
        return image, segmentation, mask_np_list, mask_label_list, image, sliderup, sliderup , 'Segmentation finish. Select mask id and move to the next step.'
    except Exception as e:
        print(e)
        sliderup = gr.Slider(value = 0, minimum=0, maximum=1, step=1, visible=False)
        gr.Warning('Please upload an image before proceeding.')
        return None,None,None,None,None, sliderup, sliderup , 'Please upload an image before proceeding.'
        

def transparent_paste_with_mask(backimg, foreimg, mask_np,transparency = 128):
    backimg_solid_np =  np.array(backimg)
    bimg = backimg.copy()
    fimg = foreimg.copy()
    fimg.putalpha(transparency)
    bimg.paste(fimg, (0,0), fimg)

    bimg_np = np.array(bimg)
    mask_np = mask_np[:,:,np.newaxis]

    new_img_np = bimg_np*mask_np + (1-mask_np)* backimg_solid_np
    return Image.fromarray(np.uint8(new_img_np))

def show_segmentation(image, segmentation, flag):
    if flag is False:
        flag = True
        mask_np = np.ones([image.size[0],image.size[1]]).astype(np.uint8)
        image_edit = transparent_paste_with_mask(image, segmentation, mask_np ,transparency = TRANSPARENCY)
        return image_edit, flag
    else:
        flag = False
        return image,flag

def edit_mask_add(canvas,  image, idx, mask_np_list):
    mask_sel = mask_np_list[idx]
    mask_new = np.uint8(canvas["mask"][:, :, 0]/ 255.)
    mask_np_list_updated = []
    for midx, m  in enumerate(mask_np_list):
        if midx == idx:
            mask_np_list_updated.append(mask_union(mask_sel, mask_new))
        else:
            mask_np_list_updated.append(m)
    
    priority_list = [0 for _ in range(len(mask_np_list_updated))]
    priority_list[idx] = 1
    mask_np_list_updated = process_mask_to_follow_priority(mask_np_list_updated, priority_list)
    mask_ones = np.ones([mask_sel.shape[0], mask_sel.shape[1]]).astype(np.uint8)
    segmentation = create_segmentation(mask_np_list_updated)
    image_edit = transparent_paste_with_mask(image, segmentation, mask_ones ,transparency = TRANSPARENCY)
    return mask_np_list_updated, image_edit

def slider_release(index, image,  mask_np_list_updated, mask_label_list):
    if index > len(mask_np_list_updated)-1:
        return image, "out of range", ""
    else:
        mask_np = mask_np_list_updated[index]
        mask_label = mask_label_list[index]
        index = mask_label.rfind('-')
        mask_label = mask_label[:index]
        if mask_label == 'handbag':
            mask_prompt = "white handbag"
        elif mask_label == 'person':
            mask_prompt = "little boy"
        elif mask_label == 'wall-other-merged':
            mask_prompt = "white wall"
        elif mask_label == 'table-merged':
            mask_prompt = "table"
        else:
            mask_prompt = mask_label
        segmentation = create_segmentation(mask_np_list_updated)
        new_image = transparent_paste_with_mask(image, segmentation, mask_np, transparency = TRANSPARENCY)
    gr.Info('Edit '+ mask_label)
    return new_image, mask_label, mask_prompt
def image_change():
    return gr.Slider(value = 0, minimum=0, maximum=1, step=1, visible=False)

def save_as_orig_mask(mask_np_list_updated, mask_label_list, input_folder="example_tmp"):
    print(mask_np_list_updated)
    try: 
        assert np.all(sum(mask_np_list_updated)==1)
    except:
        print("please check mask")
        # plt.imsave( "out_mask.png", mask_list_edit[0]) 
        import pdb; pdb.set_trace()
        
    for midx, (mask, mask_label) in enumerate(zip(mask_np_list_updated, mask_label_list)):
        # np.save(os.path.join(input_folder, "maskEDIT{}_{}.npy".format(midx, mask_label)),mask )
        np.save(os.path.join(input_folder, "mask{}_{}.npy".format(midx, mask_label)),mask )
    savepath = os.path.join(input_folder, "seg_current.png")
    visualize_mask_list_clean(mask_np_list_updated, savepath)
    
def save_as_edit_mask(mask_np_list_updated, mask_label_list, input_folder="example_tmp"):
    print(mask_np_list_updated)
    try: 
        assert np.all(sum(mask_np_list_updated)==1)
    except:
        print("please check mask")
        # plt.imsave( "out_mask.png", mask_list_edit[0]) 
        import pdb; pdb.set_trace()
    for midx, (mask, mask_label) in enumerate(zip(mask_np_list_updated, mask_label_list)):
        np.save(os.path.join(input_folder, "maskEdited{}_{}.npy".format(midx, mask_label)), mask)
    savepath = os.path.join(input_folder, "seg_edited.png")
    visualize_mask_list_clean(mask_np_list_updated, savepath)
  


def button_clickable(is_clickable):
    return gr.Button(interactive=is_clickable)



def load_pil_img():
    from PIL import Image
    return Image.open("example_tmp/text/out_text_0.png")

def change_image(img):
    return None
                

import shutil
if os.path.isdir("./example_tmp"):
    shutil.rmtree("./example_tmp")




from segment import run_segmentation

with gr.Blocks() as demo:
    image = gr.State() # store mask
    image_loaded = gr.State()
    segmentation    = gr.State()

    mask_np_list    = gr.State([])
    mask_label_list = gr.State([])
    mask_np_list_updated = gr.State([])
    true    = gr.State(True)
    false    = gr.State(False)
    block_flag = gr.State(0)
    num_tokens_global = gr.State(5)
    with gr.Row():
        gr.Markdown("""# D-Edit""")


    with gr.Row():
        with gr.Column():
            canvas = gr.Image(value = None, type="numpy",  label="Show Mask", show_label=True, height=LENGTH, width=LENGTH, interactive=True)
            example_inps = [['./img.png'],['./img2.png'],['./img3.png'],['./img4.png']]
            gr.Examples(examples=example_inps, inputs=[canvas],
                        label='examples', cache_examples='lazy', outputs=[],
                        fn=change_image)
            gr.Markdown(f"Each image must first undergo segmentation. Afterwards, you can modify the \n mask ID and the prompt for image editing, then proceed with the editing process. \n The link of D-edit paper: [https://arxiv.org/abs/2403.04880v2](https://arxiv.org/abs/2403.04880v2), [https://huggingface.co/papers/2403.04880](https://huggingface.co/papers/2403.04880)")

        with gr.Column():
            result_info0 = gr.Text(label="Response")
            segment_button  = gr.Button("Step 1. Run segmentation")
            flag = gr.State(False)

            # mask_np_list_updated.value = copy.deepcopy(mask_np_list.value) #!!
            mask_np_list_updated = mask_np_list
            gr.Markdown("""<p style="text-align: center; font-size: 20px">Edit Mask (Do not change it during the editing process)</p>""")
            slider =  gr.Slider(0, 20, step=1, label = 'mask id',  visible=False)
            label = gr.Text(label='label')

        
            

            result_info = gr.Text(label="Response")
            
            opt_flag = gr.State(0)
            gr.Markdown("""<p style="text-align: center; font-size: 20px">Optimization settings</p>""")
            with gr.Accordion(label="Advanced settings", open=False):
                num_tokens = gr.Number(value="5", label="num tokens to represent each object", interactive= True)
                num_tokens_global = num_tokens
                embedding_learning_rate = gr.Textbox(value="0.00025", label="Embedding optimization: Learning rate", interactive= True )
                max_emb_train_steps =  gr.Number(value="6", label="embedding optimization: Training steps", interactive= True )
                
                diffusion_model_learning_rate = gr.Textbox(value="0.0002", label="UNet Optimization: Learning rate", interactive= True )
                max_diffusion_train_steps = gr.Number(value="28", label="UNet Optimization: Learning rate: Training steps", interactive= True )
                
                train_batch_size = gr.Number(value="20", label="Batch size", interactive= True )
                gradient_accumulation_steps=gr.Number(value="2", label="Gradient accumulation", interactive= True )
            
            def run_optimization_wrapper (
                    mask_np_list,
                    mask_label_list,
                    image,
                    opt_flag,                     
                    num_tokens,
                    embedding_learning_rate , 
                    max_emb_train_steps , 
                    diffusion_model_learning_rate , 
                    max_diffusion_train_steps,
                    train_batch_size,
                    gradient_accumulation_steps,
            ):
                try:
                    run_optimization = partial(
                        run_main,  
                        mask_np_list=mask_np_list, 
                        mask_label_list=mask_label_list,
                        image_gt=np.array(image),
                        num_tokens=int(num_tokens),
                        embedding_learning_rate = float(embedding_learning_rate), 
                        max_emb_train_steps = int(max_emb_train_steps), 
                        diffusion_model_learning_rate= float(diffusion_model_learning_rate), 
                        max_diffusion_train_steps = int(max_diffusion_train_steps),
                        train_batch_size=int(train_batch_size),
                        gradient_accumulation_steps=int(gradient_accumulation_steps)
                    )
                    run_optimization()
                    gr.Info("Optimization Finished! Move to the next step.")
                    return "Optimization finished! Move to the next step."#,gr.Button("Step 3. Run Editing",interactive = True)
                except Exception as e:
                    print(e)
                    gr.Error("e")
                    return "Error: use a smaller batch size or try latter."#,gr.Button("Step 3. Run Editing",interactive = False)



    if 1:
        with gr.Row():
            with gr.Column():
                canvas_text_edit = gr.Image(value = None, type = "pil", label="Editing results", show_label=True,visible = True)
                # canvas_text_edit = gr.Gallery(label = "Edited results")
            
            with gr.Column():
                gr.Markdown("""<p style="text-align: center; font-size: 20px">Editing setting</p>""")
                tgt_prompt =  gr.Textbox(value="text prompt", label="Editing: Text prompt", interactive= True )
                with gr.Accordion(label="Advanced settings", open=False):
                    slider2 = gr.Slider(0, 20, step=1, label = 'mask id',  visible=False)
                    guidance_scale = gr.Textbox(value="5", label="Editing: CFG guidance scale", interactive= True )
                    num_sampling_steps = gr.Number(value="20", label="Editing: Sampling steps", interactive= True )
                    edge_thickness = gr.Number(value="10", label="Editing: Edge thickness", interactive= True )
                    strength = gr.Textbox(value="0.5", label="Editing: Mask strength", interactive= True )
                
                add_button  = gr.Button("Step 2. Run Editing",interactive = True)
                def run_edit_text_wrapper(
                        mask_np_list,
                        mask_label_list,
                        image,
                        num_tokens,
                        guidance_scale,
                        num_sampling_steps ,
                        strength ,
                        edge_thickness,
                        tgt_prompt ,
                        tgt_index
                ):
                        
                    run_edit_text = partial(
                        run_main,
                        mask_np_list=mask_np_list, 
                        mask_label_list=mask_label_list,
                        image_gt=np.array(image),
                        load_trained=True,
                        text=True,
                        num_tokens = int(num_tokens_global.value),
                        guidance_scale = float(guidance_scale),
                        num_sampling_steps = int(num_sampling_steps),
                        strength = float(strength),
                        edge_thickness = int(edge_thickness),
                        num_imgs = 1,
                        tgt_prompt = tgt_prompt,
                        tgt_index = int(tgt_index)
                    )
                    run_edit_text()
                    gr.Info('Image editing completed.')
                    return load_pil_img()



            def run_total_wrapper(mask_np_list, mask_label_list, image_loaded, opt_flag, num_tokens, embedding_learning_rate, max_emb_train_steps, diffusion_model_learning_rate, max_diffusion_train_steps, train_batch_size, gradient_accumulation_steps, num_tokens_global, guidance_scale, num_sampling_steps, strength, edge_thickness, tgt_prompt, slider2):
                result_info = run_optimization_wrapper(mask_np_list, mask_label_list, image_loaded, opt_flag, num_tokens, embedding_learning_rate, max_emb_train_steps, diffusion_model_learning_rate, max_diffusion_train_steps, train_batch_size, gradient_accumulation_steps)
                canvas_text_edit = run_edit_text_wrapper(mask_np_list, mask_label_list, image_loaded, num_tokens_global, guidance_scale, num_sampling_steps, strength, edge_thickness, tgt_prompt, slider2)
                return result_info, canvas_text_edit


            add_button.click(
                run_total_wrapper, 
                inputs=[
                    mask_np_list,
                    mask_label_list,
                    image_loaded,
                    opt_flag,
                    num_tokens,
                    embedding_learning_rate,
                    max_emb_train_steps,
                    diffusion_model_learning_rate,
                    max_diffusion_train_steps,
                    train_batch_size,
                    gradient_accumulation_steps,
                    num_tokens_global,
                    guidance_scale,
                    num_sampling_steps,
                    strength,
                    edge_thickness,
                    tgt_prompt,
                    slider2
                ],
                outputs=[result_info, canvas_text_edit],
            )




        canvas.upload(image_change, inputs=[], outputs=[slider])       

        slider.release(slider_release, 
                        inputs = [slider, image_loaded,   mask_np_list_updated, mask_label_list], 
                        outputs= [canvas, label,tgt_prompt])

        slider.change(
            lambda x: x,
            inputs=[slider],
            outputs=[slider2]
        )


        segment_button.click(run_segmentation_wrapper, 
                [canvas] ,
                [image_loaded, segmentation,  mask_np_list, mask_label_list, canvas, slider, slider2, result_info0] )



demo.queue().launch(debug=True)