import os
import numpy as np
from matplotlib import cm
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import torch
from utils import myroll2d

def create_outer_edge_mask_torch(mask, edge_thickness = 20):
    mask_down = myroll2d(mask, edge_thickness, 0 )
    mask_edge_down = (mask_down.to(torch.float) -mask.to(torch.float))>0

    mask_up  = myroll2d(mask, -edge_thickness, 0)
    mask_edge_up = (mask_up.to(torch.float) -mask.to(torch.float))>0

    mask_left  = myroll2d(mask, 0, -edge_thickness)
    mask_edge_left = (mask_left.to(torch.float) -mask.to(torch.float))>0

    mask_right  = myroll2d(mask, 0, edge_thickness)
    mask_edge_right = (mask_right.to(torch.float) -mask.to(torch.float))>0

    mask_ur =  myroll2d(mask, -edge_thickness,edge_thickness)
    mask_edge_ur = (mask_ur.to(torch.float) -mask.to(torch.float))>0

    mask_ul =  myroll2d(mask, -edge_thickness,-edge_thickness)
    mask_edge_ul = (mask_ul.to(torch.float) -mask.to(torch.float))>0
    
    mask_dr =  myroll2d(mask, edge_thickness,edge_thickness )
    mask_edge_dr = (mask_dr.to(torch.float) -mask.to(torch.float))>0

    mask_dl =  myroll2d(mask, edge_thickness,-edge_thickness)
    mask_edge_ul = (mask_dl.to(torch.float) -mask.to(torch.float))>0

    mask_edge = mask_union_torch(mask_edge_down, mask_edge_up, mask_edge_left, mask_edge_right,
                            mask_edge_ur, mask_edge_ul, mask_edge_dr, mask_edge_ul)
    return mask_edge

def mask_substract_torch(mask1, mask2):
    return ((mask1.cpu().to(torch.float)-mask2.cpu().to(torch.float))>0).to(torch.uint8)

def check_mask_overlap_torch(*masks):
    assert torch.any(sum([m.float() for m in masks])<=1 )
    
def check_mask_overlap_numpy(*masks):
    assert np.all(sum([m.astype(float) for m in masks])<=1 )
    
def check_cover_all_torch (*masks):     
    assert torch.all(sum([m.cpu().float() for m in masks])==1)
    
def process_mask_to_follow_priority(mask_list, priority_list):
    for idx1, (m1 , p1) in enumerate(zip(mask_list, priority_list)):
        for idx2, (m2 , p2) in enumerate(zip(mask_list, priority_list)):
            if p2 > p1:
                mask_list[idx1] = ((m1.astype(float)-m2.astype(float))>0).astype(np.uint8)
    return mask_list

def mask_union(*masks):
    masks = [m.astype(float) for m in masks]
    res = sum(masks)>0
    return res.astype(np.uint8)

def mask_intersection(mask1, mask2):
    mask_uni =  mask_union(mask1, mask2)
    mask_intersec = ((mask1.astype(float)-mask2.astype(float))==0) * mask_uni
    return mask_intersec

def mask_union_torch(*masks):
    masks = [m.float() for m in masks]
    res = sum(masks)>0
    return res.to(torch.uint8)

def mask_intersection_torch(mask1, mask2):
    mask_uni =  mask_union_torch(mask1, mask2)
    mask_intersec = ((mask1.float()-mask2.float())==0) * mask_uni
    return mask_intersec.cpu().to(torch.uint8)


def visualize_mask_list(mask_list, savepath):
    mask = 0
    for midx, m in enumerate(mask_list):
        try:
            mask += m.astype(float)* midx
        except:
            mask += m.float()*midx 
    viridis = cm.get_cmap('viridis', len(mask_list))
    fig, ax = plt.subplots()
    ax.imshow( mask)

    handles = []
    label_list = []
    for idx , _ in enumerate(mask_list):
        color = viridis(idx)
        label = f"{idx}"
        handles.append(mpatches.Patch(color=color, label=label))
        label_list.append(label)
    ax.legend(handles=handles)
    plt.savefig(savepath)

def visualize_mask_list_clean(mask_list, savepath):
    mask = 0
    for midx, m in enumerate(mask_list):
        try:
            mask += m.astype(float)* midx
        except:
            mask += m.float()*midx 
    viridis = cm.get_cmap('viridis', len(mask_list))
    fig, ax = plt.subplots()
    ax.imshow( mask)

    handles = []
    label_list = []
    for idx , _ in enumerate(mask_list):
        color = viridis(idx)
        label = f"{idx}"
        handles.append(mpatches.Patch(color=color, label=label))
        label_list.append(label)
    # ax.legend(handles=handles)
    plt.savefig(savepath,  dpi=500)

   
def move_mask(mask_select, delta_x, delta_y):
    mask_edit = myroll2d(mask_select, delta_y, delta_x)
    return mask_edit

def stack_mask_with_priority (mask_list_np, priority_list, edit_idx_list):
    mask_sel = mask_union(*[mask_list_np[eid] for eid in edit_idx_list])
    for midx, mask in enumerate(mask_list_np):
        if midx not in edit_idx_list:
            if priority_list[edit_idx_list[0]] >= priority_list[midx]:
                mask = mask.astype(float) - np.logical_and(mask.astype(bool) , mask_sel.astype(bool)).astype(float)
            mask_list_np[midx] = mask.astype("uint8")
    for midx  in edit_idx_list:
        for midx_1 in edit_idx_list:
            if midx != midx_1:
                if priority_list[midx] <= priority_list[midx_1]:
                    mask = mask_list_np[midx].astype(float) - np.logical_and(mask_list_np[midx].astype(bool), mask_list_np[midx_1].astype(bool)).astype(float)
                    mask_list_np[midx] = mask.astype("uint8")    
    return mask_list_np

def process_remain_mask(mask_list, edit_idx_list = None, force_mask_remain = None):
    print("Start to process remaining mask using nearest neighbor")
    width = mask_list[0].shape[0]
    height = mask_list[0].shape[1]
    pixel_ind = np.arange( width* height)
    
    y_axis = np.arange(width)
    ymesh = np.repeat(y_axis[:,np.newaxis], height, axis = 1) #N, N
    ymesh_vec = ymesh.reshape(-1)                           #N *N
    
    x_axis = np.arange(height)
    xmesh = np.repeat(x_axis[np.newaxis, : ], width, axis = 0)
    xmesh_vec = xmesh.reshape(-1)
    
    mask_remain = (1 - sum([m.astype(float) for m in mask_list])).astype(np.uint8)
    if force_mask_remain is not None:
        mask_list[force_mask_remain] = (mask_list[force_mask_remain].astype(float) + mask_remain.astype(float)).astype(np.uint8)
    else:
        if edit_idx_list is not None:
            a = [mask_list[eidx] for eidx in edit_idx_list]
            mask_edit = mask_union(*a)
        else:
            mask_edit = np.zeros_like(mask_remain).astype(np.uint8)
        mask_feasible = (1 - mask_remain.astype(float) - mask_edit.astype(float)).astype(np.uint8)

        edge_width = 2

        mask_feasible_down  = myroll2d(mask_feasible, edge_width, 0)
        mask_edge_down = (mask_feasible_down.astype(float) -mask_feasible.astype(float))<0

        mask_feasible_up  = myroll2d(mask_feasible, -edge_width, 0)
        mask_edge_up = (mask_feasible_up.astype(float) -mask_feasible.astype(float))<0

        mask_feasible_left  = myroll2d(mask_feasible, 0, -edge_width)
        mask_edge_left = (mask_feasible_left.astype(float) -mask_feasible.astype(float))<0

        mask_feasible_right  = myroll2d(mask_feasible, 0, edge_width)
        mask_edge_right = (mask_feasible_right.astype(float) -mask_feasible.astype(float))<0

        mask_feasible_ur =  myroll2d(mask_feasible, -edge_width,edge_width)
        mask_edge_ur = (mask_feasible_ur.astype(float) -mask_feasible.astype(float))<0

        mask_feasible_ul =  myroll2d(mask_feasible, -edge_width,-edge_width )
        mask_edge_ul = (mask_feasible_ul.astype(float) -mask_feasible.astype(float))<0

        mask_feasible_dr =  myroll2d(mask_feasible, edge_width,edge_width )
        mask_edge_dr = (mask_feasible_dr.astype(float) -mask_feasible.astype(float))<0

        mask_feasible_dl =  myroll2d(mask_feasible, edge_width,-edge_width)
        mask_edge_ul = (mask_feasible_dl.astype(float) -mask_feasible.astype(float))<0
            
        mask_edge = mask_union(
            mask_edge_down, mask_edge_up, mask_edge_left, mask_edge_right, mask_edge_ur, mask_edge_ul, mask_edge_dr, mask_edge_ul
        )
        
        mask_feasible_edge = mask_intersection(mask_edge, mask_feasible)
        
        vec_mask_feasible_edge = mask_feasible_edge.reshape(-1)
        vec_mask_remain        = mask_remain.reshape(-1)

        indvec_all = np.arange(width*height)
        vec_region_partition= 0
        for mask_idx, mask in enumerate(mask_list):
            vec_region_partition += mask.reshape(-1) * mask_idx
        vec_region_partition += mask_remain.reshape(-1) * mask_idx
        # assert 0 in vec_region_partition
        
        vec_ind_remain = np.nonzero(vec_mask_remain)[0]
        vec_ind_feasible_edge = np.nonzero(vec_mask_feasible_edge)[0]    
        
        vec_x_remain = xmesh_vec[vec_ind_remain]
        vec_y_remain = ymesh_vec[vec_ind_remain]

        vec_x_feasible_edge =  xmesh_vec[vec_ind_feasible_edge]
        vec_y_feasible_edge =  ymesh_vec[vec_ind_feasible_edge]

        x_dis = vec_x_remain[:,np.newaxis] - vec_x_feasible_edge[np.newaxis,:]
        y_dis = vec_y_remain[:,np.newaxis] - vec_y_feasible_edge[np.newaxis,:]
        dis = x_dis **2 + y_dis **2
        pos = np.argmin(dis, axis = 1)
        nearest_point = vec_ind_feasible_edge[pos]   # closest point to target point
        
        nearest_region = vec_region_partition[nearest_point]
        nearest_region_set = set(nearest_region)
        if edit_idx_list is not None:
            for edit_idx in edit_idx_list:
                assert edit_idx not in nearest_region

        for midx, m in enumerate(mask_list):
            if midx in nearest_region_set:
                vec_newmask = np.zeros_like(indvec_all)
                add_ind = vec_ind_remain [np.argwhere(nearest_region==midx)]
                vec_newmask[add_ind] = 1
            
                mask_list[midx] = mask_list[midx].astype(float)+ vec_newmask.reshape( mask_list[midx].shape).astype(float)
                mask_list[midx] = mask_list[midx] > 0
    
    print("Finish processing remaining mask, if you want to edit, launch the ui")
    return mask_list, mask_remain
       
def resize_mask(mask_np, resize_ratio = 1):
    w, h = mask_np.shape[0],  mask_np.shape[1]
    resized_w, resized_h = int(w*resize_ratio),int(h*resize_ratio) 
    mask_resized = torch.nn.functional.interpolate(torch.from_numpy(mask_np).unsqueeze(0).unsqueeze(0), (resized_w, resized_h)).squeeze()

    mask = torch.zeros(w,  h)
    if w > resized_w:
        mask[:resized_w, :resized_h] = mask_resized
    else:
        assert h <= resized_h 
        mask = mask_resized[resized_w//2-w//2: resized_w//2-w//2+w, resized_h//2-h//2: resized_h//2-h//2+h]
    return mask.cpu().numpy().astype(np.uint8)

def process_mask_move_torch(
        mask_list,
        move_index_list,
        delta_x_list = None,
        delta_y_list = None, 
        edit_priority_list = None,
        force_mask_remain = None,
        resize_list = None
    ):
    mask_list_np = [m.cpu().numpy() for m in mask_list]
    priority_list = [0 for _ in range(len(mask_list_np))]
    for idx, (move_index, delta_x, delta_y, priority) in enumerate(zip(move_index_list, delta_x_list, delta_y_list, edit_priority_list)):
        priority_list[move_index] = priority
        if resize_list is not None:
            mask = resize_mask (mask_list_np[move_index], resize_list[idx])
        else:
            mask = mask_list_np[move_index]
        mask_list_np[move_index] = move_mask(mask,  delta_x = delta_x, delta_y = delta_y)
    mask_list_np = stack_mask_with_priority (mask_list_np, priority_list, move_index_list) # exists blank
    check_mask_overlap_numpy(*mask_list_np)
    mask_list_np, mask_remain = process_remain_mask(mask_list_np, move_index_list,force_mask_remain)
    mask_list = [torch.from_numpy(m).to( dtype=torch.uint8) for m in mask_list_np]
    mask_remain = torch.from_numpy(mask_remain).to(dtype=torch.uint8)
    return mask_list, mask_remain

def process_mask_remove_torch(mask_list, remove_idx):
    mask_list_np = [m.cpu().numpy() for m in mask_list]
    mask_list_np[remove_idx] = np.zeros_like(mask_list_np[0])
    mask_list_np, mask_remain = process_remain_mask(mask_list_np)
    mask_list = [torch.from_numpy(m).to(dtype=torch.uint8) for m in mask_list_np]
    mask_remain = torch.from_numpy(mask_remain).to(dtype=torch.uint8)
    return mask_list, mask_remain

def get_mask_difference_torch(mask_list1, mask_list2):
    assert len(mask_list1) == len(mask_list2)
    mask_diff = torch.zeros_like(mask_list1[0])
    for mask1 , mask2 in zip(mask_list1, mask_list2):
        diff = ((mask1.float() - mask2.float())!=0).to(torch.uint8)
        mask_diff = mask_union_torch(mask_diff, diff)
    return mask_diff  

def save_mask_list_to_npys(folder, mask_list, mask_label_list, name = "mask"):
    for midx, (mask, mask_label) in enumerate(zip(mask_list, mask_label_list)):
        np.save(os.path.join(folder, "{}{}_{}.npy".format(name, midx, mask_label)), mask)