File size: 2,505 Bytes
cc0dd3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod
from typing import Any, List, Optional, Tuple

import numpy as np
from mmengine.utils import is_method_overridden


class BaseKeypointCodec(metaclass=ABCMeta):
    """The base class of the keypoint codec.

    A keypoint codec is a module to encode keypoint coordinates to specific
    representation (e.g. heatmap) and vice versa. A subclass should implement
    the methods :meth:`encode` and :meth:`decode`.
    """

    # pass additional encoding arguments to the `encode` method, beyond the
    # mandatory `keypoints` and `keypoints_visible` arguments.
    auxiliary_encode_keys = set()

    @abstractmethod
    def encode(self,
               keypoints: np.ndarray,
               keypoints_visible: Optional[np.ndarray] = None) -> dict:
        """Encode keypoints.

        Note:

            - instance number: N
            - keypoint number: K
            - keypoint dimension: D

        Args:
            keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
            keypoints_visible (np.ndarray): Keypoint visibility in shape
                (N, K, D)

        Returns:
            dict: Encoded items.
        """

    @abstractmethod
    def decode(self, encoded: Any) -> Tuple[np.ndarray, np.ndarray]:
        """Decode keypoints.

        Args:
            encoded (any): Encoded keypoint representation using the codec

        Returns:
            tuple:
            - keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
            - keypoints_visible (np.ndarray): Keypoint visibility in shape
                (N, K, D)
        """

    def batch_decode(self, batch_encoded: Any
                     ) -> Tuple[List[np.ndarray], List[np.ndarray]]:
        """Decode keypoints.

        Args:
            batch_encoded (any): A batch of encoded keypoint
                representations

        Returns:
            tuple:
            - batch_keypoints (List[np.ndarray]): Each element is keypoint
                coordinates in shape (N, K, D)
            - batch_keypoints (List[np.ndarray]): Each element is keypoint
                visibility in shape (N, K)
        """
        raise NotImplementedError()

    @property
    def support_batch_decoding(self) -> bool:
        """Return whether the codec support decoding from batch data."""
        return is_method_overridden('batch_decode', BaseKeypointCodec,
                                    self.__class__)