Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import Optional | |
import numpy as np | |
def get_instance_root(keypoints: np.ndarray, | |
keypoints_visible: Optional[np.ndarray] = None, | |
root_type: str = 'kpt_center') -> np.ndarray: | |
"""Calculate the coordinates and visibility of instance roots. | |
Args: | |
keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D) | |
keypoints_visible (np.ndarray): Keypoint visibilities in shape | |
(N, K) | |
root_type (str): Calculation of instance roots which should | |
be one of the following options: | |
- ``'kpt_center'``: The roots' coordinates are the mean | |
coordinates of visible keypoints | |
- ``'bbox_center'``: The roots' are the center of bounding | |
boxes outlined by visible keypoints | |
Defaults to ``'kpt_center'`` | |
Returns: | |
tuple | |
- roots_coordinate(np.ndarray): Coordinates of instance roots in | |
shape [N, D] | |
- roots_visible(np.ndarray): Visibility of instance roots in | |
shape [N] | |
""" | |
roots_coordinate = np.zeros((keypoints.shape[0], 2), dtype=np.float32) | |
roots_visible = np.ones((keypoints.shape[0]), dtype=np.float32) * 2 | |
for i in range(keypoints.shape[0]): | |
# collect visible keypoints | |
if keypoints_visible is not None: | |
visible_keypoints = keypoints[i][keypoints_visible[i] > 0] | |
else: | |
visible_keypoints = keypoints[i] | |
if visible_keypoints.size == 0: | |
roots_visible[i] = 0 | |
continue | |
# compute the instance root with visible keypoints | |
if root_type == 'kpt_center': | |
roots_coordinate[i] = visible_keypoints.mean(axis=0) | |
roots_visible[i] = 1 | |
elif root_type == 'bbox_center': | |
roots_coordinate[i] = (visible_keypoints.max(axis=0) + | |
visible_keypoints.min(axis=0)) / 2.0 | |
roots_visible[i] = 1 | |
else: | |
raise ValueError( | |
f'the value of `root_type` must be \'kpt_center\' or ' | |
f'\'bbox_center\', but got \'{root_type}\'') | |
return roots_coordinate, roots_visible | |
def get_instance_bbox(keypoints: np.ndarray, | |
keypoints_visible: Optional[np.ndarray] = None | |
) -> np.ndarray: | |
"""Calculate the pseudo instance bounding box from visible keypoints. The | |
bounding boxes are in the xyxy format. | |
Args: | |
keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D) | |
keypoints_visible (np.ndarray): Keypoint visibilities in shape | |
(N, K) | |
Returns: | |
np.ndarray: bounding boxes in [N, 4] | |
""" | |
bbox = np.zeros((keypoints.shape[0], 4), dtype=np.float32) | |
for i in range(keypoints.shape[0]): | |
if keypoints_visible is not None: | |
visible_keypoints = keypoints[i][keypoints_visible[i] > 0] | |
else: | |
visible_keypoints = keypoints[i] | |
if visible_keypoints.size == 0: | |
continue | |
bbox[i, :2] = visible_keypoints.min(axis=0) | |
bbox[i, 2:] = visible_keypoints.max(axis=0) | |
return bbox | |
def get_diagonal_lengths(keypoints: np.ndarray, | |
keypoints_visible: Optional[np.ndarray] = None | |
) -> np.ndarray: | |
"""Calculate the diagonal length of instance bounding box from visible | |
keypoints. | |
Args: | |
keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D) | |
keypoints_visible (np.ndarray): Keypoint visibilities in shape | |
(N, K) | |
Returns: | |
np.ndarray: bounding box diagonal length in [N] | |
""" | |
pseudo_bbox = get_instance_bbox(keypoints, keypoints_visible) | |
pseudo_bbox = pseudo_bbox.reshape(-1, 2, 2) | |
h_w_diff = pseudo_bbox[:, 1] - pseudo_bbox[:, 0] | |
diagonal_length = np.sqrt(np.power(h_w_diff, 2).sum(axis=1)) | |
return diagonal_length | |