Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from itertools import zip_longest | |
from typing import Optional | |
from torch import Tensor | |
from mmpose.registry import MODELS | |
from mmpose.utils.typing import (ConfigType, InstanceList, OptConfigType, | |
OptMultiConfig, PixelDataList, SampleList) | |
from .base import BasePoseEstimator | |
class TopdownPoseEstimator(BasePoseEstimator): | |
"""Base class for top-down pose estimators. | |
Args: | |
backbone (dict): The backbone config | |
neck (dict, optional): The neck config. Defaults to ``None`` | |
head (dict, optional): The head config. Defaults to ``None`` | |
train_cfg (dict, optional): The runtime config for training process. | |
Defaults to ``None`` | |
test_cfg (dict, optional): The runtime config for testing process. | |
Defaults to ``None`` | |
data_preprocessor (dict, optional): The data preprocessing config to | |
build the instance of :class:`BaseDataPreprocessor`. Defaults to | |
``None`` | |
init_cfg (dict, optional): The config to control the initialization. | |
Defaults to ``None`` | |
metainfo (dict): Meta information for dataset, such as keypoints | |
definition and properties. If set, the metainfo of the input data | |
batch will be overridden. For more details, please refer to | |
https://mmpose.readthedocs.io/en/latest/user_guides/ | |
prepare_datasets.html#create-a-custom-dataset-info- | |
config-file-for-the-dataset. Defaults to ``None`` | |
""" | |
def __init__(self, | |
backbone: ConfigType, | |
neck: OptConfigType = None, | |
head: OptConfigType = None, | |
train_cfg: OptConfigType = None, | |
test_cfg: OptConfigType = None, | |
data_preprocessor: OptConfigType = None, | |
init_cfg: OptMultiConfig = None, | |
metainfo: Optional[dict] = None): | |
super().__init__( | |
backbone=backbone, | |
neck=neck, | |
head=head, | |
train_cfg=train_cfg, | |
test_cfg=test_cfg, | |
data_preprocessor=data_preprocessor, | |
init_cfg=init_cfg, | |
metainfo=metainfo) | |
def loss(self, inputs: Tensor, data_samples: SampleList) -> dict: | |
"""Calculate losses from a batch of inputs and data samples. | |
Args: | |
inputs (Tensor): Inputs with shape (N, C, H, W). | |
data_samples (List[:obj:`PoseDataSample`]): The batch | |
data samples. | |
Returns: | |
dict: A dictionary of losses. | |
""" | |
feats = self.extract_feat(inputs) | |
losses = dict() | |
if self.with_head: | |
losses.update( | |
self.head.loss(feats, data_samples, train_cfg=self.train_cfg)) | |
return losses | |
def predict(self, inputs: Tensor, data_samples: SampleList) -> SampleList: | |
"""Predict results from a batch of inputs and data samples with post- | |
processing. | |
Args: | |
inputs (Tensor): Inputs with shape (N, C, H, W) | |
data_samples (List[:obj:`PoseDataSample`]): The batch | |
data samples | |
Returns: | |
list[:obj:`PoseDataSample`]: The pose estimation results of the | |
input images. The return value is `PoseDataSample` instances with | |
``pred_instances`` and ``pred_fields``(optional) field , and | |
``pred_instances`` usually contains the following keys: | |
- keypoints (Tensor): predicted keypoint coordinates in shape | |
(num_instances, K, D) where K is the keypoint number and D | |
is the keypoint dimension | |
- keypoint_scores (Tensor): predicted keypoint scores in shape | |
(num_instances, K) | |
""" | |
assert self.with_head, ( | |
'The model must have head to perform prediction.') | |
if self.test_cfg.get('flip_test', False): | |
_feats = self.extract_feat(inputs) | |
_feats_flip = self.extract_feat(inputs.flip(-1)) | |
feats = [_feats, _feats_flip] | |
else: | |
feats = self.extract_feat(inputs) | |
preds = self.head.predict(feats, data_samples, test_cfg=self.test_cfg) | |
if isinstance(preds, tuple): | |
batch_pred_instances, batch_pred_fields = preds | |
else: | |
batch_pred_instances = preds | |
batch_pred_fields = None | |
results = self.add_pred_to_datasample(batch_pred_instances, | |
batch_pred_fields, data_samples) | |
return results | |
def add_pred_to_datasample(self, batch_pred_instances: InstanceList, | |
batch_pred_fields: Optional[PixelDataList], | |
batch_data_samples: SampleList) -> SampleList: | |
"""Add predictions into data samples. | |
Args: | |
batch_pred_instances (List[InstanceData]): The predicted instances | |
of the input data batch | |
batch_pred_fields (List[PixelData], optional): The predicted | |
fields (e.g. heatmaps) of the input batch | |
batch_data_samples (List[PoseDataSample]): The input data batch | |
Returns: | |
List[PoseDataSample]: A list of data samples where the predictions | |
are stored in the ``pred_instances`` field of each data sample. | |
""" | |
assert len(batch_pred_instances) == len(batch_data_samples) | |
if batch_pred_fields is None: | |
batch_pred_fields = [] | |
output_keypoint_indices = self.test_cfg.get('output_keypoint_indices', | |
None) | |
for pred_instances, pred_fields, data_sample in zip_longest( | |
batch_pred_instances, batch_pred_fields, batch_data_samples): | |
gt_instances = data_sample.gt_instances | |
# convert keypoint coordinates from input space to image space | |
bbox_centers = gt_instances.bbox_centers | |
bbox_scales = gt_instances.bbox_scales | |
input_size = data_sample.metainfo['input_size'] | |
pred_instances.keypoints = pred_instances.keypoints / input_size \ | |
* bbox_scales + bbox_centers - 0.5 * bbox_scales | |
if output_keypoint_indices is not None: | |
# select output keypoints with given indices | |
num_keypoints = pred_instances.keypoints.shape[1] | |
for key, value in pred_instances.all_items(): | |
if key.startswith('keypoint'): | |
pred_instances.set_field( | |
value[:, output_keypoint_indices], key) | |
# add bbox information into pred_instances | |
pred_instances.bboxes = gt_instances.bboxes | |
pred_instances.bbox_scores = gt_instances.bbox_scores | |
data_sample.pred_instances = pred_instances | |
if pred_fields is not None: | |
if output_keypoint_indices is not None: | |
# select output heatmap channels with keypoint indices | |
# when the number of heatmap channel matches num_keypoints | |
for key, value in pred_fields.all_items(): | |
if value.shape[0] != num_keypoints: | |
continue | |
pred_fields.set_field(value[output_keypoint_indices], | |
key) | |
data_sample.pred_fields = pred_fields | |
return batch_data_samples | |