# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Tuple, Union

import numpy as np
import torch
from torch import Tensor

from mmpose.registry import KEYPOINT_CODECS
from .base import BaseKeypointCodec
from .utils import (batch_heatmap_nms, generate_displacement_heatmap,
                    generate_gaussian_heatmaps, get_diagonal_lengths,
                    get_instance_root)


@KEYPOINT_CODECS.register_module()
class SPR(BaseKeypointCodec):
    """Encode/decode keypoints with Structured Pose Representation (SPR).

    See the paper `Single-stage multi-person pose machines`_
    by Nie et al (2017) for details

    Note:

        - instance number: N
        - keypoint number: K
        - keypoint dimension: D
        - image size: [w, h]
        - heatmap size: [W, H]

    Encoded:

        - heatmaps (np.ndarray): The generated heatmap in shape (1, H, W)
            where [W, H] is the `heatmap_size`. If the keypoint heatmap is
            generated together, the output heatmap shape is (K+1, H, W)
        - heatmap_weights (np.ndarray): The target weights for heatmaps which
            has same shape with heatmaps.
        - displacements (np.ndarray): The dense keypoint displacement in
            shape (K*2, H, W).
        - displacement_weights (np.ndarray): The target weights for heatmaps
            which has same shape with displacements.

    Args:
        input_size (tuple): Image size in [w, h]
        heatmap_size (tuple): Heatmap size in [W, H]
        sigma (float or tuple, optional): The sigma values of the Gaussian
            heatmaps. If sigma is a tuple, it includes both sigmas for root
            and keypoint heatmaps. ``None`` means the sigmas are computed
            automatically from the heatmap size. Defaults to ``None``
        generate_keypoint_heatmaps (bool): Whether to generate Gaussian
            heatmaps for each keypoint. Defaults to ``False``
        root_type (str): The method to generate the instance root. Options
            are:

            - ``'kpt_center'``: Average coordinate of all visible keypoints.
            - ``'bbox_center'``: Center point of bounding boxes outlined by
                all visible keypoints.

            Defaults to ``'kpt_center'``

        minimal_diagonal_length (int or float): The threshold of diagonal
            length of instance bounding box. Small instances will not be
            used in training. Defaults to 32
        background_weight (float): Loss weight of background pixels.
            Defaults to 0.1
        decode_thr (float): The threshold of keypoint response value in
            heatmaps. Defaults to 0.01
        decode_nms_kernel (int): The kernel size of the NMS during decoding,
            which should be an odd integer. Defaults to 5
        decode_max_instances (int): The maximum number of instances
            to decode. Defaults to 30

    .. _`Single-stage multi-person pose machines`:
        https://arxiv.org/abs/1908.09220
    """

    def __init__(
        self,
        input_size: Tuple[int, int],
        heatmap_size: Tuple[int, int],
        sigma: Optional[Union[float, Tuple[float]]] = None,
        generate_keypoint_heatmaps: bool = False,
        root_type: str = 'kpt_center',
        minimal_diagonal_length: Union[int, float] = 5,
        background_weight: float = 0.1,
        decode_nms_kernel: int = 5,
        decode_max_instances: int = 30,
        decode_thr: float = 0.01,
    ):
        super().__init__()

        self.input_size = input_size
        self.heatmap_size = heatmap_size
        self.generate_keypoint_heatmaps = generate_keypoint_heatmaps
        self.root_type = root_type
        self.minimal_diagonal_length = minimal_diagonal_length
        self.background_weight = background_weight
        self.decode_nms_kernel = decode_nms_kernel
        self.decode_max_instances = decode_max_instances
        self.decode_thr = decode_thr

        self.scale_factor = (np.array(input_size) /
                             heatmap_size).astype(np.float32)

        if sigma is None:
            sigma = (heatmap_size[0] * heatmap_size[1])**0.5 / 32
            if generate_keypoint_heatmaps:
                # sigma for root heatmap and keypoint heatmaps
                self.sigma = (sigma, sigma // 2)
            else:
                self.sigma = (sigma, )
        else:
            if not isinstance(sigma, (tuple, list)):
                sigma = (sigma, )
            if generate_keypoint_heatmaps:
                assert len(sigma) == 2, 'sigma for keypoints must be given ' \
                                        'if `generate_keypoint_heatmaps` ' \
                                        'is True. e.g. sigma=(4, 2)'
            self.sigma = sigma

    def _get_heatmap_weights(self,
                             heatmaps,
                             fg_weight: float = 1,
                             bg_weight: float = 0):
        """Generate weight array for heatmaps.

        Args:
            heatmaps (np.ndarray): Root and keypoint (optional) heatmaps
            fg_weight (float): Weight for foreground pixels. Defaults to 1.0
            bg_weight (float): Weight for background pixels. Defaults to 0.0

        Returns:
            np.ndarray: Heatmap weight array in the same shape with heatmaps
        """
        heatmap_weights = np.ones(heatmaps.shape) * bg_weight
        heatmap_weights[heatmaps > 0] = fg_weight
        return heatmap_weights

    def encode(self,
               keypoints: np.ndarray,
               keypoints_visible: Optional[np.ndarray] = None) -> dict:
        """Encode keypoints into root heatmaps and keypoint displacement
        fields. Note that the original keypoint coordinates should be in the
        input image space.

        Args:
            keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
            keypoints_visible (np.ndarray): Keypoint visibilities in shape
                (N, K)

        Returns:
            dict:
            - heatmaps (np.ndarray): The generated heatmap in shape
                (1, H, W) where [W, H] is the `heatmap_size`. If keypoint
                heatmaps are generated together, the shape is (K+1, H, W)
            - heatmap_weights (np.ndarray): The pixel-wise weight for heatmaps
                 which has same shape with `heatmaps`
            - displacements (np.ndarray): The generated displacement fields in
                shape (K*D, H, W). The vector on each pixels represents the
                displacement of keypoints belong to the associated instance
                from this pixel.
            - displacement_weights (np.ndarray): The pixel-wise weight for
                displacements which has same shape with `displacements`
        """

        if keypoints_visible is None:
            keypoints_visible = np.ones(keypoints.shape[:2], dtype=np.float32)

        # keypoint coordinates in heatmap
        _keypoints = keypoints / self.scale_factor

        # compute the root and scale of each instance
        roots, roots_visible = get_instance_root(_keypoints, keypoints_visible,
                                                 self.root_type)
        diagonal_lengths = get_diagonal_lengths(_keypoints, keypoints_visible)

        # discard the small instances
        roots_visible[diagonal_lengths < self.minimal_diagonal_length] = 0

        # generate heatmaps
        heatmaps, _ = generate_gaussian_heatmaps(
            heatmap_size=self.heatmap_size,
            keypoints=roots[:, None],
            keypoints_visible=roots_visible[:, None],
            sigma=self.sigma[0])
        heatmap_weights = self._get_heatmap_weights(
            heatmaps, bg_weight=self.background_weight)

        if self.generate_keypoint_heatmaps:
            keypoint_heatmaps, _ = generate_gaussian_heatmaps(
                heatmap_size=self.heatmap_size,
                keypoints=_keypoints,
                keypoints_visible=keypoints_visible,
                sigma=self.sigma[1])

            keypoint_heatmaps_weights = self._get_heatmap_weights(
                keypoint_heatmaps, bg_weight=self.background_weight)

            heatmaps = np.concatenate((keypoint_heatmaps, heatmaps), axis=0)
            heatmap_weights = np.concatenate(
                (keypoint_heatmaps_weights, heatmap_weights), axis=0)

        # generate displacements
        displacements, displacement_weights = \
            generate_displacement_heatmap(
                self.heatmap_size,
                _keypoints,
                keypoints_visible,
                roots,
                roots_visible,
                diagonal_lengths,
                self.sigma[0],
            )

        encoded = dict(
            heatmaps=heatmaps,
            heatmap_weights=heatmap_weights,
            displacements=displacements,
            displacement_weights=displacement_weights)

        return encoded

    def decode(self, heatmaps: Tensor,
               displacements: Tensor) -> Tuple[np.ndarray, np.ndarray]:
        """Decode the keypoint coordinates from heatmaps and displacements. The
        decoded keypoint coordinates are in the input image space.

        Args:
            heatmaps (Tensor): Encoded root and keypoints (optional) heatmaps
                in shape (1, H, W) or (K+1, H, W)
            displacements (Tensor): Encoded keypoints displacement fields
                in shape (K*D, H, W)

        Returns:
            tuple:
            - keypoints (Tensor): Decoded keypoint coordinates in shape
                (N, K, D)
            - scores (tuple):
                - root_scores (Tensor): The root scores in shape (N, )
                - keypoint_scores (Tensor): The keypoint scores in
                    shape (N, K). If keypoint heatmaps are not generated,
                    `keypoint_scores` will be `None`
        """
        # heatmaps, displacements = encoded
        _k, h, w = displacements.shape
        k = _k // 2
        displacements = displacements.view(k, 2, h, w)

        # convert displacements to a dense keypoint prediction
        y, x = torch.meshgrid(torch.arange(h), torch.arange(w))
        regular_grid = torch.stack([x, y], dim=0).to(displacements)
        posemaps = (regular_grid[None] + displacements).flatten(2)

        # find local maximum on root heatmap
        root_heatmap_peaks = batch_heatmap_nms(heatmaps[None, -1:],
                                               self.decode_nms_kernel)
        root_scores, pos_idx = root_heatmap_peaks.flatten().topk(
            self.decode_max_instances)
        mask = root_scores > self.decode_thr
        root_scores, pos_idx = root_scores[mask], pos_idx[mask]

        keypoints = posemaps[:, :, pos_idx].permute(2, 0, 1).contiguous()

        if self.generate_keypoint_heatmaps and heatmaps.shape[0] == 1 + k:
            # compute scores for each keypoint
            keypoint_scores = self.get_keypoint_scores(heatmaps[:k], keypoints)
        else:
            keypoint_scores = None

        keypoints = torch.cat([
            kpt * self.scale_factor[i]
            for i, kpt in enumerate(keypoints.split(1, -1))
        ],
                              dim=-1)
        return keypoints, (root_scores, keypoint_scores)

    def get_keypoint_scores(self, heatmaps: Tensor, keypoints: Tensor):
        """Calculate the keypoint scores with keypoints heatmaps and
        coordinates.

        Args:
            heatmaps (Tensor): Keypoint heatmaps in shape (K, H, W)
            keypoints (Tensor): Keypoint coordinates in shape (N, K, D)

        Returns:
            Tensor: Keypoint scores in [N, K]
        """
        k, h, w = heatmaps.shape
        keypoints = torch.stack((
            keypoints[..., 0] / (w - 1) * 2 - 1,
            keypoints[..., 1] / (h - 1) * 2 - 1,
        ),
                                dim=-1)
        keypoints = keypoints.transpose(0, 1).unsqueeze(1).contiguous()

        keypoint_scores = torch.nn.functional.grid_sample(
            heatmaps.unsqueeze(1), keypoints,
            padding_mode='border').view(k, -1).transpose(0, 1).contiguous()

        return keypoint_scores