import gc

import torch
import torch.nn.functional as F

from einops import repeat, rearrange
from vidtome import merge
from utils.flow_utils import flow_warp, coords_grid

# AdaIn


def calc_mean_std(feat, eps=1e-5):
    # eps is a small value added to the variance to avoid divide-by-zero.
    size = feat.size()
    assert (len(size) == 4)
    N, C = size[:2]
    feat_var = feat.view(N, C, -1).var(dim=2) + eps
    feat_std = feat_var.sqrt().view(N, C, 1, 1)
    feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
    return feat_mean, feat_std


class AttentionControl():

    def __init__(self,
                 warp_period=(0.0, 0.0),
                 merge_period=(0.0, 0.0),
                 merge_ratio=(0.3, 0.3),
                 ToMe_period=(0.0, 1.0),
                 mask_period=(0.0, 0.0),
                 cross_period=(0.0, 0.0),
                 ada_period=(0.0, 0.0),
                 inner_strength=1.0,
                 loose_cfatnn=False,
                 flow_merge=True, 
                ):
        
        self.cur_frame_idx = 0

        self.step_store = self.get_empty_store()
        self.cur_step = 0
        self.total_step = 0
        self.cur_index = 0
        self.init_store = False
        self.restore = False
        self.update = False
        self.flow = None
        self.mask = None
        self.cldm = None
        self.decoded_imgs = []
        self.restorex0 = True
        self.updatex0 = False
        self.inner_strength = inner_strength
        self.cross_period = cross_period
        self.mask_period = mask_period
        self.ada_period = ada_period
        self.warp_period = warp_period
        self.ToMe_period = ToMe_period
        self.merge_period = merge_period
        self.merge_ratio = merge_ratio
        self.keyframe_idx = 0
        self.flow_merge = flow_merge
        self.distances = {}
        self.flow_correspondence = {}
        self.non_pad_ratio = (1.0, 1.0)
        self.up_resolution = 1280 if loose_cfatnn else 1281

    @staticmethod
    def get_empty_store():
        return {
            'first': [],
            'previous': [],
            'x0_previous': [],
            'first_ada': [],
            'pre_x0': [],
            "pre_keyframe_lq": None,
            "flows": None,
            "occ_masks": None,
            "flow_confids": None,
            "merge": None,
            "unmerge": None,
            "corres_scores": None,
            "flows2": None,
            "flow_confids2": None,
        }

    def forward(self, context, is_cross: bool, place_in_unet: str):
        cross_period = (self.total_step * self.cross_period[0],
                        self.total_step * self.cross_period[1])
        if not is_cross and place_in_unet == 'up' and context.shape[
                2] < self.up_resolution:
            if self.init_store:
                self.step_store['first'].append(context.detach())
                self.step_store['previous'].append(context.detach())
            if self.update:
                tmp = context.clone().detach()
            if self.restore and self.cur_step >= cross_period[0] and \
                    self.cur_step <= cross_period[1]:
                # context = torch.cat(
                #     (self.step_store['first'][self.cur_index],
                #      self.step_store['previous'][self.cur_index]),
                #     dim=1).clone()
                context = self.step_store['previous'][self.cur_index].clone()
            if self.update:
                self.step_store['previous'][self.cur_index] = tmp
            self.cur_index += 1
        # print(is_cross, place_in_unet, context.shape[2])
        # import ipdb; ipdb.set_trace()
        return context

    def update_x0(self, x0, cur_frame=0):
        # if self.init_store:
        #     self.step_store['x0_previous'].append(x0.detach())
        #     style_mean, style_std = calc_mean_std(x0.detach())
        #     self.step_store['first_ada'].append(style_mean.detach())
        #     self.step_store['first_ada'].append(style_std.detach())
        # if self.updatex0:
        #     tmp = x0.clone().detach()
        if self.restorex0:
            # if self.cur_step >= self.total_step * self.ada_period[
            #         0] and self.cur_step <= self.total_step * self.ada_period[
            #             1]:
            #     x0 = F.instance_norm(x0) * self.step_store['first_ada'][
            #         2 * self.cur_step +
            #         1] + self.step_store['first_ada'][2 * self.cur_step]
            if self.cur_step >= self.total_step * self.warp_period[
                    0] and self.cur_step < int(self.total_step * self.warp_period[1]):
                
                # mid_x = repeat(x[mid][None], 'b c h w -> (repeat b) c h w', repeat=x.shape[0])
                mid = x0.shape[0] // 2
                if len(self.step_store["pre_x0"]) == int(self.total_step * self.warp_period[1]):
                    print(f"[INFO] keyframe latent warping @ step {self.cur_step}...")
                    x0[mid] = (1 - self.step_store["occ_masks"][mid]) * x0[mid] + \
                        flow_warp(self.step_store["pre_x0"][self.cur_step][None], self.step_store["flows"][mid], mode='nearest')[0] * self.step_store["occ_masks"][mid] 
                    
                print(f"[INFO] local latent warping @ step {self.cur_step}...")
                for i in range(x0.shape[0]):
                    if i == mid:
                        continue
                    x0[i] = (1 - self.step_store["occ_masks"][i]) * x0[i] + \
                        flow_warp(x0[mid][None], self.step_store["flows"][i], mode='nearest')[0] * self.step_store["occ_masks"][i] 
                # x = rearrange(x, 'b c h w -> b (h w) c', h=64) 
                # self.step_store['x0_previous'][self.cur_step] = tmp
                # print(f"[INFO] storeing {self.cur_frame_idx} th frame x0 for step {self.cur_step}...")
                if len(self.step_store["pre_x0"]) < int(self.total_step * self.warp_period[1]):
                    self.step_store['pre_x0'].append(x0[mid])
                else:
                    self.step_store['pre_x0'][self.cur_step] = x0[mid]

        return x0

    def merge_x0(self, x0, merge_ratio):
        # print(f"[INFO] {self.total_step * self.merge_period[0]} {self.cur_step} {int(self.total_step * self.merge_period[1])} ...")
        if self.cur_step >= self.total_step * self.merge_period[0] and \
            self.cur_step < int(self.total_step * self.merge_period[1]):
            print(f"[INFO] latent merging @ step {self.cur_step}...")

            B, C, H, W = x0.shape
            non_pad_ratio_h, non_pad_ratio_w = self.non_pad_ratio
            padding_size_w = W - int(W * non_pad_ratio_w)
            padding_size_h = H - int(H * non_pad_ratio_h)
            non_pad_w = W - padding_size_w
            non_pad_h = H - padding_size_h
            padding_mask = torch.zeros((H, W), device=x0.device, dtype=torch.bool)
            if padding_size_w:
                padding_mask[:, -padding_size_w:] = 1
            if padding_size_h:
                padding_mask[-padding_size_h:, :] = 1
            padding_mask = rearrange(padding_mask, 'h w -> (h w)')
            
            idx_buffer = torch.arange(H*W, device=x0.device, dtype=torch.int64)
            non_pad_idx = idx_buffer[None, ~padding_mask, None]
            del idx_buffer, padding_mask
            x0 = rearrange(x0, 'b c h w -> b (h w) c', h=H)
            x_non_pad = torch.gather(x0, dim=1, index=non_pad_idx.expand(B, -1, C))
            # import ipdb; ipdb.set_trace()
            # merge.visualize_correspondence(x_non_pad[0][None], x_non_pad[B//2][None], ratio=0.3, H=H, out="latent_correspondence.png")

            # m, u, ret_dict = merge.bipartite_soft_matching_randframe(
            #                     x_non_pad, B, merge_ratio, 0, target_stride=B)
            import copy
            flows = copy.deepcopy(self.step_store["flows"])
            for i in range(B):
                if flows[i] is not None:
                    flows[i] = flows[i][:, :, :non_pad_h, :non_pad_w]
            # merge.visualize_flow_correspondence(x_non_pad[1][None], x_non_pad[B // 2][None], flow=flows[1], flow_confid=self.step_store["flow_confids"][1], \
            #                                  ratio=0.8, H=H, out=f"flow_correspondence_08.png")
            # import ipdb; ipdb.set_trace()
            x_non_pad = rearrange(x_non_pad, 'b a c -> 1 (b a) c')
            m, u, ret_dict = merge.bipartite_soft_matching_randframe(
                    x_non_pad, B, merge_ratio, 0, target_stride=B, 
                    H=H,
                    flow=flows, 
                    flow_confid=self.step_store["flow_confids"],
                    )
            x_non_pad = u(m(x_non_pad))
            # x_non_pad = self.step_store["unmerge"](self.step_store["merge"](x_non_pad))
            x_non_pad = rearrange(x_non_pad, '1 (b a) c -> b a c', b=B)
            # print(torch.mean(x0[0]).item(), torch.mean(x0[1]).item(), torch.mean(x0[2]).item(), torch.mean(x0[3]).item(), torch.mean(x0[4]).item())
            # print(torch.std(x0[0]).item(), torch.std(x0[1]).item(), torch.std(x0[2]).item(), torch.std(x0[3]).item(), torch.std(x0[4]).item())
            # import ipdb; ipdb.set_trace()
            x0.scatter_(dim=1, index=non_pad_idx.expand(B, -1, C), src=x_non_pad)
            x0 = rearrange(x0, 'b (h w) c -> b c h w ', h=H)
            # import ipdb; ipdb.set_trace()
        
        return x0
    
    def merge_x0_scores(self, x0, merge_ratio, merge_mode="replace"):
        # print(f"[INFO] {self.total_step * self.merge_period[0]} {self.cur_step} {int(self.total_step * self.merge_period[1])} ...")
        # import ipdb; ipdb.set_trace()
        if self.cur_step >= self.total_step * self.merge_period[0] and \
            self.cur_step < int(self.total_step * self.merge_period[1]):
            print(f"[INFO] latent merging @ step {self.cur_step}...")

            B, C, H, W = x0.shape
            non_pad_ratio_h, non_pad_ratio_w = self.non_pad_ratio
            padding_size_w = W - int(W * non_pad_ratio_w)
            padding_size_h = H - int(H * non_pad_ratio_h)
            padding_mask = torch.zeros((H, W), device=x0.device, dtype=torch.bool)
            if padding_size_w:
                padding_mask[:, -padding_size_w:] = 1
            if padding_size_h:
                padding_mask[-padding_size_h:, :] = 1
            padding_mask = rearrange(padding_mask, 'h w -> (h w)')
            
            idx_buffer = torch.arange(H*W, device=x0.device, dtype=torch.int64)
            non_pad_idx = idx_buffer[None, ~padding_mask, None]
            x0 = rearrange(x0, 'b c h w -> b (h w) c', h=H)
            x_non_pad = torch.gather(x0, dim=1, index=non_pad_idx.expand(B, -1, C))
            x_non_pad_A, x_non_pad_N = x_non_pad.shape[1], x_non_pad.shape[1] * B
            mid = B // 2
            
            x_non_pad_ = x_non_pad.clone()
            x_non_pad = rearrange(x_non_pad, 'b a c -> 1 (b a) c')
            # import ipdb; ipdb.set_trace()

            idx_buffer = torch.arange(x_non_pad_N, device=x0.device, dtype=torch.int64)
            randf = torch.tensor(B // 2, dtype=torch.int).to(x0.device)
            # print(f"[INFO] {randf.item()} th frame as target")
            dst_select = ((torch.div(idx_buffer, x_non_pad_A, rounding_mode='floor')) % B == randf).to(torch.bool)
            # a_idx: src index. b_idx: dst index
            a_idx = idx_buffer[None, ~dst_select, None] 
            b_idx = idx_buffer[None, dst_select, None]
            del idx_buffer, padding_mask
            num_dst = b_idx.shape[1]
            # b, _, _ = x_non_pad.shape
            b = 1
            src = torch.gather(x_non_pad, dim=1, index=a_idx.expand(b, x_non_pad_N - num_dst, C))
            tar = torch.gather(x_non_pad, dim=1, index=b_idx.expand(b, num_dst, C))
            # tar = x_non_pad[mid][None]
            # src = torch.cat((x_non_pad[:mid], x_non_pad[mid+1:]), dim=0)
            # src = rearrange(src, 'b n c -> 1 (b n) c')
            # print(f"[INFO] {x_non_pad.shape} {src.shape} {tar.shape} ...")
            # print(f"[INFO] maximum score {torch.max(self.step_store['corres_scores'])} ...")
            flow_src_idx = self.flow_correspondence[H][0]
            flow_tar_idx = self.flow_correspondence[H][1]
            flow_confid = self.step_store["flow_confids"][:mid] + self.step_store["flow_confids"][mid+1:]
            flow_confid = torch.cat(flow_confid, dim=0) 
            flow_confid = rearrange(flow_confid, 'b h w -> 1 (b h w)')
            scores = F.normalize(self.step_store["corres_scores"], p=2, dim=-1)

            flow_confid -= (torch.max(flow_confid) - torch.max(scores)) 

            # merge.visualize_correspondence_score(x_non_pad_[0][None], x_non_pad_[mid][None], 
            #                                    score=scores[:,:x_non_pad_A],
            #                                    ratio=0.2, H=H-padding_size_h, out="latent_correspondence.png")
            # import ipdb; ipdb.set_trace()
            scores[:, flow_src_idx[0, :, 0], flow_tar_idx[0, :, 0]] += (flow_confid[:, flow_src_idx[0, :, 0]] * 0.3)
            # merge.visualize_correspondence_score(x_non_pad_[0][None], x_non_pad_[mid][None], 
            #                                    score=scores[:,:x_non_pad_A],
            #                                    ratio=0.2, H=H-padding_size_h, out="latent_correspondence_flow.png")

            # import ipdb; ipdb.set_trace()
            r = min(src.shape[1], int(src.shape[1] * merge_ratio))
            node_max, node_idx = scores.max(dim=-1)
            edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]
            unm_idx = edge_idx[..., r:, :]  # Unmerged Tokens
            src_idx = edge_idx[..., :r, :]  # Merged Tokens 
            tar_idx = torch.gather(node_idx[..., None], dim=-2, index=src_idx)
            unm = torch.gather(src, dim=-2, index=unm_idx.expand(-1, -1, C))
            if merge_mode != "replace":
                src = torch.gather(src, dim=-2, index=src_idx.expand(-1, -1, C))
                # In other mode such as mean, combine matched src and dst tokens.
                tar = tar.scatter_reduce(-2, tar_idx.expand(-1, -1, C),
                                        src, reduce=merge_mode, include_self=True)
            # In replace mode, just cat unmerged tokens and tar tokens. Ignore src tokens.
            # token = torch.cat([unm, tar], dim=1)

            # unm_len = unm_idx.shape[1]
            # unm, tar = token[..., :unm_len, :], token[..., unm_len:, :]
            src = torch.gather(tar, dim=-2, index=tar_idx.expand(-1, -1, C))
            # Combine back to the original shape
            # x_non_pad = torch.zeros(b, x_non_pad_N, C, device=x0.device, dtype=x0.dtype)
            # Scatter dst tokens
            x_non_pad.scatter_(dim=-2, index=b_idx.expand(b, -1, C), src=tar)
            # Scatter unmerged tokens
            x_non_pad.scatter_(dim=-2, index=torch.gather(a_idx.expand(b, -1, 1),
                        dim=1, index=unm_idx).expand(-1, -1, C), src=unm)
            # Scatter src tokens
            x_non_pad.scatter_(dim=-2, index=torch.gather(a_idx.expand(b, -1, 1),
                        dim=1, index=src_idx).expand(-1, -1, C), src=src)

            x_non_pad = rearrange(x_non_pad, '1 (b a) c -> b a c', a=x_non_pad_A)
            x0.scatter_(dim=1, index=non_pad_idx.expand(B, -1, C), src=x_non_pad)
            x0 = rearrange(x0, 'b (h w) c -> b c h w ', h=H)
        
        return x0

    def set_distance(self, B, H, W, radius, device):
        y, x = torch.meshgrid(torch.arange(H), torch.arange(W))
        coords = torch.stack((y, x), dim=-1).float().to(device)
        coords = rearrange(coords, 'h w c -> (h w) c')

        # Calculate the Euclidean distance between all pixels
        distances = torch.cdist(coords, coords)
        # radius = W // 30
        radius = 1 if radius == 0 else radius
        # print(f"[INFO]  W: {W} Radius: {radius} ")
        distances //= radius
        distances = torch.exp(-distances)
        # distances += torch.diag_embed(torch.ones(A)).to(metric.device)
        distances = repeat(distances, 'h a -> 1 (b h) a', b=B)
        self.distances[H] = distances
    
    def set_flow_correspondence(self, B, H, W, key_idx, flow_confid, flow):

        if len(flow) != B - 1:
                flow_confid = flow_confid[:key_idx] + flow_confid[key_idx+1:]
                flow = flow[:key_idx] + flow[key_idx+1:]

        flow_confid = torch.cat(flow_confid, dim=0) 
        flow = torch.cat(flow, dim=0) 
        flow_confid = rearrange(flow_confid, 'b h w -> 1 (b h w)')
        
        edge_idx = flow_confid.argsort(dim=-1, descending=True)[..., None]

        src_idx = edge_idx[..., :, :]  # Merged Tokens 

        A = H * W
        src_idx_tensor = src_idx[0, : ,0]
        f = src_idx_tensor // A
        id = src_idx_tensor % A
        x = id % W
        y = id // W

        # Stack the results into a 2D tensor
        src_fxy = torch.stack((f, x, y), dim=1)
        # import ipdb; ipdb.set_trace()
        grid = coords_grid(B-1, H, W).to(flow.device) + flow  # [F-1, 2, H, W]

        x = grid[src_fxy[:, 0], 0, src_fxy[:, 2], src_fxy[:, 1]].clamp(0, W-1).long()
        y = grid[src_fxy[:, 0], 1, src_fxy[:, 2], src_fxy[:, 1]].clamp(0, H-1).long()
        tar_xy = torch.stack((x, y), dim=1)
        tar_idx = y * W + x
        tar_idx = rearrange(tar_idx, ' d -> 1 d 1')

        self.flow_correspondence[H] = (src_idx, tar_idx)

    def set_merge(self, merge, unmerge):
        self.step_store["merge"] = merge
        self.step_store["unmerge"] = unmerge

    def set_warp(self, flows, masks, flow_confids=None):
        self.step_store["flows"] = flows
        self.step_store["occ_masks"] = masks
        if flow_confids is not None:
            self.step_store["flow_confids"] = flow_confids

    def set_warp2(self, flows, flow_confids):
        self.step_store["flows2"] = flows
        self.step_store["flow_confids2"] = flow_confids

    def set_pre_keyframe_lq(self, pre_keyframe_lq):
        self.step_store["pre_keyframe_lq"] = pre_keyframe_lq

    def __call__(self, context, is_cross: bool, place_in_unet: str):
        context = self.forward(context, is_cross, place_in_unet)
        return context

    def set_cur_frame_idx(self, frame_idx):
        self.cur_frame_idx = frame_idx

    def set_step(self, step):
        self.cur_step = step

    def set_total_step(self, total_step):
        self.total_step = total_step
        self.cur_index = 0

    def clear_store(self):
        del self.step_store
        torch.cuda.empty_cache()
        gc.collect()
        self.step_store = self.get_empty_store()

    def set_task(self, task, restore_step=1.0):
        self.init_store = False
        self.restore = False
        self.update = False
        self.cur_index = 0
        self.restore_step = restore_step
        self.updatex0 = False
        self.restorex0 = False
        if 'initfirst' in task:
            self.init_store = True
            self.clear_store()
        if 'updatestyle' in task:
            self.update = True
        if 'keepstyle' in task:
            self.restore = True
        if 'updatex0' in task: 
            self.updatex0 = True
        if 'keepx0' in task:
            self.restorex0 = True