Spaces:
Runtime error
Runtime error
File size: 7,629 Bytes
cc0dd3c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
# Copyright (c) OpenMMLab. All rights reserved.
from itertools import zip_longest
from typing import Optional
from torch import Tensor
from mmpose.registry import MODELS
from mmpose.utils.typing import (ConfigType, InstanceList, OptConfigType,
OptMultiConfig, PixelDataList, SampleList)
from .base import BasePoseEstimator
@MODELS.register_module()
class TopdownPoseEstimator(BasePoseEstimator):
"""Base class for top-down pose estimators.
Args:
backbone (dict): The backbone config
neck (dict, optional): The neck config. Defaults to ``None``
head (dict, optional): The head config. Defaults to ``None``
train_cfg (dict, optional): The runtime config for training process.
Defaults to ``None``
test_cfg (dict, optional): The runtime config for testing process.
Defaults to ``None``
data_preprocessor (dict, optional): The data preprocessing config to
build the instance of :class:`BaseDataPreprocessor`. Defaults to
``None``
init_cfg (dict, optional): The config to control the initialization.
Defaults to ``None``
metainfo (dict): Meta information for dataset, such as keypoints
definition and properties. If set, the metainfo of the input data
batch will be overridden. For more details, please refer to
https://mmpose.readthedocs.io/en/latest/user_guides/
prepare_datasets.html#create-a-custom-dataset-info-
config-file-for-the-dataset. Defaults to ``None``
"""
def __init__(self,
backbone: ConfigType,
neck: OptConfigType = None,
head: OptConfigType = None,
train_cfg: OptConfigType = None,
test_cfg: OptConfigType = None,
data_preprocessor: OptConfigType = None,
init_cfg: OptMultiConfig = None,
metainfo: Optional[dict] = None):
super().__init__(
backbone=backbone,
neck=neck,
head=head,
train_cfg=train_cfg,
test_cfg=test_cfg,
data_preprocessor=data_preprocessor,
init_cfg=init_cfg,
metainfo=metainfo)
def loss(self, inputs: Tensor, data_samples: SampleList) -> dict:
"""Calculate losses from a batch of inputs and data samples.
Args:
inputs (Tensor): Inputs with shape (N, C, H, W).
data_samples (List[:obj:`PoseDataSample`]): The batch
data samples.
Returns:
dict: A dictionary of losses.
"""
feats = self.extract_feat(inputs)
losses = dict()
if self.with_head:
losses.update(
self.head.loss(feats, data_samples, train_cfg=self.train_cfg))
return losses
def predict(self, inputs: Tensor, data_samples: SampleList) -> SampleList:
"""Predict results from a batch of inputs and data samples with post-
processing.
Args:
inputs (Tensor): Inputs with shape (N, C, H, W)
data_samples (List[:obj:`PoseDataSample`]): The batch
data samples
Returns:
list[:obj:`PoseDataSample`]: The pose estimation results of the
input images. The return value is `PoseDataSample` instances with
``pred_instances`` and ``pred_fields``(optional) field , and
``pred_instances`` usually contains the following keys:
- keypoints (Tensor): predicted keypoint coordinates in shape
(num_instances, K, D) where K is the keypoint number and D
is the keypoint dimension
- keypoint_scores (Tensor): predicted keypoint scores in shape
(num_instances, K)
"""
assert self.with_head, (
'The model must have head to perform prediction.')
if self.test_cfg.get('flip_test', False):
_feats = self.extract_feat(inputs)
_feats_flip = self.extract_feat(inputs.flip(-1))
feats = [_feats, _feats_flip]
else:
feats = self.extract_feat(inputs)
preds = self.head.predict(feats, data_samples, test_cfg=self.test_cfg)
if isinstance(preds, tuple):
batch_pred_instances, batch_pred_fields = preds
else:
batch_pred_instances = preds
batch_pred_fields = None
results = self.add_pred_to_datasample(batch_pred_instances,
batch_pred_fields, data_samples)
return results
def add_pred_to_datasample(self, batch_pred_instances: InstanceList,
batch_pred_fields: Optional[PixelDataList],
batch_data_samples: SampleList) -> SampleList:
"""Add predictions into data samples.
Args:
batch_pred_instances (List[InstanceData]): The predicted instances
of the input data batch
batch_pred_fields (List[PixelData], optional): The predicted
fields (e.g. heatmaps) of the input batch
batch_data_samples (List[PoseDataSample]): The input data batch
Returns:
List[PoseDataSample]: A list of data samples where the predictions
are stored in the ``pred_instances`` field of each data sample.
"""
assert len(batch_pred_instances) == len(batch_data_samples)
if batch_pred_fields is None:
batch_pred_fields = []
output_keypoint_indices = self.test_cfg.get('output_keypoint_indices',
None)
for pred_instances, pred_fields, data_sample in zip_longest(
batch_pred_instances, batch_pred_fields, batch_data_samples):
gt_instances = data_sample.gt_instances
# convert keypoint coordinates from input space to image space
bbox_centers = gt_instances.bbox_centers
bbox_scales = gt_instances.bbox_scales
input_size = data_sample.metainfo['input_size']
pred_instances.keypoints = pred_instances.keypoints / input_size \
* bbox_scales + bbox_centers - 0.5 * bbox_scales
if output_keypoint_indices is not None:
# select output keypoints with given indices
num_keypoints = pred_instances.keypoints.shape[1]
for key, value in pred_instances.all_items():
if key.startswith('keypoint'):
pred_instances.set_field(
value[:, output_keypoint_indices], key)
# add bbox information into pred_instances
pred_instances.bboxes = gt_instances.bboxes
pred_instances.bbox_scores = gt_instances.bbox_scores
data_sample.pred_instances = pred_instances
if pred_fields is not None:
if output_keypoint_indices is not None:
# select output heatmap channels with keypoint indices
# when the number of heatmap channel matches num_keypoints
for key, value in pred_fields.all_items():
if value.shape[0] != num_keypoints:
continue
pred_fields.set_field(value[output_keypoint_indices],
key)
data_sample.pred_fields = pred_fields
return batch_data_samples
|