from pathlib import Path
from types import SimpleNamespace
import warnings
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from typing import Optional, List, Callable

try:
    from flash_attn.modules.mha import FlashCrossAttention
except ModuleNotFoundError:
    FlashCrossAttention = None

if FlashCrossAttention or hasattr(F, "scaled_dot_product_attention"):
    FLASH_AVAILABLE = True
else:
    FLASH_AVAILABLE = False

torch.backends.cudnn.deterministic = True


@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
def normalize_keypoints(kpts: torch.Tensor, size: torch.Tensor) -> torch.Tensor:
    if isinstance(size, torch.Size):
        size = torch.tensor(size)[None]
    shift = size.float().to(kpts) / 2
    scale = size.max(1).values.float().to(kpts) / 2
    kpts = (kpts - shift[:, None]) / scale[:, None, None]
    return kpts


def rotate_half(x: torch.Tensor) -> torch.Tensor:
    x = x.unflatten(-1, (-1, 2))
    x1, x2 = x.unbind(dim=-1)
    return torch.stack((-x2, x1), dim=-1).flatten(start_dim=-2)


def apply_cached_rotary_emb(freqs: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
    return (t * freqs[0]) + (rotate_half(t) * freqs[1])


class LearnableFourierPositionalEncoding(nn.Module):
    def __init__(self, M: int, dim: int, F_dim: int = None, gamma: float = 1.0) -> None:
        super().__init__()
        F_dim = F_dim if F_dim is not None else dim
        self.gamma = gamma
        self.Wr = nn.Linear(M, F_dim // 2, bias=False)
        nn.init.normal_(self.Wr.weight.data, mean=0, std=self.gamma**-2)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """encode position vector"""
        projected = self.Wr(x)
        cosines, sines = torch.cos(projected), torch.sin(projected)
        emb = torch.stack([cosines, sines], 0).unsqueeze(-3)
        return emb.repeat_interleave(2, dim=-1)


class TokenConfidence(nn.Module):
    def __init__(self, dim: int) -> None:
        super().__init__()
        self.token = nn.Sequential(nn.Linear(dim, 1), nn.Sigmoid())

    def forward(self, desc0: torch.Tensor, desc1: torch.Tensor):
        """get confidence tokens"""
        return (
            self.token(desc0.detach().float()).squeeze(-1),
            self.token(desc1.detach().float()).squeeze(-1),
        )


class Attention(nn.Module):
    def __init__(self, allow_flash: bool) -> None:
        super().__init__()
        if allow_flash and not FLASH_AVAILABLE:
            warnings.warn(
                "FlashAttention is not available. For optimal speed, "
                "consider installing torch >= 2.0 or flash-attn.",
                stacklevel=2,
            )
        self.enable_flash = allow_flash and FLASH_AVAILABLE
        if allow_flash and FlashCrossAttention:
            self.flash_ = FlashCrossAttention()

    def forward(self, q, k, v) -> torch.Tensor:
        if self.enable_flash and q.device.type == "cuda":
            if FlashCrossAttention:
                q, k, v = [x.transpose(-2, -3) for x in [q, k, v]]
                m = self.flash_(q.half(), torch.stack([k, v], 2).half())
                return m.transpose(-2, -3).to(q.dtype)
            else:  # use torch 2.0 scaled_dot_product_attention with flash
                args = [x.half().contiguous() for x in [q, k, v]]
                with torch.backends.cuda.sdp_kernel(enable_flash=True):
                    return F.scaled_dot_product_attention(*args).to(q.dtype)
        elif hasattr(F, "scaled_dot_product_attention"):
            args = [x.contiguous() for x in [q, k, v]]
            return F.scaled_dot_product_attention(*args).to(q.dtype)
        else:
            s = q.shape[-1] ** -0.5
            attn = F.softmax(torch.einsum("...id,...jd->...ij", q, k) * s, -1)
            return torch.einsum("...ij,...jd->...id", attn, v)


class Transformer(nn.Module):
    def __init__(
        self, embed_dim: int, num_heads: int, flash: bool = False, bias: bool = True
    ) -> None:
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        assert self.embed_dim % num_heads == 0
        self.head_dim = self.embed_dim // num_heads
        self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias)
        self.inner_attn = Attention(flash)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.ffn = nn.Sequential(
            nn.Linear(2 * embed_dim, 2 * embed_dim),
            nn.LayerNorm(2 * embed_dim, elementwise_affine=True),
            nn.GELU(),
            nn.Linear(2 * embed_dim, embed_dim),
        )

    def _forward(self, x: torch.Tensor, encoding: Optional[torch.Tensor] = None):
        qkv = self.Wqkv(x)
        qkv = qkv.unflatten(-1, (self.num_heads, -1, 3)).transpose(1, 2)
        q, k, v = qkv[..., 0], qkv[..., 1], qkv[..., 2]
        if encoding is not None:
            q = apply_cached_rotary_emb(encoding, q)
            k = apply_cached_rotary_emb(encoding, k)
        context = self.inner_attn(q, k, v)
        message = self.out_proj(context.transpose(1, 2).flatten(start_dim=-2))
        return x + self.ffn(torch.cat([x, message], -1))

    def forward(self, x0, x1, encoding0=None, encoding1=None):
        return self._forward(x0, encoding0), self._forward(x1, encoding1)


class CrossTransformer(nn.Module):
    def __init__(
        self, embed_dim: int, num_heads: int, flash: bool = False, bias: bool = True
    ) -> None:
        super().__init__()
        self.heads = num_heads
        dim_head = embed_dim // num_heads
        self.scale = dim_head**-0.5
        inner_dim = dim_head * num_heads
        self.to_qk = nn.Linear(embed_dim, inner_dim, bias=bias)
        self.to_v = nn.Linear(embed_dim, inner_dim, bias=bias)
        self.to_out = nn.Linear(inner_dim, embed_dim, bias=bias)
        self.ffn = nn.Sequential(
            nn.Linear(2 * embed_dim, 2 * embed_dim),
            nn.LayerNorm(2 * embed_dim, elementwise_affine=True),
            nn.GELU(),
            nn.Linear(2 * embed_dim, embed_dim),
        )

        if flash and FLASH_AVAILABLE:
            self.flash = Attention(True)
        else:
            self.flash = None

    def map_(self, func: Callable, x0: torch.Tensor, x1: torch.Tensor):
        return func(x0), func(x1)

    def forward(self, x0: torch.Tensor, x1: torch.Tensor) -> List[torch.Tensor]:
        qk0, qk1 = self.map_(self.to_qk, x0, x1)
        v0, v1 = self.map_(self.to_v, x0, x1)
        qk0, qk1, v0, v1 = map(
            lambda t: t.unflatten(-1, (self.heads, -1)).transpose(1, 2),
            (qk0, qk1, v0, v1),
        )
        if self.flash is not None:
            m0 = self.flash(qk0, qk1, v1)
            m1 = self.flash(qk1, qk0, v0)
        else:
            qk0, qk1 = qk0 * self.scale**0.5, qk1 * self.scale**0.5
            sim = torch.einsum("b h i d, b h j d -> b h i j", qk0, qk1)
            attn01 = F.softmax(sim, dim=-1)
            attn10 = F.softmax(sim.transpose(-2, -1).contiguous(), dim=-1)
            m0 = torch.einsum("bhij, bhjd -> bhid", attn01, v1)
            m1 = torch.einsum("bhji, bhjd -> bhid", attn10.transpose(-2, -1), v0)
        m0, m1 = self.map_(lambda t: t.transpose(1, 2).flatten(start_dim=-2), m0, m1)
        m0, m1 = self.map_(self.to_out, m0, m1)
        x0 = x0 + self.ffn(torch.cat([x0, m0], -1))
        x1 = x1 + self.ffn(torch.cat([x1, m1], -1))
        return x0, x1


def sigmoid_log_double_softmax(
    sim: torch.Tensor, z0: torch.Tensor, z1: torch.Tensor
) -> torch.Tensor:
    """create the log assignment matrix from logits and similarity"""
    b, m, n = sim.shape
    certainties = F.logsigmoid(z0) + F.logsigmoid(z1).transpose(1, 2)
    scores0 = F.log_softmax(sim, 2)
    scores1 = F.log_softmax(sim.transpose(-1, -2).contiguous(), 2).transpose(-1, -2)
    scores = sim.new_full((b, m + 1, n + 1), 0)
    scores[:, :m, :n] = scores0 + scores1 + certainties
    scores[:, :-1, -1] = F.logsigmoid(-z0.squeeze(-1))
    scores[:, -1, :-1] = F.logsigmoid(-z1.squeeze(-1))
    return scores


class MatchAssignment(nn.Module):
    def __init__(self, dim: int) -> None:
        super().__init__()
        self.dim = dim
        self.matchability = nn.Linear(dim, 1, bias=True)
        self.final_proj = nn.Linear(dim, dim, bias=True)

    def forward(self, desc0: torch.Tensor, desc1: torch.Tensor):
        """build assignment matrix from descriptors"""
        mdesc0, mdesc1 = self.final_proj(desc0), self.final_proj(desc1)
        _, _, d = mdesc0.shape
        mdesc0, mdesc1 = mdesc0 / d**0.25, mdesc1 / d**0.25
        sim = torch.einsum("bmd,bnd->bmn", mdesc0, mdesc1)
        z0 = self.matchability(desc0)
        z1 = self.matchability(desc1)
        scores = sigmoid_log_double_softmax(sim, z0, z1)
        return scores, sim

    def scores(self, desc0: torch.Tensor, desc1: torch.Tensor):
        m0 = torch.sigmoid(self.matchability(desc0)).squeeze(-1)
        m1 = torch.sigmoid(self.matchability(desc1)).squeeze(-1)
        return m0, m1


def filter_matches(scores: torch.Tensor, th: float):
    """obtain matches from a log assignment matrix [Bx M+1 x N+1]"""
    max0, max1 = scores[:, :-1, :-1].max(2), scores[:, :-1, :-1].max(1)
    m0, m1 = max0.indices, max1.indices
    mutual0 = torch.arange(m0.shape[1]).to(m0)[None] == m1.gather(1, m0)
    mutual1 = torch.arange(m1.shape[1]).to(m1)[None] == m0.gather(1, m1)
    max0_exp = max0.values.exp()
    zero = max0_exp.new_tensor(0)
    mscores0 = torch.where(mutual0, max0_exp, zero)
    mscores1 = torch.where(mutual1, mscores0.gather(1, m1), zero)
    if th is not None:
        valid0 = mutual0 & (mscores0 > th)
    else:
        valid0 = mutual0
    valid1 = mutual1 & valid0.gather(1, m1)
    m0 = torch.where(valid0, m0, m0.new_tensor(-1))
    m1 = torch.where(valid1, m1, m1.new_tensor(-1))
    return m0, m1, mscores0, mscores1


class LightGlue(nn.Module):
    default_conf = {
        "name": "lightglue",  # just for interfacing
        "input_dim": 256,  # input descriptor dimension (autoselected from weights)
        "descriptor_dim": 256,
        "n_layers": 9,
        "num_heads": 4,
        "flash": True,  # enable FlashAttention if available.
        "mp": False,  # enable mixed precision
        "depth_confidence": 0.95,  # early stopping, disable with -1
        "width_confidence": 0.99,  # point pruning, disable with -1
        "filter_threshold": 0.1,  # match threshold
        "weights": None,
    }

    required_data_keys = ["image0", "image1"]

    version = "v0.1_arxiv"
    url = "https://github.com/cvg/LightGlue/releases/download/{}/{}_lightglue.pth"

    features = {
        "superpoint": ("superpoint_lightglue", 256),
        "disk": ("disk_lightglue", 128),
    }

    def __init__(self, features="superpoint", **conf) -> None:
        super().__init__()
        self.conf = {**self.default_conf, **conf}
        if features is not None:
            assert features in list(self.features.keys())
            self.conf["weights"], self.conf["input_dim"] = self.features[features]
        self.conf = conf = SimpleNamespace(**self.conf)

        if conf.input_dim != conf.descriptor_dim:
            self.input_proj = nn.Linear(conf.input_dim, conf.descriptor_dim, bias=True)
        else:
            self.input_proj = nn.Identity()

        head_dim = conf.descriptor_dim // conf.num_heads
        self.posenc = LearnableFourierPositionalEncoding(2, head_dim, head_dim)

        h, n, d = conf.num_heads, conf.n_layers, conf.descriptor_dim
        self.self_attn = nn.ModuleList(
            [Transformer(d, h, conf.flash) for _ in range(n)]
        )
        self.cross_attn = nn.ModuleList(
            [CrossTransformer(d, h, conf.flash) for _ in range(n)]
        )
        self.log_assignment = nn.ModuleList([MatchAssignment(d) for _ in range(n)])
        self.token_confidence = nn.ModuleList(
            [TokenConfidence(d) for _ in range(n - 1)]
        )

        if features is not None:
            fname = f"{conf.weights}_{self.version}.pth".replace(".", "-")
            state_dict = torch.hub.load_state_dict_from_url(
                self.url.format(self.version, features), file_name=fname
            )
            self.load_state_dict(state_dict, strict=False)
        elif conf.weights is not None:
            path = Path(__file__).parent
            path = path / "weights/{}.pth".format(self.conf.weights)
            state_dict = torch.load(str(path), map_location="cpu")
            self.load_state_dict(state_dict, strict=False)

        print("Loaded LightGlue model")

    def forward(self, data: dict) -> dict:
        """
        Match keypoints and descriptors between two images

        Input (dict):
            image0: dict
                keypoints: [B x M x 2]
                descriptors: [B x M x D]
                image: [B x C x H x W] or image_size: [B x 2]
            image1: dict
                keypoints: [B x N x 2]
                descriptors: [B x N x D]
                image: [B x C x H x W] or image_size: [B x 2]
        Output (dict):
            log_assignment: [B x M+1 x N+1]
            matches0: [B x M]
            matching_scores0: [B x M]
            matches1: [B x N]
            matching_scores1: [B x N]
            matches: List[[Si x 2]], scores: List[[Si]]
        """
        with torch.autocast(enabled=self.conf.mp, device_type="cuda"):
            return self._forward(data)

    def _forward(self, data: dict) -> dict:
        for key in self.required_data_keys:
            assert key in data, f"Missing key {key} in data"
        data0, data1 = data["image0"], data["image1"]
        kpts0_, kpts1_ = data0["keypoints"], data1["keypoints"]
        b, m, _ = kpts0_.shape
        b, n, _ = kpts1_.shape
        size0, size1 = data0.get("image_size"), data1.get("image_size")
        size0 = size0 if size0 is not None else data0["image"].shape[-2:][::-1]
        size1 = size1 if size1 is not None else data1["image"].shape[-2:][::-1]
        kpts0 = normalize_keypoints(kpts0_, size=size0)
        kpts1 = normalize_keypoints(kpts1_, size=size1)

        assert torch.all(kpts0 >= -1) and torch.all(kpts0 <= 1)
        assert torch.all(kpts1 >= -1) and torch.all(kpts1 <= 1)

        desc0 = data0["descriptors"].detach()
        desc1 = data1["descriptors"].detach()

        assert desc0.shape[-1] == self.conf.input_dim
        assert desc1.shape[-1] == self.conf.input_dim

        if torch.is_autocast_enabled():
            desc0 = desc0.half()
            desc1 = desc1.half()

        desc0 = self.input_proj(desc0)
        desc1 = self.input_proj(desc1)

        # cache positional embeddings
        encoding0 = self.posenc(kpts0)
        encoding1 = self.posenc(kpts1)

        # GNN + final_proj + assignment
        ind0 = torch.arange(0, m).to(device=kpts0.device)[None]
        ind1 = torch.arange(0, n).to(device=kpts0.device)[None]
        prune0 = torch.ones_like(ind0)  # store layer where pruning is detected
        prune1 = torch.ones_like(ind1)
        dec, wic = self.conf.depth_confidence, self.conf.width_confidence
        token0, token1 = None, None
        for i in range(self.conf.n_layers):
            # self+cross attention
            desc0, desc1 = self.self_attn[i](desc0, desc1, encoding0, encoding1)
            desc0, desc1 = self.cross_attn[i](desc0, desc1)
            if i == self.conf.n_layers - 1:
                continue  # no early stopping or adaptive width at last layer
            if dec > 0:  # early stopping
                token0, token1 = self.token_confidence[i](desc0, desc1)
                if self.stop(token0, token1, self.conf_th(i), dec, m + n):
                    break
            if wic > 0:  # point pruning
                match0, match1 = self.log_assignment[i].scores(desc0, desc1)
                mask0 = self.get_mask(token0, match0, self.conf_th(i), 1 - wic)
                mask1 = self.get_mask(token1, match1, self.conf_th(i), 1 - wic)
                ind0, ind1 = ind0[mask0][None], ind1[mask1][None]
                desc0, desc1 = desc0[mask0][None], desc1[mask1][None]
                if desc0.shape[-2] == 0 or desc1.shape[-2] == 0:
                    break
                encoding0 = encoding0[:, :, mask0][:, None]
                encoding1 = encoding1[:, :, mask1][:, None]
            prune0[:, ind0] += 1
            prune1[:, ind1] += 1

        if wic > 0:  # scatter with indices after pruning
            scores_, _ = self.log_assignment[i](desc0, desc1)
            dt, dev = scores_.dtype, scores_.device
            scores = torch.zeros(b, m + 1, n + 1, dtype=dt, device=dev)
            scores[:, :-1, :-1] = -torch.inf
            scores[:, ind0[0], -1] = scores_[:, :-1, -1]
            scores[:, -1, ind1[0]] = scores_[:, -1, :-1]
            x, y = torch.meshgrid(ind0[0], ind1[0], indexing="ij")
            scores[:, x, y] = scores_[:, :-1, :-1]
        else:
            scores, _ = self.log_assignment[i](desc0, desc1)

        m0, m1, mscores0, mscores1 = filter_matches(scores, self.conf.filter_threshold)

        matches, mscores = [], []
        for k in range(b):
            valid = m0[k] > -1
            matches.append(torch.stack([torch.where(valid)[0], m0[k][valid]], -1))
            mscores.append(mscores0[k][valid])

        return {
            "log_assignment": scores,
            "matches0": m0,
            "matches1": m1,
            "matching_scores0": mscores0,
            "matching_scores1": mscores1,
            "stop": i + 1,
            "prune0": prune0,
            "prune1": prune1,
            "matches": matches,
            "scores": mscores,
        }

    def conf_th(self, i: int) -> float:
        """scaled confidence threshold"""
        return np.clip(0.8 + 0.1 * np.exp(-4.0 * i / self.conf.n_layers), 0, 1)

    def get_mask(
        self,
        confidence: torch.Tensor,
        match: torch.Tensor,
        conf_th: float,
        match_th: float,
    ) -> torch.Tensor:
        """mask points which should be removed"""
        if conf_th and confidence is not None:
            mask = (
                torch.where(confidence > conf_th, match, match.new_tensor(1.0))
                > match_th
            )
        else:
            mask = match > match_th
        return mask

    def stop(
        self,
        token0: torch.Tensor,
        token1: torch.Tensor,
        conf_th: float,
        inl_th: float,
        seql: int,
    ) -> torch.Tensor:
        """evaluate stopping condition"""
        tokens = torch.cat([token0, token1], -1)
        if conf_th:
            pos = 1.0 - (tokens < conf_th).float().sum() / seql
            return pos > inl_th
        else:
            return tokens.mean() > inl_th