File size: 4,017 Bytes
cc0dd3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional

import numpy as np


def get_instance_root(keypoints: np.ndarray,
                      keypoints_visible: Optional[np.ndarray] = None,
                      root_type: str = 'kpt_center') -> np.ndarray:
    """Calculate the coordinates and visibility of instance roots.

    Args:
        keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
        keypoints_visible (np.ndarray): Keypoint visibilities in shape
            (N, K)
        root_type (str): Calculation of instance roots which should
            be one of the following options:

                - ``'kpt_center'``: The roots' coordinates are the mean
                    coordinates of visible keypoints
                - ``'bbox_center'``: The roots' are the center of bounding
                    boxes outlined by visible keypoints

            Defaults to ``'kpt_center'``

    Returns:
        tuple
        - roots_coordinate(np.ndarray): Coordinates of instance roots in
            shape [N, D]
        - roots_visible(np.ndarray): Visibility of instance roots in
            shape [N]
    """

    roots_coordinate = np.zeros((keypoints.shape[0], 2), dtype=np.float32)
    roots_visible = np.ones((keypoints.shape[0]), dtype=np.float32) * 2

    for i in range(keypoints.shape[0]):

        # collect visible keypoints
        if keypoints_visible is not None:
            visible_keypoints = keypoints[i][keypoints_visible[i] > 0]
        else:
            visible_keypoints = keypoints[i]
        if visible_keypoints.size == 0:
            roots_visible[i] = 0
            continue

        # compute the instance root with visible keypoints
        if root_type == 'kpt_center':
            roots_coordinate[i] = visible_keypoints.mean(axis=0)
            roots_visible[i] = 1
        elif root_type == 'bbox_center':
            roots_coordinate[i] = (visible_keypoints.max(axis=0) +
                                   visible_keypoints.min(axis=0)) / 2.0
            roots_visible[i] = 1
        else:
            raise ValueError(
                f'the value of `root_type` must be \'kpt_center\' or '
                f'\'bbox_center\', but got \'{root_type}\'')

    return roots_coordinate, roots_visible


def get_instance_bbox(keypoints: np.ndarray,
                      keypoints_visible: Optional[np.ndarray] = None
                      ) -> np.ndarray:
    """Calculate the pseudo instance bounding box from visible keypoints. The
    bounding boxes are in the xyxy format.

    Args:
        keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
        keypoints_visible (np.ndarray): Keypoint visibilities in shape
            (N, K)

    Returns:
        np.ndarray: bounding boxes in [N, 4]
    """
    bbox = np.zeros((keypoints.shape[0], 4), dtype=np.float32)
    for i in range(keypoints.shape[0]):
        if keypoints_visible is not None:
            visible_keypoints = keypoints[i][keypoints_visible[i] > 0]
        else:
            visible_keypoints = keypoints[i]
        if visible_keypoints.size == 0:
            continue

        bbox[i, :2] = visible_keypoints.min(axis=0)
        bbox[i, 2:] = visible_keypoints.max(axis=0)
    return bbox


def get_diagonal_lengths(keypoints: np.ndarray,
                         keypoints_visible: Optional[np.ndarray] = None
                         ) -> np.ndarray:
    """Calculate the diagonal length of instance bounding box from visible
    keypoints.

    Args:
        keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
        keypoints_visible (np.ndarray): Keypoint visibilities in shape
            (N, K)

    Returns:
        np.ndarray: bounding box diagonal length in [N]
    """
    pseudo_bbox = get_instance_bbox(keypoints, keypoints_visible)
    pseudo_bbox = pseudo_bbox.reshape(-1, 2, 2)
    h_w_diff = pseudo_bbox[:, 1] - pseudo_bbox[:, 0]
    diagonal_length = np.sqrt(np.power(h_w_diff, 2).sum(axis=1))

    return diagonal_length