liyy201912's picture
Upload folder using huggingface_hub
cc0dd3c
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod
from typing import Any, List, Optional, Tuple
import numpy as np
from mmengine.utils import is_method_overridden
class BaseKeypointCodec(metaclass=ABCMeta):
"""The base class of the keypoint codec.
A keypoint codec is a module to encode keypoint coordinates to specific
representation (e.g. heatmap) and vice versa. A subclass should implement
the methods :meth:`encode` and :meth:`decode`.
"""
# pass additional encoding arguments to the `encode` method, beyond the
# mandatory `keypoints` and `keypoints_visible` arguments.
auxiliary_encode_keys = set()
@abstractmethod
def encode(self,
keypoints: np.ndarray,
keypoints_visible: Optional[np.ndarray] = None) -> dict:
"""Encode keypoints.
Note:
- instance number: N
- keypoint number: K
- keypoint dimension: D
Args:
keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
keypoints_visible (np.ndarray): Keypoint visibility in shape
(N, K, D)
Returns:
dict: Encoded items.
"""
@abstractmethod
def decode(self, encoded: Any) -> Tuple[np.ndarray, np.ndarray]:
"""Decode keypoints.
Args:
encoded (any): Encoded keypoint representation using the codec
Returns:
tuple:
- keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
- keypoints_visible (np.ndarray): Keypoint visibility in shape
(N, K, D)
"""
def batch_decode(self, batch_encoded: Any
) -> Tuple[List[np.ndarray], List[np.ndarray]]:
"""Decode keypoints.
Args:
batch_encoded (any): A batch of encoded keypoint
representations
Returns:
tuple:
- batch_keypoints (List[np.ndarray]): Each element is keypoint
coordinates in shape (N, K, D)
- batch_keypoints (List[np.ndarray]): Each element is keypoint
visibility in shape (N, K)
"""
raise NotImplementedError()
@property
def support_batch_decoding(self) -> bool:
"""Return whether the codec support decoding from batch data."""
return is_method_overridden('batch_decode', BaseKeypointCodec,
self.__class__)