from einops.einops import rearrange
import torch
import torch.nn as nn
import torch.nn.functional as F
from roma.utils.utils import get_gt_warp
import wandb
import roma
import math


class RobustLosses(nn.Module):
    def __init__(
        self,
        robust=False,
        center_coords=False,
        scale_normalize=False,
        ce_weight=0.01,
        local_loss=True,
        local_dist=4.0,
        local_largest_scale=8,
        smooth_mask=False,
        depth_interpolation_mode="bilinear",
        mask_depth_loss=False,
        relative_depth_error_threshold=0.05,
        alpha=1.0,
        c=1e-3,
    ):
        super().__init__()
        self.robust = robust  # measured in pixels
        self.center_coords = center_coords
        self.scale_normalize = scale_normalize
        self.ce_weight = ce_weight
        self.local_loss = local_loss
        self.local_dist = local_dist
        self.local_largest_scale = local_largest_scale
        self.smooth_mask = smooth_mask
        self.depth_interpolation_mode = depth_interpolation_mode
        self.mask_depth_loss = mask_depth_loss
        self.relative_depth_error_threshold = relative_depth_error_threshold
        self.avg_overlap = dict()
        self.alpha = alpha
        self.c = c

    def gm_cls_loss(self, x2, prob, scale_gm_cls, gm_certainty, scale):
        with torch.no_grad():
            B, C, H, W = scale_gm_cls.shape
            device = x2.device
            cls_res = round(math.sqrt(C))
            G = torch.meshgrid(
                *[
                    torch.linspace(
                        -1 + 1 / cls_res, 1 - 1 / cls_res, steps=cls_res, device=device
                    )
                    for _ in range(2)
                ]
            )
            G = torch.stack((G[1], G[0]), dim=-1).reshape(C, 2)
            GT = (
                (G[None, :, None, None, :] - x2[:, None])
                .norm(dim=-1)
                .min(dim=1)
                .indices
            )
        cls_loss = F.cross_entropy(scale_gm_cls, GT, reduction="none")[prob > 0.99]
        if not torch.any(cls_loss):
            cls_loss = certainty_loss * 0.0  # Prevent issues where prob is 0 everywhere

        certainty_loss = F.binary_cross_entropy_with_logits(gm_certainty[:, 0], prob)
        losses = {
            f"gm_certainty_loss_{scale}": certainty_loss.mean(),
            f"gm_cls_loss_{scale}": cls_loss.mean(),
        }
        wandb.log(losses, step=roma.GLOBAL_STEP)
        return losses

    def delta_cls_loss(
        self, x2, prob, flow_pre_delta, delta_cls, certainty, scale, offset_scale
    ):
        with torch.no_grad():
            B, C, H, W = delta_cls.shape
            device = x2.device
            cls_res = round(math.sqrt(C))
            G = torch.meshgrid(
                *[
                    torch.linspace(
                        -1 + 1 / cls_res, 1 - 1 / cls_res, steps=cls_res, device=device
                    )
                    for _ in range(2)
                ]
            )
            G = torch.stack((G[1], G[0]), dim=-1).reshape(C, 2) * offset_scale
            GT = (
                (G[None, :, None, None, :] + flow_pre_delta[:, None] - x2[:, None])
                .norm(dim=-1)
                .min(dim=1)
                .indices
            )
        cls_loss = F.cross_entropy(delta_cls, GT, reduction="none")[prob > 0.99]
        if not torch.any(cls_loss):
            cls_loss = certainty_loss * 0.0  # Prevent issues where prob is 0 everywhere
        certainty_loss = F.binary_cross_entropy_with_logits(certainty[:, 0], prob)
        losses = {
            f"delta_certainty_loss_{scale}": certainty_loss.mean(),
            f"delta_cls_loss_{scale}": cls_loss.mean(),
        }
        wandb.log(losses, step=roma.GLOBAL_STEP)
        return losses

    def regression_loss(self, x2, prob, flow, certainty, scale, eps=1e-8, mode="delta"):
        epe = (flow.permute(0, 2, 3, 1) - x2).norm(dim=-1)
        if scale == 1:
            pck_05 = (epe[prob > 0.99] < 0.5 * (2 / 512)).float().mean()
            wandb.log({"train_pck_05": pck_05}, step=roma.GLOBAL_STEP)

        ce_loss = F.binary_cross_entropy_with_logits(certainty[:, 0], prob)
        a = self.alpha
        cs = self.c * scale
        x = epe[prob > 0.99]
        reg_loss = cs**a * ((x / (cs)) ** 2 + 1**2) ** (a / 2)
        if not torch.any(reg_loss):
            reg_loss = ce_loss * 0.0  # Prevent issues where prob is 0 everywhere
        losses = {
            f"{mode}_certainty_loss_{scale}": ce_loss.mean(),
            f"{mode}_regression_loss_{scale}": reg_loss.mean(),
        }
        wandb.log(losses, step=roma.GLOBAL_STEP)
        return losses

    def forward(self, corresps, batch):
        scales = list(corresps.keys())
        tot_loss = 0.0
        # scale_weights due to differences in scale for regression gradients and classification gradients
        scale_weights = {1: 1, 2: 1, 4: 1, 8: 1, 16: 1}
        for scale in scales:
            scale_corresps = corresps[scale]
            (
                scale_certainty,
                flow_pre_delta,
                delta_cls,
                offset_scale,
                scale_gm_cls,
                scale_gm_certainty,
                flow,
                scale_gm_flow,
            ) = (
                scale_corresps["certainty"],
                scale_corresps["flow_pre_delta"],
                scale_corresps.get("delta_cls"),
                scale_corresps.get("offset_scale"),
                scale_corresps.get("gm_cls"),
                scale_corresps.get("gm_certainty"),
                scale_corresps["flow"],
                scale_corresps.get("gm_flow"),
            )
            flow_pre_delta = rearrange(flow_pre_delta, "b d h w -> b h w d")
            b, h, w, d = flow_pre_delta.shape
            gt_warp, gt_prob = get_gt_warp(
                batch["im_A_depth"],
                batch["im_B_depth"],
                batch["T_1to2"],
                batch["K1"],
                batch["K2"],
                H=h,
                W=w,
            )
            x2 = gt_warp.float()
            prob = gt_prob

            if self.local_largest_scale >= scale:
                prob = prob * (
                    F.interpolate(prev_epe[:, None], size=(h, w), mode="nearest-exact")[
                        :, 0
                    ]
                    < (2 / 512) * (self.local_dist[scale] * scale)
                )

            if scale_gm_cls is not None:
                gm_cls_losses = self.gm_cls_loss(
                    x2, prob, scale_gm_cls, scale_gm_certainty, scale
                )
                gm_loss = (
                    self.ce_weight * gm_cls_losses[f"gm_certainty_loss_{scale}"]
                    + gm_cls_losses[f"gm_cls_loss_{scale}"]
                )
                tot_loss = tot_loss + scale_weights[scale] * gm_loss
            elif scale_gm_flow is not None:
                gm_flow_losses = self.regression_loss(
                    x2, prob, scale_gm_flow, scale_gm_certainty, scale, mode="gm"
                )
                gm_loss = (
                    self.ce_weight * gm_flow_losses[f"gm_certainty_loss_{scale}"]
                    + gm_flow_losses[f"gm_regression_loss_{scale}"]
                )
                tot_loss = tot_loss + scale_weights[scale] * gm_loss

            if delta_cls is not None:
                delta_cls_losses = self.delta_cls_loss(
                    x2,
                    prob,
                    flow_pre_delta,
                    delta_cls,
                    scale_certainty,
                    scale,
                    offset_scale,
                )
                delta_cls_loss = (
                    self.ce_weight * delta_cls_losses[f"delta_certainty_loss_{scale}"]
                    + delta_cls_losses[f"delta_cls_loss_{scale}"]
                )
                tot_loss = tot_loss + scale_weights[scale] * delta_cls_loss
            else:
                delta_regression_losses = self.regression_loss(
                    x2, prob, flow, scale_certainty, scale
                )
                reg_loss = (
                    self.ce_weight
                    * delta_regression_losses[f"delta_certainty_loss_{scale}"]
                    + delta_regression_losses[f"delta_regression_loss_{scale}"]
                )
                tot_loss = tot_loss + scale_weights[scale] * reg_loss
            prev_epe = (flow.permute(0, 2, 3, 1) - x2).norm(dim=-1).detach()
        return tot_loss