import itertools

import torch
import torch.nn as nn

import pose_estimation


class MSE(nn.Module):
    def __init__(self, ignore=None):
        super().__init__()

        self.mse = torch.nn.MSELoss(reduction="none")
        self.ignore = ignore if ignore is not None else []

    def forward(self, y_pred, y_data):
        loss = self.mse(y_pred, y_data)

        if len(self.ignore) > 0:
            loss[self.ignore] *= 0

        return loss.sum() / (len(loss) - len(self.ignore))


class Parallel(nn.Module):
    def __init__(self, skeleton, ignore=None, ground_parallel=None):
        super().__init__()

        self.skeleton = skeleton
        if ignore is not None:
            self.ignore = set(ignore)
        else:
            self.ignore = set()

        self.ground_parallel = ground_parallel if ground_parallel is not None else []
        self.parallel_in_3d = []

        self.cos = None

    def forward(self, y_pred3d, y_data, z, spine_j, global_step=0):
        y_pred = y_pred3d[:, :2]
        rleg, lleg = spine_j

        Lcon2d = Lcount = 0
        if hasattr(self, "contact_2d"):
            for c2d in self.contact_2d:
                for (
                    (src_1, dst_1, t_1),
                    (src_2, dst_2, t_2),
                ) in itertools.combinations(c2d, 2):

                    a_1 = torch.lerp(y_data[src_1], y_data[dst_1], t_1)
                    a_2 = torch.lerp(y_data[src_2], y_data[dst_2], t_2)
                    a = a_2 - a_1

                    b_1 = torch.lerp(y_pred[src_1], y_pred[dst_1], t_1)
                    b_2 = torch.lerp(y_pred[src_2], y_pred[dst_2], t_2)
                    b = b_2 - b_1

                    lcon2d = ((a - b) ** 2).sum()
                    Lcon2d = Lcon2d + lcon2d
                    Lcount += 1

        if Lcount > 0:
            Lcon2d = Lcon2d / Lcount

        Ltan = Lpar = Lcos = Lcount = 0
        Lspine = 0
        for i, bone in enumerate(self.skeleton):
            if bone in self.ignore:
                continue

            src, dst = bone

            b = y_data[dst] - y_data[src]
            t = nn.functional.normalize(b, dim=0)
            n = torch.stack([-t[1], t[0]])

            if src == 10 and dst == 11:  # right leg
                a = rleg
            elif src == 13 and dst == 14:  # left leg
                a = lleg
            else:
                a = y_pred[dst] - y_pred[src]

            bone_name = f"{pose_estimation.KPS[src]}_{pose_estimation.KPS[dst]}"
            c = a - b
            lcos_loc = ltan_loc = lpar_loc = 0
            if self.cos is not None:
                if bone not in [
                    (1, 2),  # Neck + Right Shoulder
                    (1, 5),  # Neck + Left Shoulder
                    (9, 10),  # Hips + Right Upper Leg
                    (9, 13),  # Hips + Left Upper Leg
                ]:
                    a = y_pred[dst] - y_pred[src]
                    l2d = torch.norm(a, dim=0)
                    l3d = torch.norm(y_pred3d[dst] - y_pred3d[src], dim=0)
                    lcos = self.cos[i]

                    lcos_loc = (l2d / l3d - lcos) ** 2
                    Lcos = Lcos + lcos_loc
                    lpar_loc = ((a / l2d) * n).sum() ** 2
                    Lpar = Lpar + lpar_loc
            else:
                ltan_loc = ((c * t).sum()) ** 2
                Ltan = Ltan + ltan_loc
                lpar_loc = (c * n).sum() ** 2
                Lpar = Lpar + lpar_loc

            Lcount += 1

        if Lcount > 0:
            Ltan = Ltan / Lcount
            Lcos = Lcos / Lcount
            Lpar = Lpar / Lcount
            Lspine = Lspine / Lcount

        Lgr = Lcount = 0
        for (src, dst), value in self.ground_parallel:
            bone = y_pred[dst] - y_pred[src]
            bone = nn.functional.normalize(bone, dim=0)
            l = (torch.abs(bone[0]) - value) ** 2

            Lgr = Lgr + l
            Lcount += 1

        if Lcount > 0:
            Lgr = Lgr / Lcount

        Lstraight3d = Lcount = 0
        for (i, j), (k, l) in self.parallel_in_3d:
            a = z[j] - z[i]
            a = nn.functional.normalize(a, dim=0)
            b = z[l] - z[k]
            b = nn.functional.normalize(b, dim=0)
            lo = (((a * b).sum() - 1) ** 2).sum()
            Lstraight3d = Lstraight3d + lo
            Lcount += 1

            b = y_data[1] - y_data[8]
            b = nn.functional.normalize(b, dim=0)

        if Lcount > 0:
            Lstraight3d = Lstraight3d / Lcount

        return Ltan, Lcos, Lpar, Lspine, Lgr, Lstraight3d, Lcon2d


class MimickedSelfContactLoss(nn.Module):
    def __init__(self, geodesics_mask):
        super().__init__()
        """
        Loss that lets vertices in contact on presented mesh attract vertices that are close.
        """
        # geodesic distance mask
        self.register_buffer("geomask", geodesics_mask)

    def forward(
        self,
        presented_contact,
        vertices,
        v2v=None,
        contact_mode="dist_tanh",
        contact_thresh=1,
    ):

        contactloss = 0.0

        if v2v is None:
            # compute pairwise distances
            verts = vertices.contiguous()
            nv = verts.shape[1]
            v2v = verts.squeeze().unsqueeze(1).expand(
                nv, nv, 3
            ) - verts.squeeze().unsqueeze(0).expand(nv, nv, 3)
            v2v = torch.norm(v2v, 2, 2)

        # loss for self-contact from mimic'ed pose
        if len(presented_contact) > 0:
            # without geodesic distance mask, compute distances
            # between each pair of verts in contact
            with torch.no_grad():
                cvertstobody = v2v[presented_contact, :]
                cvertstobody = cvertstobody[:, presented_contact]
                maskgeo = self.geomask[presented_contact, :]
                maskgeo = maskgeo[:, presented_contact]
                weights = torch.ones_like(cvertstobody).to(verts.device)
                weights[~maskgeo] = float("inf")
                min_idx = torch.min((cvertstobody + 1) * weights, 1)[1]
                min_idx = presented_contact[min_idx.cpu().numpy()]

            v2v_min = v2v[presented_contact, min_idx]

            # tanh will not pull vertices that are ~more than contact_thres far apart
            if contact_mode == "dist_tanh":
                contactloss = contact_thresh * torch.tanh(v2v_min / contact_thresh)
                contactloss = contactloss.mean()
            else:
                contactloss = v2v_min.mean()

        return contactloss