Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import warnings | |
from typing import Optional, Sequence, Tuple, Union | |
import torch | |
from mmengine.dist import get_dist_info | |
from mmengine.structures import PixelData | |
from torch import Tensor, nn | |
from mmpose.codecs.utils import get_simcc_normalized | |
from mmpose.evaluation.functional import simcc_pck_accuracy | |
from mmpose.models.utils.rtmcc_block import RTMCCBlock, ScaleNorm | |
from mmpose.models.utils.tta import flip_vectors | |
from mmpose.registry import KEYPOINT_CODECS, MODELS | |
from mmpose.utils.tensor_utils import to_numpy | |
from mmpose.utils.typing import (ConfigType, InstanceList, OptConfigType, | |
OptSampleList) | |
from ..base_head import BaseHead | |
OptIntSeq = Optional[Sequence[int]] | |
class RTMCCHead(BaseHead): | |
"""Top-down head introduced in RTMPose (2023). The head is composed of a | |
large-kernel convolutional layer, a fully-connected layer and a Gated | |
Attention Unit to generate 1d representation from low-resolution feature | |
maps. | |
Args: | |
in_channels (int | sequence[int]): Number of channels in the input | |
feature map. | |
out_channels (int): Number of channels in the output heatmap. | |
input_size (tuple): Size of input image in shape [w, h]. | |
in_featuremap_size (int | sequence[int]): Size of input feature map. | |
simcc_split_ratio (float): Split ratio of pixels. | |
Default: 2.0. | |
final_layer_kernel_size (int): Kernel size of the convolutional layer. | |
Default: 1. | |
gau_cfg (Config): Config dict for the Gated Attention Unit. | |
Default: dict( | |
hidden_dims=256, | |
s=128, | |
expansion_factor=2, | |
dropout_rate=0., | |
drop_path=0., | |
act_fn='ReLU', | |
use_rel_bias=False, | |
pos_enc=False). | |
loss (Config): Config of the keypoint loss. Defaults to use | |
:class:`KLDiscretLoss` | |
decoder (Config, optional): The decoder config that controls decoding | |
keypoint coordinates from the network output. Defaults to ``None`` | |
init_cfg (Config, optional): Config to control the initialization. See | |
:attr:`default_init_cfg` for default settings | |
""" | |
def __init__( | |
self, | |
in_channels: Union[int, Sequence[int]], | |
out_channels: int, | |
input_size: Tuple[int, int], | |
in_featuremap_size: Tuple[int, int], | |
simcc_split_ratio: float = 2.0, | |
final_layer_kernel_size: int = 1, | |
gau_cfg: ConfigType = dict( | |
hidden_dims=256, | |
s=128, | |
expansion_factor=2, | |
dropout_rate=0., | |
drop_path=0., | |
act_fn='ReLU', | |
use_rel_bias=False, | |
pos_enc=False), | |
loss: ConfigType = dict(type='KLDiscretLoss', use_target_weight=True), | |
decoder: OptConfigType = None, | |
init_cfg: OptConfigType = None, | |
): | |
if init_cfg is None: | |
init_cfg = self.default_init_cfg | |
super().__init__(init_cfg) | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.input_size = input_size | |
self.in_featuremap_size = in_featuremap_size | |
self.simcc_split_ratio = simcc_split_ratio | |
self.loss_module = MODELS.build(loss) | |
if decoder is not None: | |
self.decoder = KEYPOINT_CODECS.build(decoder) | |
else: | |
self.decoder = None | |
if isinstance(in_channels, (tuple, list)): | |
raise ValueError( | |
f'{self.__class__.__name__} does not support selecting ' | |
'multiple input features.') | |
# Define SimCC layers | |
flatten_dims = self.in_featuremap_size[0] * self.in_featuremap_size[1] | |
self.final_layer = nn.Conv2d( | |
in_channels, | |
out_channels, | |
kernel_size=final_layer_kernel_size, | |
stride=1, | |
padding=final_layer_kernel_size // 2) | |
self.mlp = nn.Sequential( | |
ScaleNorm(flatten_dims), | |
nn.Linear(flatten_dims, gau_cfg['hidden_dims'], bias=False)) | |
W = int(self.input_size[0] * self.simcc_split_ratio) | |
H = int(self.input_size[1] * self.simcc_split_ratio) | |
self.gau = RTMCCBlock( | |
self.out_channels, | |
gau_cfg['hidden_dims'], | |
gau_cfg['hidden_dims'], | |
s=gau_cfg['s'], | |
expansion_factor=gau_cfg['expansion_factor'], | |
dropout_rate=gau_cfg['dropout_rate'], | |
drop_path=gau_cfg['drop_path'], | |
attn_type='self-attn', | |
act_fn=gau_cfg['act_fn'], | |
use_rel_bias=gau_cfg['use_rel_bias'], | |
pos_enc=gau_cfg['pos_enc']) | |
self.cls_x = nn.Linear(gau_cfg['hidden_dims'], W, bias=False) | |
self.cls_y = nn.Linear(gau_cfg['hidden_dims'], H, bias=False) | |
def forward(self, feats: Tuple[Tensor]) -> Tuple[Tensor, Tensor]: | |
"""Forward the network. | |
The input is multi scale feature maps and the | |
output is the heatmap. | |
Args: | |
feats (Tuple[Tensor]): Multi scale feature maps. | |
Returns: | |
pred_x (Tensor): 1d representation of x. | |
pred_y (Tensor): 1d representation of y. | |
""" | |
feats = feats[-1] | |
feats = self.final_layer(feats) # -> B, K, H, W | |
# flatten the output heatmap | |
feats = torch.flatten(feats, 2) | |
feats = self.mlp(feats) # -> B, K, hidden | |
feats = self.gau(feats) | |
pred_x = self.cls_x(feats) | |
pred_y = self.cls_y(feats) | |
return pred_x, pred_y | |
def predict( | |
self, | |
feats: Tuple[Tensor], | |
batch_data_samples: OptSampleList, | |
test_cfg: OptConfigType = {}, | |
) -> InstanceList: | |
"""Predict results from features. | |
Args: | |
feats (Tuple[Tensor] | List[Tuple[Tensor]]): The multi-stage | |
features (or multiple multi-stage features in TTA) | |
batch_data_samples (List[:obj:`PoseDataSample`]): The batch | |
data samples | |
test_cfg (dict): The runtime config for testing process. Defaults | |
to {} | |
Returns: | |
List[InstanceData]: The pose predictions, each contains | |
the following fields: | |
- keypoints (np.ndarray): predicted keypoint coordinates in | |
shape (num_instances, K, D) where K is the keypoint number | |
and D is the keypoint dimension | |
- keypoint_scores (np.ndarray): predicted keypoint scores in | |
shape (num_instances, K) | |
- keypoint_x_labels (np.ndarray, optional): The predicted 1-D | |
intensity distribution in the x direction | |
- keypoint_y_labels (np.ndarray, optional): The predicted 1-D | |
intensity distribution in the y direction | |
""" | |
if test_cfg.get('flip_test', False): | |
# TTA: flip test -> feats = [orig, flipped] | |
assert isinstance(feats, list) and len(feats) == 2 | |
flip_indices = batch_data_samples[0].metainfo['flip_indices'] | |
_feats, _feats_flip = feats | |
_batch_pred_x, _batch_pred_y = self.forward(_feats) | |
_batch_pred_x_flip, _batch_pred_y_flip = self.forward(_feats_flip) | |
_batch_pred_x_flip, _batch_pred_y_flip = flip_vectors( | |
_batch_pred_x_flip, | |
_batch_pred_y_flip, | |
flip_indices=flip_indices) | |
batch_pred_x = (_batch_pred_x + _batch_pred_x_flip) * 0.5 | |
batch_pred_y = (_batch_pred_y + _batch_pred_y_flip) * 0.5 | |
else: | |
batch_pred_x, batch_pred_y = self.forward(feats) | |
preds = self.decode((batch_pred_x, batch_pred_y)) | |
if test_cfg.get('output_heatmaps', False): | |
rank, _ = get_dist_info() | |
if rank == 0: | |
warnings.warn('The predicted simcc values are normalized for ' | |
'visualization. This may cause discrepancy ' | |
'between the keypoint scores and the 1D heatmaps' | |
'.') | |
# normalize the predicted 1d distribution | |
batch_pred_x = get_simcc_normalized(batch_pred_x) | |
batch_pred_y = get_simcc_normalized(batch_pred_y) | |
B, K, _ = batch_pred_x.shape | |
# B, K, Wx -> B, K, Wx, 1 | |
x = batch_pred_x.reshape(B, K, 1, -1) | |
# B, K, Wy -> B, K, 1, Wy | |
y = batch_pred_y.reshape(B, K, -1, 1) | |
# B, K, Wx, Wy | |
batch_heatmaps = torch.matmul(y, x) | |
pred_fields = [ | |
PixelData(heatmaps=hm) for hm in batch_heatmaps.detach() | |
] | |
for pred_instances, pred_x, pred_y in zip(preds, | |
to_numpy(batch_pred_x), | |
to_numpy(batch_pred_y)): | |
pred_instances.keypoint_x_labels = pred_x[None] | |
pred_instances.keypoint_y_labels = pred_y[None] | |
return preds, pred_fields | |
else: | |
return preds | |
def loss( | |
self, | |
feats: Tuple[Tensor], | |
batch_data_samples: OptSampleList, | |
train_cfg: OptConfigType = {}, | |
) -> dict: | |
"""Calculate losses from a batch of inputs and data samples.""" | |
pred_x, pred_y = self.forward(feats) | |
gt_x = torch.cat([ | |
d.gt_instance_labels.keypoint_x_labels for d in batch_data_samples | |
], | |
dim=0) | |
gt_y = torch.cat([ | |
d.gt_instance_labels.keypoint_y_labels for d in batch_data_samples | |
], | |
dim=0) | |
keypoint_weights = torch.cat( | |
[ | |
d.gt_instance_labels.keypoint_weights | |
for d in batch_data_samples | |
], | |
dim=0, | |
) | |
pred_simcc = (pred_x, pred_y) | |
gt_simcc = (gt_x, gt_y) | |
# calculate losses | |
losses = dict() | |
loss = self.loss_module(pred_simcc, gt_simcc, keypoint_weights) | |
losses.update(loss_kpt=loss) | |
# calculate accuracy | |
_, avg_acc, _ = simcc_pck_accuracy( | |
output=to_numpy(pred_simcc), | |
target=to_numpy(gt_simcc), | |
simcc_split_ratio=self.simcc_split_ratio, | |
mask=to_numpy(keypoint_weights) > 0, | |
) | |
acc_pose = torch.tensor(avg_acc, device=gt_x.device) | |
losses.update(acc_pose=acc_pose) | |
return losses | |
def default_init_cfg(self): | |
init_cfg = [ | |
dict(type='Normal', layer=['Conv2d'], std=0.001), | |
dict(type='Constant', layer='BatchNorm2d', val=1), | |
dict(type='Normal', layer=['Linear'], std=0.01, bias=0), | |
] | |
return init_cfg | |