Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import List, Optional, Sequence, Union | |
import torch | |
from mmcv.cnn import build_conv_layer, build_upsample_layer | |
from mmengine.structures import PixelData | |
from torch import Tensor, nn | |
from mmpose.evaluation.functional import pose_pck_accuracy | |
from mmpose.models.utils.tta import flip_heatmaps | |
from mmpose.registry import KEYPOINT_CODECS, MODELS | |
from mmpose.utils.tensor_utils import to_numpy | |
from mmpose.utils.typing import (Features, MultiConfig, OptConfigType, | |
OptSampleList, Predictions) | |
from ..base_head import BaseHead | |
OptIntSeq = Optional[Sequence[int]] | |
class CPMHead(BaseHead): | |
"""Multi-stage heatmap head introduced in `Convolutional Pose Machines`_ by | |
Wei et al (2016) and used by `Stacked Hourglass Networks`_ by Newell et al | |
(2016). The head consists of multiple branches, each of which has some | |
deconv layers and a simple conv2d layer. | |
Args: | |
in_channels (int | Sequence[int]): Number of channels in the input | |
feature maps. | |
out_channels (int): Number of channels in the output heatmaps. | |
num_stages (int): Number of stages. | |
deconv_out_channels (Sequence[int], optional): The output channel | |
number of each deconv layer. Defaults to ``(256, 256, 256)`` | |
deconv_kernel_sizes (Sequence[int | tuple], optional): The kernel size | |
of each deconv layer. Each element should be either an integer for | |
both height and width dimensions, or a tuple of two integers for | |
the height and the width dimension respectively. | |
Defaults to ``(4, 4, 4)`` | |
final_layer (dict): Arguments of the final Conv2d layer. | |
Defaults to ``dict(kernel_size=1)`` | |
loss (Config | List[Config]): Config of the keypoint loss of different | |
stages. Defaults to use :class:`KeypointMSELoss`. | |
decoder (Config, optional): The decoder config that controls decoding | |
keypoint coordinates from the network output. Defaults to ``None`` | |
init_cfg (Config, optional): Config to control the initialization. See | |
:attr:`default_init_cfg` for default settings | |
.. _`Convolutional Pose Machines`: https://arxiv.org/abs/1602.00134 | |
.. _`Stacked Hourglass Networks`: https://arxiv.org/abs/1603.06937 | |
""" | |
_version = 2 | |
def __init__(self, | |
in_channels: Union[int, Sequence[int]], | |
out_channels: int, | |
num_stages: int, | |
deconv_out_channels: OptIntSeq = None, | |
deconv_kernel_sizes: OptIntSeq = None, | |
final_layer: dict = dict(kernel_size=1), | |
loss: MultiConfig = dict( | |
type='KeypointMSELoss', use_target_weight=True), | |
decoder: OptConfigType = None, | |
init_cfg: OptConfigType = None): | |
if init_cfg is None: | |
init_cfg = self.default_init_cfg | |
super().__init__(init_cfg) | |
self.num_stages = num_stages | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
if isinstance(loss, list): | |
if len(loss) != num_stages: | |
raise ValueError( | |
f'The length of loss_module({len(loss)}) did not match ' | |
f'`num_stages`({num_stages})') | |
self.loss_module = nn.ModuleList( | |
MODELS.build(_loss) for _loss in loss) | |
else: | |
self.loss_module = MODELS.build(loss) | |
if decoder is not None: | |
self.decoder = KEYPOINT_CODECS.build(decoder) | |
else: | |
self.decoder = None | |
# build multi-stage deconv layers | |
self.multi_deconv_layers = nn.ModuleList([]) | |
if deconv_out_channels: | |
if deconv_kernel_sizes is None or len(deconv_out_channels) != len( | |
deconv_kernel_sizes): | |
raise ValueError( | |
'"deconv_out_channels" and "deconv_kernel_sizes" should ' | |
'be integer sequences with the same length. Got ' | |
f'mismatched lengths {deconv_out_channels} and ' | |
f'{deconv_kernel_sizes}') | |
for _ in range(self.num_stages): | |
deconv_layers = self._make_deconv_layers( | |
in_channels=in_channels, | |
layer_out_channels=deconv_out_channels, | |
layer_kernel_sizes=deconv_kernel_sizes, | |
) | |
self.multi_deconv_layers.append(deconv_layers) | |
in_channels = deconv_out_channels[-1] | |
else: | |
for _ in range(self.num_stages): | |
self.multi_deconv_layers.append(nn.Identity()) | |
# build multi-stage final layers | |
self.multi_final_layers = nn.ModuleList([]) | |
if final_layer is not None: | |
cfg = dict( | |
type='Conv2d', | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=1) | |
cfg.update(final_layer) | |
for _ in range(self.num_stages): | |
self.multi_final_layers.append(build_conv_layer(cfg)) | |
else: | |
for _ in range(self.num_stages): | |
self.multi_final_layers.append(nn.Identity()) | |
def default_init_cfg(self): | |
init_cfg = [ | |
dict( | |
type='Normal', layer=['Conv2d', 'ConvTranspose2d'], std=0.001), | |
dict(type='Constant', layer='BatchNorm2d', val=1) | |
] | |
return init_cfg | |
def _make_deconv_layers(self, in_channels: int, | |
layer_out_channels: Sequence[int], | |
layer_kernel_sizes: Sequence[int]) -> nn.Module: | |
"""Create deconvolutional layers by given parameters.""" | |
layers = [] | |
for out_channels, kernel_size in zip(layer_out_channels, | |
layer_kernel_sizes): | |
if kernel_size == 4: | |
padding = 1 | |
output_padding = 0 | |
elif kernel_size == 3: | |
padding = 1 | |
output_padding = 1 | |
elif kernel_size == 2: | |
padding = 0 | |
output_padding = 0 | |
else: | |
raise ValueError(f'Unsupported kernel size {kernel_size} for' | |
'deconvlutional layers in ' | |
f'{self.__class__.__name__}') | |
cfg = dict( | |
type='deconv', | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=kernel_size, | |
stride=2, | |
padding=padding, | |
output_padding=output_padding, | |
bias=False) | |
layers.append(build_upsample_layer(cfg)) | |
layers.append(nn.BatchNorm2d(num_features=out_channels)) | |
layers.append(nn.ReLU(inplace=True)) | |
in_channels = out_channels | |
return nn.Sequential(*layers) | |
def forward(self, feats: Sequence[Tensor]) -> List[Tensor]: | |
"""Forward the network. The input is multi-stage feature maps and the | |
output is a list of heatmaps from multiple stages. | |
Args: | |
feats (Sequence[Tensor]): Multi-stage feature maps. | |
Returns: | |
List[Tensor]: A list of output heatmaps from multiple stages. | |
""" | |
out = [] | |
assert len(feats) == self.num_stages, ( | |
f'The length of feature maps did not match the ' | |
f'`num_stages` in {self.__class__.__name__}') | |
for i in range(self.num_stages): | |
y = self.multi_deconv_layers[i](feats[i]) | |
y = self.multi_final_layers[i](y) | |
out.append(y) | |
return out | |
def predict(self, | |
feats: Features, | |
batch_data_samples: OptSampleList, | |
test_cfg: OptConfigType = {}) -> Predictions: | |
"""Predict results from multi-stage feature maps. | |
Args: | |
feats (Tuple[Tensor] | List[Tuple[Tensor]]): The multi-stage | |
features (or multiple multi-stage features in TTA) | |
batch_data_samples (List[:obj:`PoseDataSample`]): The batch | |
data samples | |
test_cfg (dict): The runtime config for testing process. Defaults | |
to {} | |
Returns: | |
Union[InstanceList | Tuple[InstanceList | PixelDataList]]: If | |
``test_cfg['output_heatmap']==True``, return both pose and heatmap | |
prediction; otherwise only return the pose prediction. | |
The pose prediction is a list of ``InstanceData``, each contains | |
the following fields: | |
- keypoints (np.ndarray): predicted keypoint coordinates in | |
shape (num_instances, K, D) where K is the keypoint number | |
and D is the keypoint dimension | |
- keypoint_scores (np.ndarray): predicted keypoint scores in | |
shape (num_instances, K) | |
The heatmap prediction is a list of ``PixelData``, each contains | |
the following fields: | |
- heatmaps (Tensor): The predicted heatmaps in shape (K, h, w) | |
""" | |
if test_cfg.get('flip_test', False): | |
# TTA: flip test | |
assert isinstance(feats, list) and len(feats) == 2 | |
flip_indices = batch_data_samples[0].metainfo['flip_indices'] | |
_feats, _feats_flip = feats | |
_batch_heatmaps = self.forward(_feats)[-1] | |
_batch_heatmaps_flip = flip_heatmaps( | |
self.forward(_feats_flip)[-1], | |
flip_mode=test_cfg.get('flip_mode', 'heatmap'), | |
flip_indices=flip_indices, | |
shift_heatmap=test_cfg.get('shift_heatmap', False)) | |
batch_heatmaps = (_batch_heatmaps + _batch_heatmaps_flip) * 0.5 | |
else: | |
multi_stage_heatmaps = self.forward(feats) | |
batch_heatmaps = multi_stage_heatmaps[-1] | |
preds = self.decode(batch_heatmaps) | |
if test_cfg.get('output_heatmaps', False): | |
pred_fields = [ | |
PixelData(heatmaps=hm) for hm in batch_heatmaps.detach() | |
] | |
return preds, pred_fields | |
else: | |
return preds | |
def loss(self, | |
feats: Sequence[Tensor], | |
batch_data_samples: OptSampleList, | |
train_cfg: OptConfigType = {}) -> dict: | |
"""Calculate losses from a batch of inputs and data samples. | |
Args: | |
feats (Sequence[Tensor]): Multi-stage feature maps. | |
batch_data_samples (List[:obj:`PoseDataSample`]): The Data | |
Samples. It usually includes information such as | |
`gt_instances`. | |
train_cfg (Config, optional): The training config. | |
Returns: | |
dict: A dictionary of loss components. | |
""" | |
multi_stage_pred_heatmaps = self.forward(feats) | |
gt_heatmaps = torch.stack( | |
[d.gt_fields.heatmaps for d in batch_data_samples]) | |
keypoint_weights = torch.cat([ | |
d.gt_instance_labels.keypoint_weights for d in batch_data_samples | |
]) | |
# calculate losses over multiple stages | |
losses = dict() | |
for i in range(self.num_stages): | |
if isinstance(self.loss_module, nn.ModuleList): | |
# use different loss_module over different stages | |
loss_func = self.loss_module[i] | |
else: | |
# use the same loss_module over different stages | |
loss_func = self.loss_module | |
# the `gt_heatmaps` and `keypoint_weights` used to calculate loss | |
# for different stages are the same | |
loss_i = loss_func(multi_stage_pred_heatmaps[i], gt_heatmaps, | |
keypoint_weights) | |
if 'loss_kpt' not in losses: | |
losses['loss_kpt'] = loss_i | |
else: | |
losses['loss_kpt'] += loss_i | |
# calculate accuracy | |
_, avg_acc, _ = pose_pck_accuracy( | |
output=to_numpy(multi_stage_pred_heatmaps[-1]), | |
target=to_numpy(gt_heatmaps), | |
mask=to_numpy(keypoint_weights) > 0) | |
acc_pose = torch.tensor(avg_acc, device=gt_heatmaps.device) | |
losses.update(acc_pose=acc_pose) | |
return losses | |