HumanSD / mmpose /models /heads /regression_heads /integral_regression_head.py
liyy201912's picture
Upload folder using huggingface_hub
cc0dd3c
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Sequence, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from mmcv.cnn import build_conv_layer
from mmengine.structures import PixelData
from torch import Tensor, nn
from mmpose.evaluation.functional import keypoint_pck_accuracy
from mmpose.models.utils.tta import flip_coordinates, flip_heatmaps
from mmpose.registry import KEYPOINT_CODECS, MODELS
from mmpose.utils.tensor_utils import to_numpy
from mmpose.utils.typing import (ConfigType, OptConfigType, OptSampleList,
Predictions)
from .. import HeatmapHead
from ..base_head import BaseHead
OptIntSeq = Optional[Sequence[int]]
@MODELS.register_module()
class IntegralRegressionHead(BaseHead):
"""Top-down integral regression head introduced in `IPR`_ by Xiao et
al(2018). The head contains a differentiable spatial to numerical transform
(DSNT) layer that do soft-argmax operation on the predicted heatmaps to
regress the coordinates.
This head is used for algorithms that only supervise the coordinates.
Args:
in_channels (int | sequence[int]): Number of input channels
in_featuremap_size (int | sequence[int]): Size of input feature map
num_joints (int): Number of joints
debias (bool): Whether to remove the bias of Integral Pose Regression.
see `Removing the Bias of Integral Pose Regression`_ by Gu et al
(2021). Defaults to ``False``.
beta (float): A smoothing parameter in softmax. Defaults to ``1.0``.
deconv_out_channels (sequence[int]): The output channel number of each
deconv layer. Defaults to ``(256, 256, 256)``
deconv_kernel_sizes (sequence[int | tuple], optional): The kernel size
of each deconv layer. Each element should be either an integer for
both height and width dimensions, or a tuple of two integers for
the height and the width dimension respectively.Defaults to
``(4, 4, 4)``
conv_out_channels (sequence[int], optional): The output channel number
of each intermediate conv layer. ``None`` means no intermediate
conv layer between deconv layers and the final conv layer.
Defaults to ``None``
conv_kernel_sizes (sequence[int | tuple], optional): The kernel size
of each intermediate conv layer. Defaults to ``None``
final_layer (dict): Arguments of the final Conv2d layer.
Defaults to ``dict(kernel_size=1)``
loss (Config): Config for keypoint loss. Defaults to use
:class:`SmoothL1Loss`
decoder (Config, optional): The decoder config that controls decoding
keypoint coordinates from the network output. Defaults to ``None``
init_cfg (Config, optional): Config to control the initialization. See
:attr:`default_init_cfg` for default settings
.. _`IPR`: https://arxiv.org/abs/1711.08229
.. _`Debias`:
"""
_version = 2
def __init__(self,
in_channels: Union[int, Sequence[int]],
in_featuremap_size: Tuple[int, int],
num_joints: int,
debias: bool = False,
beta: float = 1.0,
deconv_out_channels: OptIntSeq = (256, 256, 256),
deconv_kernel_sizes: OptIntSeq = (4, 4, 4),
conv_out_channels: OptIntSeq = None,
conv_kernel_sizes: OptIntSeq = None,
final_layer: dict = dict(kernel_size=1),
loss: ConfigType = dict(
type='SmoothL1Loss', use_target_weight=True),
decoder: OptConfigType = None,
init_cfg: OptConfigType = None):
if init_cfg is None:
init_cfg = self.default_init_cfg
super().__init__(init_cfg)
self.in_channels = in_channels
self.num_joints = num_joints
self.debias = debias
self.beta = beta
self.loss_module = MODELS.build(loss)
if decoder is not None:
self.decoder = KEYPOINT_CODECS.build(decoder)
else:
self.decoder = None
num_deconv = len(deconv_out_channels) if deconv_out_channels else 0
if num_deconv != 0:
self.heatmap_size = tuple(
[s * (2**num_deconv) for s in in_featuremap_size])
# deconv layers + 1x1 conv
self.simplebaseline_head = HeatmapHead(
in_channels=in_channels,
out_channels=num_joints,
deconv_out_channels=deconv_out_channels,
deconv_kernel_sizes=deconv_kernel_sizes,
conv_out_channels=conv_out_channels,
conv_kernel_sizes=conv_kernel_sizes,
final_layer=final_layer)
if final_layer is not None:
in_channels = num_joints
else:
in_channels = deconv_out_channels[-1]
else:
self.simplebaseline_head = None
if final_layer is not None:
cfg = dict(
type='Conv2d',
in_channels=in_channels,
out_channels=num_joints,
kernel_size=1)
cfg.update(final_layer)
self.final_layer = build_conv_layer(cfg)
else:
self.final_layer = None
self.heatmap_size = in_featuremap_size
if isinstance(in_channels, list):
raise ValueError(
f'{self.__class__.__name__} does not support selecting '
'multiple input features.')
W, H = self.heatmap_size
self.linspace_x = torch.arange(0.0, 1.0 * W, 1).reshape(1, 1, 1, W) / W
self.linspace_y = torch.arange(0.0, 1.0 * H, 1).reshape(1, 1, H, 1) / H
self.linspace_x = nn.Parameter(self.linspace_x, requires_grad=False)
self.linspace_y = nn.Parameter(self.linspace_y, requires_grad=False)
self._register_load_state_dict_pre_hook(self._load_state_dict_pre_hook)
def _linear_expectation(self, heatmaps: Tensor,
linspace: Tensor) -> Tensor:
"""Calculate linear expectation."""
B, N, _, _ = heatmaps.shape
heatmaps = heatmaps.mul(linspace).reshape(B, N, -1)
expectation = torch.sum(heatmaps, dim=2, keepdim=True)
return expectation
def _flat_softmax(self, featmaps: Tensor) -> Tensor:
"""Use Softmax to normalize the featmaps in depthwise."""
_, N, H, W = featmaps.shape
featmaps = featmaps.reshape(-1, N, H * W)
heatmaps = F.softmax(featmaps, dim=2)
return heatmaps.reshape(-1, N, H, W)
def forward(self, feats: Tuple[Tensor]) -> Union[Tensor, Tuple[Tensor]]:
"""Forward the network. The input is multi scale feature maps and the
output is the coordinates.
Args:
feats (Tuple[Tensor]): Multi scale feature maps.
Returns:
Tensor: output coordinates(and sigmas[optional]).
"""
if self.simplebaseline_head is None:
feats = feats[-1]
if self.final_layer is not None:
feats = self.final_layer(feats)
else:
feats = self.simplebaseline_head(feats)
heatmaps = self._flat_softmax(feats * self.beta)
pred_x = self._linear_expectation(heatmaps, self.linspace_x)
pred_y = self._linear_expectation(heatmaps, self.linspace_y)
if self.debias:
B, N, H, W = feats.shape
C = feats.reshape(B, N, H * W).exp().sum(dim=2).reshape(B, N, 1)
pred_x = C / (C - 1) * (pred_x - 1 / (2 * C))
pred_y = C / (C - 1) * (pred_y - 1 / (2 * C))
coords = torch.cat([pred_x, pred_y], dim=-1)
return coords, heatmaps
def predict(self,
feats: Tuple[Tensor],
batch_data_samples: OptSampleList,
test_cfg: ConfigType = {}) -> Predictions:
"""Predict results from features.
Args:
feats (Tuple[Tensor] | List[Tuple[Tensor]]): The multi-stage
features (or multiple multi-stage features in TTA)
batch_data_samples (List[:obj:`PoseDataSample`]): The batch
data samples
test_cfg (dict): The runtime config for testing process. Defaults
to {}
Returns:
Union[InstanceList | Tuple[InstanceList | PixelDataList]]: If
``test_cfg['output_heatmap']==True``, return both pose and heatmap
prediction; otherwise only return the pose prediction.
The pose prediction is a list of ``InstanceData``, each contains
the following fields:
- keypoints (np.ndarray): predicted keypoint coordinates in
shape (num_instances, K, D) where K is the keypoint number
and D is the keypoint dimension
- keypoint_scores (np.ndarray): predicted keypoint scores in
shape (num_instances, K)
The heatmap prediction is a list of ``PixelData``, each contains
the following fields:
- heatmaps (Tensor): The predicted heatmaps in shape (K, h, w)
"""
if test_cfg.get('flip_test', False):
# TTA: flip test -> feats = [orig, flipped]
assert isinstance(feats, list) and len(feats) == 2
flip_indices = batch_data_samples[0].metainfo['flip_indices']
input_size = batch_data_samples[0].metainfo['input_size']
_feats, _feats_flip = feats
_batch_coords, _batch_heatmaps = self.forward(_feats)
_batch_coords_flip, _batch_heatmaps_flip = self.forward(
_feats_flip)
_batch_coords_flip = flip_coordinates(
_batch_coords_flip,
flip_indices=flip_indices,
shift_coords=test_cfg.get('shift_coords', True),
input_size=input_size)
_batch_heatmaps_flip = flip_heatmaps(
_batch_heatmaps_flip,
flip_mode='heatmap',
flip_indices=flip_indices,
shift_heatmap=test_cfg.get('shift_heatmap', False))
batch_coords = (_batch_coords + _batch_coords_flip) * 0.5
batch_heatmaps = (_batch_heatmaps + _batch_heatmaps_flip) * 0.5
else:
batch_coords, batch_heatmaps = self.forward(feats) # (B, K, D)
batch_coords.unsqueeze_(dim=1) # (B, N, K, D)
preds = self.decode(batch_coords)
if test_cfg.get('output_heatmaps', False):
pred_fields = [
PixelData(heatmaps=hm) for hm in batch_heatmaps.detach()
]
return preds, pred_fields
else:
return preds
def loss(self,
inputs: Tuple[Tensor],
batch_data_samples: OptSampleList,
train_cfg: ConfigType = {}) -> dict:
"""Calculate losses from a batch of inputs and data samples."""
pred_coords, _ = self.forward(inputs)
keypoint_labels = torch.cat(
[d.gt_instance_labels.keypoint_labels for d in batch_data_samples])
keypoint_weights = torch.cat([
d.gt_instance_labels.keypoint_weights for d in batch_data_samples
])
# calculate losses
losses = dict()
# TODO: multi-loss calculation
loss = self.loss_module(pred_coords, keypoint_labels, keypoint_weights)
losses.update(loss_kpt=loss)
# calculate accuracy
_, avg_acc, _ = keypoint_pck_accuracy(
pred=to_numpy(pred_coords),
gt=to_numpy(keypoint_labels),
mask=to_numpy(keypoint_weights) > 0,
thr=0.05,
norm_factor=np.ones((pred_coords.size(0), 2), dtype=np.float32))
acc_pose = torch.tensor(avg_acc, device=keypoint_labels.device)
losses.update(acc_pose=acc_pose)
return losses
@property
def default_init_cfg(self):
init_cfg = [dict(type='Normal', layer=['Linear'], std=0.01, bias=0)]
return init_cfg
def _load_state_dict_pre_hook(self, state_dict, prefix, local_meta, *args,
**kwargs):
"""A hook function to load weights of deconv layers from
:class:`HeatmapHead` into `simplebaseline_head`.
The hook will be automatically registered during initialization.
"""
# convert old-version state dict
keys = list(state_dict.keys())
for _k in keys:
if not _k.startswith(prefix):
continue
v = state_dict.pop(_k)
k = _k.lstrip(prefix)
k_new = _k
k_parts = k.split('.')
if self.simplebaseline_head is not None:
if k_parts[0] == 'conv_layers':
k_new = (
prefix + 'simplebaseline_head.deconv_layers.' +
'.'.join(k_parts[1:]))
elif k_parts[0] == 'final_layer':
k_new = prefix + 'simplebaseline_head.' + k
state_dict[k_new] = v