Spaces:
Runtime error
Runtime error
File size: 6,959 Bytes
cc0dd3c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
# Copyright (c) OpenMMLab. All rights reserved.
from itertools import zip_longest
from typing import List, Optional, Union
from mmengine.utils import is_list_of
from torch import Tensor
from mmpose.registry import MODELS
from mmpose.utils.typing import (ConfigType, InstanceList, OptConfigType,
OptMultiConfig, PixelDataList, SampleList)
from .base import BasePoseEstimator
@MODELS.register_module()
class BottomupPoseEstimator(BasePoseEstimator):
"""Base class for bottom-up pose estimators.
Args:
backbone (dict): The backbone config
neck (dict, optional): The neck config. Defaults to ``None``
head (dict, optional): The head config. Defaults to ``None``
train_cfg (dict, optional): The runtime config for training process.
Defaults to ``None``
test_cfg (dict, optional): The runtime config for testing process.
Defaults to ``None``
data_preprocessor (dict, optional): The data preprocessing config to
build the instance of :class:`BaseDataPreprocessor`. Defaults to
``None``.
init_cfg (dict, optional): The config to control the initialization.
Defaults to ``None``
"""
def __init__(self,
backbone: ConfigType,
neck: OptConfigType = None,
head: OptConfigType = None,
train_cfg: OptConfigType = None,
test_cfg: OptConfigType = None,
data_preprocessor: OptConfigType = None,
init_cfg: OptMultiConfig = None):
super().__init__(
backbone=backbone,
neck=neck,
head=head,
train_cfg=train_cfg,
test_cfg=test_cfg,
data_preprocessor=data_preprocessor,
init_cfg=init_cfg)
def loss(self, inputs: Tensor, data_samples: SampleList) -> dict:
"""Calculate losses from a batch of inputs and data samples.
Args:
inputs (Tensor): Inputs with shape (N, C, H, W).
data_samples (List[:obj:`PoseDataSample`]): The batch
data samples.
Returns:
dict: A dictionary of losses.
"""
feats = self.extract_feat(inputs)
losses = dict()
if self.with_head:
losses.update(
self.head.loss(feats, data_samples, train_cfg=self.train_cfg))
return losses
def predict(self, inputs: Union[Tensor, List[Tensor]],
data_samples: SampleList) -> SampleList:
"""Predict results from a batch of inputs and data samples with post-
processing.
Args:
inputs (Tensor | List[Tensor]): Input image in tensor or image
pyramid as a list of tensors. Each tensor is in shape
[B, C, H, W]
data_samples (List[:obj:`PoseDataSample`]): The batch
data samples
Returns:
list[:obj:`PoseDataSample`]: The pose estimation results of the
input images. The return value is `PoseDataSample` instances with
``pred_instances`` and ``pred_fields``(optional) field , and
``pred_instances`` usually contains the following keys:
- keypoints (Tensor): predicted keypoint coordinates in shape
(num_instances, K, D) where K is the keypoint number and D
is the keypoint dimension
- keypoint_scores (Tensor): predicted keypoint scores in shape
(num_instances, K)
"""
assert self.with_head, (
'The model must have head to perform prediction.')
multiscale_test = self.test_cfg.get('multiscale_test', False)
flip_test = self.test_cfg.get('flip_test', False)
# enable multi-scale test
aug_scales = data_samples[0].metainfo.get('aug_scales', None)
if multiscale_test:
assert isinstance(aug_scales, list)
assert is_list_of(inputs, Tensor)
# `inputs` includes images in original and augmented scales
assert len(inputs) == len(aug_scales) + 1
else:
assert isinstance(inputs, Tensor)
# single-scale test
inputs = [inputs]
feats = []
for _inputs in inputs:
if flip_test:
_feats_orig = self.extract_feat(_inputs)
_feats_flip = self.extract_feat(_inputs.flip(-1))
_feats = [_feats_orig, _feats_flip]
else:
_feats = self.extract_feat(_inputs)
feats.append(_feats)
if not multiscale_test:
feats = feats[0]
preds = self.head.predict(feats, data_samples, test_cfg=self.test_cfg)
if isinstance(preds, tuple):
batch_pred_instances, batch_pred_fields = preds
else:
batch_pred_instances = preds
batch_pred_fields = None
results = self.add_pred_to_datasample(batch_pred_instances,
batch_pred_fields, data_samples)
return results
def add_pred_to_datasample(self, batch_pred_instances: InstanceList,
batch_pred_fields: Optional[PixelDataList],
batch_data_samples: SampleList) -> SampleList:
"""Add predictions into data samples.
Args:
batch_pred_instances (List[InstanceData]): The predicted instances
of the input data batch
batch_pred_fields (List[PixelData], optional): The predicted
fields (e.g. heatmaps) of the input batch
batch_data_samples (List[PoseDataSample]): The input data batch
Returns:
List[PoseDataSample]: A list of data samples where the predictions
are stored in the ``pred_instances`` field of each data sample.
The length of the list is the batch size when ``merge==False``, or
1 when ``merge==True``.
"""
assert len(batch_pred_instances) == len(batch_data_samples)
if batch_pred_fields is None:
batch_pred_fields = []
for pred_instances, pred_fields, data_sample in zip_longest(
batch_pred_instances, batch_pred_fields, batch_data_samples):
# convert keypoint coordinates from input space to image space
input_size = data_sample.metainfo['input_size']
input_center = data_sample.metainfo['input_center']
input_scale = data_sample.metainfo['input_scale']
pred_instances.keypoints = pred_instances.keypoints / input_size \
* input_scale + input_center - 0.5 * input_scale
data_sample.pred_instances = pred_instances
if pred_fields is not None:
data_sample.pred_fields = pred_fields
return batch_data_samples
|