Spaces:
Build error
Build error
# Copyright (c) OpenMMLab. All rights reserved. | |
import numpy as np | |
import torch.nn as nn | |
from mmcv.cnn import build_conv_layer, constant_init, kaiming_init | |
from mmcv.utils.parrots_wrapper import _BatchNorm | |
from mmpose.core import (WeightNormClipHook, compute_similarity_transform, | |
fliplr_regression) | |
from mmpose.models.builder import HEADS, build_loss | |
class TemporalRegressionHead(nn.Module): | |
"""Regression head of VideoPose3D. | |
"3D human pose estimation in video with temporal convolutions and | |
semi-supervised training", CVPR'2019. | |
Args: | |
in_channels (int): Number of input channels | |
num_joints (int): Number of joints | |
loss_keypoint (dict): Config for keypoint loss. Default: None. | |
max_norm (float|None): if not None, the weight of convolution layers | |
will be clipped to have a maximum norm of max_norm. | |
is_trajectory (bool): If the model only predicts root joint | |
position, then this arg should be set to True. In this case, | |
traj_loss will be calculated. Otherwise, it should be set to | |
False. Default: False. | |
""" | |
def __init__(self, | |
in_channels, | |
num_joints, | |
max_norm=None, | |
loss_keypoint=None, | |
is_trajectory=False, | |
train_cfg=None, | |
test_cfg=None): | |
super().__init__() | |
self.in_channels = in_channels | |
self.num_joints = num_joints | |
self.max_norm = max_norm | |
self.loss = build_loss(loss_keypoint) | |
self.is_trajectory = is_trajectory | |
if self.is_trajectory: | |
assert self.num_joints == 1 | |
self.train_cfg = {} if train_cfg is None else train_cfg | |
self.test_cfg = {} if test_cfg is None else test_cfg | |
self.conv = build_conv_layer( | |
dict(type='Conv1d'), in_channels, num_joints * 3, 1) | |
if self.max_norm is not None: | |
# Apply weight norm clip to conv layers | |
weight_clip = WeightNormClipHook(self.max_norm) | |
for module in self.modules(): | |
if isinstance(module, nn.modules.conv._ConvNd): | |
weight_clip.register(module) | |
def _transform_inputs(x): | |
"""Transform inputs for decoder. | |
Args: | |
inputs (tuple or list of Tensor | Tensor): multi-level features. | |
Returns: | |
Tensor: The transformed inputs | |
""" | |
if not isinstance(x, (list, tuple)): | |
return x | |
assert len(x) > 0 | |
# return the top-level feature of the 1D feature pyramid | |
return x[-1] | |
def forward(self, x): | |
"""Forward function.""" | |
x = self._transform_inputs(x) | |
assert x.ndim == 3 and x.shape[2] == 1, f'Invalid shape {x.shape}' | |
output = self.conv(x) | |
N = output.shape[0] | |
return output.reshape(N, self.num_joints, 3) | |
def get_loss(self, output, target, target_weight): | |
"""Calculate keypoint loss. | |
Note: | |
- batch_size: N | |
- num_keypoints: K | |
Args: | |
output (torch.Tensor[N, K, 3]): Output keypoints. | |
target (torch.Tensor[N, K, 3]): Target keypoints. | |
target_weight (torch.Tensor[N, K, 3]): | |
Weights across different joint types. | |
If self.is_trajectory is True and target_weight is None, | |
target_weight will be set inversely proportional to joint | |
depth. | |
""" | |
losses = dict() | |
assert not isinstance(self.loss, nn.Sequential) | |
# trajectory model | |
if self.is_trajectory: | |
if target.dim() == 2: | |
target.unsqueeze_(1) | |
if target_weight is None: | |
target_weight = (1 / target[:, :, 2:]).expand(target.shape) | |
assert target.dim() == 3 and target_weight.dim() == 3 | |
losses['traj_loss'] = self.loss(output, target, target_weight) | |
# pose model | |
else: | |
if target_weight is None: | |
target_weight = target.new_ones(target.shape) | |
assert target.dim() == 3 and target_weight.dim() == 3 | |
losses['reg_loss'] = self.loss(output, target, target_weight) | |
return losses | |
def get_accuracy(self, output, target, target_weight, metas): | |
"""Calculate accuracy for keypoint loss. | |
Note: | |
- batch_size: N | |
- num_keypoints: K | |
Args: | |
output (torch.Tensor[N, K, 3]): Output keypoints. | |
target (torch.Tensor[N, K, 3]): Target keypoints. | |
target_weight (torch.Tensor[N, K, 3]): | |
Weights across different joint types. | |
metas (list(dict)): Information about data augmentation including: | |
- target_image_path (str): Optional, path to the image file | |
- target_mean (float): Optional, normalization parameter of | |
the target pose. | |
- target_std (float): Optional, normalization parameter of the | |
target pose. | |
- root_position (np.ndarray[3,1]): Optional, global | |
position of the root joint. | |
- root_index (torch.ndarray[1,]): Optional, original index of | |
the root joint before root-centering. | |
""" | |
accuracy = dict() | |
N = output.shape[0] | |
output_ = output.detach().cpu().numpy() | |
target_ = target.detach().cpu().numpy() | |
# Denormalize the predicted pose | |
if 'target_mean' in metas[0] and 'target_std' in metas[0]: | |
target_mean = np.stack([m['target_mean'] for m in metas]) | |
target_std = np.stack([m['target_std'] for m in metas]) | |
output_ = self._denormalize_joints(output_, target_mean, | |
target_std) | |
target_ = self._denormalize_joints(target_, target_mean, | |
target_std) | |
# Restore global position | |
if self.test_cfg.get('restore_global_position', False): | |
root_pos = np.stack([m['root_position'] for m in metas]) | |
root_idx = metas[0].get('root_position_index', None) | |
output_ = self._restore_global_position(output_, root_pos, | |
root_idx) | |
target_ = self._restore_global_position(target_, root_pos, | |
root_idx) | |
# Get target weight | |
if target_weight is None: | |
target_weight_ = np.ones_like(target_) | |
else: | |
target_weight_ = target_weight.detach().cpu().numpy() | |
if self.test_cfg.get('restore_global_position', False): | |
root_idx = metas[0].get('root_position_index', None) | |
root_weight = metas[0].get('root_joint_weight', 1.0) | |
target_weight_ = self._restore_root_target_weight( | |
target_weight_, root_weight, root_idx) | |
mpjpe = np.mean( | |
np.linalg.norm((output_ - target_) * target_weight_, axis=-1)) | |
transformed_output = np.zeros_like(output_) | |
for i in range(N): | |
transformed_output[i, :, :] = compute_similarity_transform( | |
output_[i, :, :], target_[i, :, :]) | |
p_mpjpe = np.mean( | |
np.linalg.norm( | |
(transformed_output - target_) * target_weight_, axis=-1)) | |
accuracy['mpjpe'] = output.new_tensor(mpjpe) | |
accuracy['p_mpjpe'] = output.new_tensor(p_mpjpe) | |
return accuracy | |
def inference_model(self, x, flip_pairs=None): | |
"""Inference function. | |
Returns: | |
output_regression (np.ndarray): Output regression. | |
Args: | |
x (torch.Tensor[N, K, 2]): Input features. | |
flip_pairs (None | list[tuple()): | |
Pairs of keypoints which are mirrored. | |
""" | |
output = self.forward(x) | |
if flip_pairs is not None: | |
output_regression = fliplr_regression( | |
output.detach().cpu().numpy(), | |
flip_pairs, | |
center_mode='static', | |
center_x=0) | |
else: | |
output_regression = output.detach().cpu().numpy() | |
return output_regression | |
def decode(self, metas, output): | |
"""Decode the keypoints from output regression. | |
Args: | |
metas (list(dict)): Information about data augmentation. | |
By default this includes: | |
- "target_image_path": path to the image file | |
output (np.ndarray[N, K, 3]): predicted regression vector. | |
metas (list(dict)): Information about data augmentation including: | |
- target_image_path (str): Optional, path to the image file | |
- target_mean (float): Optional, normalization parameter of | |
the target pose. | |
- target_std (float): Optional, normalization parameter of the | |
target pose. | |
- root_position (np.ndarray[3,1]): Optional, global | |
position of the root joint. | |
- root_index (torch.ndarray[1,]): Optional, original index of | |
the root joint before root-centering. | |
""" | |
# Denormalize the predicted pose | |
if 'target_mean' in metas[0] and 'target_std' in metas[0]: | |
target_mean = np.stack([m['target_mean'] for m in metas]) | |
target_std = np.stack([m['target_std'] for m in metas]) | |
output = self._denormalize_joints(output, target_mean, target_std) | |
# Restore global position | |
if self.test_cfg.get('restore_global_position', False): | |
root_pos = np.stack([m['root_position'] for m in metas]) | |
root_idx = metas[0].get('root_position_index', None) | |
output = self._restore_global_position(output, root_pos, root_idx) | |
target_image_paths = [m.get('target_image_path', None) for m in metas] | |
result = {'preds': output, 'target_image_paths': target_image_paths} | |
return result | |
def _denormalize_joints(x, mean, std): | |
"""Denormalize joint coordinates with given statistics mean and std. | |
Args: | |
x (np.ndarray[N, K, 3]): Normalized joint coordinates. | |
mean (np.ndarray[K, 3]): Mean value. | |
std (np.ndarray[K, 3]): Std value. | |
""" | |
assert x.ndim == 3 | |
assert x.shape == mean.shape == std.shape | |
return x * std + mean | |
def _restore_global_position(x, root_pos, root_idx=None): | |
"""Restore global position of the root-centered joints. | |
Args: | |
x (np.ndarray[N, K, 3]): root-centered joint coordinates | |
root_pos (np.ndarray[N,1,3]): The global position of the | |
root joint. | |
root_idx (int|None): If not none, the root joint will be inserted | |
back to the pose at the given index. | |
""" | |
x = x + root_pos | |
if root_idx is not None: | |
x = np.insert(x, root_idx, root_pos.squeeze(1), axis=1) | |
return x | |
def _restore_root_target_weight(target_weight, root_weight, root_idx=None): | |
"""Restore the target weight of the root joint after the restoration of | |
the global position. | |
Args: | |
target_weight (np.ndarray[N, K, 1]): Target weight of relativized | |
joints. | |
root_weight (float): The target weight value of the root joint. | |
root_idx (int|None): If not none, the root joint weight will be | |
inserted back to the target weight at the given index. | |
""" | |
if root_idx is not None: | |
root_weight = np.full( | |
target_weight.shape[0], root_weight, dtype=target_weight.dtype) | |
target_weight = np.insert( | |
target_weight, root_idx, root_weight[:, None], axis=1) | |
return target_weight | |
def init_weights(self): | |
"""Initialize the weights.""" | |
for m in self.modules(): | |
if isinstance(m, nn.modules.conv._ConvNd): | |
kaiming_init(m, mode='fan_in', nonlinearity='relu') | |
elif isinstance(m, _BatchNorm): | |
constant_init(m, 1) | |