import torch
import torch.nn as nn
import torchvision.transforms as tvf

from .modules import InterestPointModule, CorrespondenceModule


def warp_homography_batch(sources, homographies):
    """
    Batch warp keypoints given homographies. From https://github.com/TRI-ML/KP2D.

    Parameters
    ----------
    sources: torch.Tensor (B,H,W,C)
        Keypoints vector.
    homographies: torch.Tensor (B,3,3)
        Homographies.

    Returns
    -------
    warped_sources: torch.Tensor (B,H,W,C)
        Warped keypoints vector.
    """
    B, H, W, _ = sources.shape
    warped_sources = []
    for b in range(B):
        source = sources[b].clone()
        source = source.view(-1, 2)
        """
        [X,    [M11, M12, M13    [x,    M11*x + M12*y + M13           [M11, M12      [M13,
         Y,  =  M21, M22, M23  *  y, =  M21*x + M22*y + M23 = [x, y] * M21, M22    +  M23,
         Z]     M31, M32, M33]    1]    M31*x + M32*y + M33            M31, M32].T    M33]
        """
        source = torch.addmm(homographies[b, :, 2], source, homographies[b, :, :2].t())
        source.mul_(1 / source[:, 2].unsqueeze(1))
        source = source[:, :2].contiguous().view(H, W, 2)
        warped_sources.append(source)
    return torch.stack(warped_sources, dim=0)


class PointModel(nn.Module):
    def __init__(self, is_test=True):
        super(PointModel, self).__init__()
        self.is_test = is_test
        self.interestpoint_module = InterestPointModule(is_test=self.is_test)
        self.correspondence_module = CorrespondenceModule()
        self.norm_rgb = tvf.Normalize(mean=[0.5, 0.5, 0.5], std=[0.225, 0.225, 0.225])

    def forward(self, *args):
        if self.is_test:
            img = args[0]
            img = self.norm_rgb(img)
            score, coord, desc = self.interestpoint_module(img)
            return score, coord, desc
        else:
            source_score, source_coord, source_desc_block = self.interestpoint_module(
                args[0]
            )
            target_score, target_coord, target_desc_block = self.interestpoint_module(
                args[1]
            )

            B, _, H, W = args[0].shape
            B, _, hc, wc = source_score.shape
            device = source_score.device

            # Normalize the coordinates from ([0, h], [0, w]) to ([0, 1], [0, 1]).
            source_coord_norm = source_coord.clone()
            source_coord_norm[:, 0] = (
                source_coord_norm[:, 0] / (float(W - 1) / 2.0)
            ) - 1.0
            source_coord_norm[:, 1] = (
                source_coord_norm[:, 1] / (float(H - 1) / 2.0)
            ) - 1.0
            source_coord_norm = source_coord_norm.permute(0, 2, 3, 1)

            target_coord_norm = target_coord.clone()
            target_coord_norm[:, 0] = (
                target_coord_norm[:, 0] / (float(W - 1) / 2.0)
            ) - 1.0
            target_coord_norm[:, 1] = (
                target_coord_norm[:, 1] / (float(H - 1) / 2.0)
            ) - 1.0
            target_coord_norm = target_coord_norm.permute(0, 2, 3, 1)

            target_coord_warped_norm = warp_homography_batch(source_coord_norm, args[2])
            target_coord_warped = target_coord_warped_norm.clone()

            # de-normlize the coordinates
            target_coord_warped[:, :, :, 0] = (target_coord_warped[:, :, :, 0] + 1) * (
                float(W - 1) / 2.0
            )
            target_coord_warped[:, :, :, 1] = (target_coord_warped[:, :, :, 1] + 1) * (
                float(H - 1) / 2.0
            )
            target_coord_warped = target_coord_warped.permute(0, 3, 1, 2)

            # Border mask
            border_mask_ori = torch.ones(B, hc, wc)
            border_mask_ori[:, 0] = 0
            border_mask_ori[:, hc - 1] = 0
            border_mask_ori[:, :, 0] = 0
            border_mask_ori[:, :, wc - 1] = 0
            border_mask_ori = border_mask_ori.gt(1e-3).to(device)

            oob_mask2 = (
                target_coord_warped_norm[:, :, :, 0].lt(1)
                & target_coord_warped_norm[:, :, :, 0].gt(-1)
                & target_coord_warped_norm[:, :, :, 1].lt(1)
                & target_coord_warped_norm[:, :, :, 1].gt(-1)
            )
            border_mask = border_mask_ori & oob_mask2

            # score
            target_score_warped = torch.nn.functional.grid_sample(
                target_score, target_coord_warped_norm.detach(), align_corners=False
            )

            # descriptor
            source_desc2 = torch.nn.functional.grid_sample(
                source_desc_block[0], source_coord_norm.detach()
            )
            source_desc3 = torch.nn.functional.grid_sample(
                source_desc_block[1], source_coord_norm.detach()
            )
            source_aware = source_desc_block[2]
            source_desc = torch.mul(
                source_desc2, source_aware[:, 0, :, :].unsqueeze(1).contiguous()
            ) + torch.mul(
                source_desc3, source_aware[:, 1, :, :].unsqueeze(1).contiguous()
            )

            target_desc2 = torch.nn.functional.grid_sample(
                target_desc_block[0], target_coord_norm.detach()
            )
            target_desc3 = torch.nn.functional.grid_sample(
                target_desc_block[1], target_coord_norm.detach()
            )
            target_aware = target_desc_block[2]
            target_desc = torch.mul(
                target_desc2, target_aware[:, 0, :, :].unsqueeze(1).contiguous()
            ) + torch.mul(
                target_desc3, target_aware[:, 1, :, :].unsqueeze(1).contiguous()
            )

            target_desc2_warped = torch.nn.functional.grid_sample(
                target_desc_block[0], target_coord_warped_norm.detach()
            )
            target_desc3_warped = torch.nn.functional.grid_sample(
                target_desc_block[1], target_coord_warped_norm.detach()
            )
            target_aware_warped = torch.nn.functional.grid_sample(
                target_desc_block[2], target_coord_warped_norm.detach()
            )
            target_desc_warped = torch.mul(
                target_desc2_warped,
                target_aware_warped[:, 0, :, :].unsqueeze(1).contiguous(),
            ) + torch.mul(
                target_desc3_warped,
                target_aware_warped[:, 1, :, :].unsqueeze(1).contiguous(),
            )

            confidence_matrix = self.correspondence_module(source_desc, target_desc)
            confidence_matrix = torch.clamp(confidence_matrix, 1e-12, 1 - 1e-12)

            output = {
                "source_score": source_score,
                "source_coord": source_coord,
                "source_desc": source_desc,
                "source_aware": source_aware,
                "target_score": target_score,
                "target_coord": target_coord,
                "target_score_warped": target_score_warped,
                "target_coord_warped": target_coord_warped,
                "target_desc_warped": target_desc_warped,
                "target_aware_warped": target_aware_warped,
                "border_mask": border_mask,
                "confidence_matrix": confidence_matrix,
            }

            return output