HaMeR / mmpose /models /heads /interhand_3d_head.py
geopavlakos's picture
Initial commit
d7a991a
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import (build_conv_layer, build_norm_layer, build_upsample_layer,
constant_init, normal_init)
from mmpose.core.evaluation.top_down_eval import (
keypoints_from_heatmaps3d, multilabel_classification_accuracy)
from mmpose.core.post_processing import flip_back
from mmpose.models.builder import build_loss
from mmpose.models.necks import GlobalAveragePooling
from ..builder import HEADS
class Heatmap3DHead(nn.Module):
"""Heatmap3DHead is a sub-module of Interhand3DHead, and outputs 3D
heatmaps. Heatmap3DHead is composed of (>=0) number of deconv layers and a
simple conv2d layer.
Args:
in_channels (int): Number of input channels
out_channels (int): Number of output channels
depth_size (int): Number of depth discretization size
num_deconv_layers (int): Number of deconv layers.
num_deconv_layers should >= 0. Note that 0 means no deconv layers.
num_deconv_filters (list|tuple): Number of filters.
num_deconv_kernels (list|tuple): Kernel sizes.
extra (dict): Configs for extra conv layers. Default: None
"""
def __init__(self,
in_channels,
out_channels,
depth_size=64,
num_deconv_layers=3,
num_deconv_filters=(256, 256, 256),
num_deconv_kernels=(4, 4, 4),
extra=None):
super().__init__()
assert out_channels % depth_size == 0
self.depth_size = depth_size
self.in_channels = in_channels
if extra is not None and not isinstance(extra, dict):
raise TypeError('extra should be dict or None.')
if num_deconv_layers > 0:
self.deconv_layers = self._make_deconv_layer(
num_deconv_layers,
num_deconv_filters,
num_deconv_kernels,
)
elif num_deconv_layers == 0:
self.deconv_layers = nn.Identity()
else:
raise ValueError(
f'num_deconv_layers ({num_deconv_layers}) should >= 0.')
identity_final_layer = False
if extra is not None and 'final_conv_kernel' in extra:
assert extra['final_conv_kernel'] in [0, 1, 3]
if extra['final_conv_kernel'] == 3:
padding = 1
elif extra['final_conv_kernel'] == 1:
padding = 0
else:
# 0 for Identity mapping.
identity_final_layer = True
kernel_size = extra['final_conv_kernel']
else:
kernel_size = 1
padding = 0
if identity_final_layer:
self.final_layer = nn.Identity()
else:
conv_channels = num_deconv_filters[
-1] if num_deconv_layers > 0 else self.in_channels
layers = []
if extra is not None:
num_conv_layers = extra.get('num_conv_layers', 0)
num_conv_kernels = extra.get('num_conv_kernels',
[1] * num_conv_layers)
for i in range(num_conv_layers):
layers.append(
build_conv_layer(
dict(type='Conv2d'),
in_channels=conv_channels,
out_channels=conv_channels,
kernel_size=num_conv_kernels[i],
stride=1,
padding=(num_conv_kernels[i] - 1) // 2))
layers.append(
build_norm_layer(dict(type='BN'), conv_channels)[1])
layers.append(nn.ReLU(inplace=True))
layers.append(
build_conv_layer(
cfg=dict(type='Conv2d'),
in_channels=conv_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=1,
padding=padding))
if len(layers) > 1:
self.final_layer = nn.Sequential(*layers)
else:
self.final_layer = layers[0]
def _make_deconv_layer(self, num_layers, num_filters, num_kernels):
"""Make deconv layers."""
if num_layers != len(num_filters):
error_msg = f'num_layers({num_layers}) ' \
f'!= length of num_filters({len(num_filters)})'
raise ValueError(error_msg)
if num_layers != len(num_kernels):
error_msg = f'num_layers({num_layers}) ' \
f'!= length of num_kernels({len(num_kernels)})'
raise ValueError(error_msg)
layers = []
for i in range(num_layers):
kernel, padding, output_padding = \
self._get_deconv_cfg(num_kernels[i])
planes = num_filters[i]
layers.append(
build_upsample_layer(
dict(type='deconv'),
in_channels=self.in_channels,
out_channels=planes,
kernel_size=kernel,
stride=2,
padding=padding,
output_padding=output_padding,
bias=False))
layers.append(nn.BatchNorm2d(planes))
layers.append(nn.ReLU(inplace=True))
self.in_channels = planes
return nn.Sequential(*layers)
@staticmethod
def _get_deconv_cfg(deconv_kernel):
"""Get configurations for deconv layers."""
if deconv_kernel == 4:
padding = 1
output_padding = 0
elif deconv_kernel == 3:
padding = 1
output_padding = 1
elif deconv_kernel == 2:
padding = 0
output_padding = 0
else:
raise ValueError(f'Not supported num_kernels ({deconv_kernel}).')
return deconv_kernel, padding, output_padding
def forward(self, x):
"""Forward function."""
x = self.deconv_layers(x)
x = self.final_layer(x)
N, C, H, W = x.shape
# reshape the 2D heatmap to 3D heatmap
x = x.reshape(N, C // self.depth_size, self.depth_size, H, W)
return x
def init_weights(self):
"""Initialize model weights."""
for _, m in self.deconv_layers.named_modules():
if isinstance(m, nn.ConvTranspose2d):
normal_init(m, std=0.001)
elif isinstance(m, nn.BatchNorm2d):
constant_init(m, 1)
for m in self.final_layer.modules():
if isinstance(m, nn.Conv2d):
normal_init(m, std=0.001, bias=0)
elif isinstance(m, nn.BatchNorm2d):
constant_init(m, 1)
class Heatmap1DHead(nn.Module):
"""Heatmap1DHead is a sub-module of Interhand3DHead, and outputs 1D
heatmaps.
Args:
in_channels (int): Number of input channels
heatmap_size (int): Heatmap size
hidden_dims (list|tuple): Number of feature dimension of FC layers.
"""
def __init__(self, in_channels=2048, heatmap_size=64, hidden_dims=(512, )):
super().__init__()
self.in_channels = in_channels
self.heatmap_size = heatmap_size
feature_dims = [in_channels, *hidden_dims, heatmap_size]
self.fc = self._make_linear_layers(feature_dims, relu_final=False)
def soft_argmax_1d(self, heatmap1d):
heatmap1d = F.softmax(heatmap1d, 1)
accu = heatmap1d * torch.arange(
self.heatmap_size, dtype=heatmap1d.dtype,
device=heatmap1d.device)[None, :]
coord = accu.sum(dim=1)
return coord
def _make_linear_layers(self, feat_dims, relu_final=False):
"""Make linear layers."""
layers = []
for i in range(len(feat_dims) - 1):
layers.append(nn.Linear(feat_dims[i], feat_dims[i + 1]))
if i < len(feat_dims) - 2 or \
(i == len(feat_dims) - 2 and relu_final):
layers.append(nn.ReLU(inplace=True))
return nn.Sequential(*layers)
def forward(self, x):
"""Forward function."""
heatmap1d = self.fc(x)
value = self.soft_argmax_1d(heatmap1d).view(-1, 1)
return value
def init_weights(self):
"""Initialize model weights."""
for m in self.fc.modules():
if isinstance(m, nn.Linear):
normal_init(m, mean=0, std=0.01, bias=0)
class MultilabelClassificationHead(nn.Module):
"""MultilabelClassificationHead is a sub-module of Interhand3DHead, and
outputs hand type classification.
Args:
in_channels (int): Number of input channels
num_labels (int): Number of labels
hidden_dims (list|tuple): Number of hidden dimension of FC layers.
"""
def __init__(self, in_channels=2048, num_labels=2, hidden_dims=(512, )):
super().__init__()
self.in_channels = in_channels
self.num_labesl = num_labels
feature_dims = [in_channels, *hidden_dims, num_labels]
self.fc = self._make_linear_layers(feature_dims, relu_final=False)
def _make_linear_layers(self, feat_dims, relu_final=False):
"""Make linear layers."""
layers = []
for i in range(len(feat_dims) - 1):
layers.append(nn.Linear(feat_dims[i], feat_dims[i + 1]))
if i < len(feat_dims) - 2 or \
(i == len(feat_dims) - 2 and relu_final):
layers.append(nn.ReLU(inplace=True))
return nn.Sequential(*layers)
def forward(self, x):
"""Forward function."""
labels = torch.sigmoid(self.fc(x))
return labels
def init_weights(self):
for m in self.fc.modules():
if isinstance(m, nn.Linear):
normal_init(m, mean=0, std=0.01, bias=0)
@HEADS.register_module()
class Interhand3DHead(nn.Module):
"""Interhand 3D head of paper ref: Gyeongsik Moon. "InterHand2.6M: A
Dataset and Baseline for 3D Interacting Hand Pose Estimation from a Single
RGB Image".
Args:
keypoint_head_cfg (dict): Configs of Heatmap3DHead for hand
keypoint estimation.
root_head_cfg (dict): Configs of Heatmap1DHead for relative
hand root depth estimation.
hand_type_head_cfg (dict): Configs of MultilabelClassificationHead
for hand type classification.
loss_keypoint (dict): Config for keypoint loss. Default: None.
loss_root_depth (dict): Config for relative root depth loss.
Default: None.
loss_hand_type (dict): Config for hand type classification
loss. Default: None.
"""
def __init__(self,
keypoint_head_cfg,
root_head_cfg,
hand_type_head_cfg,
loss_keypoint=None,
loss_root_depth=None,
loss_hand_type=None,
train_cfg=None,
test_cfg=None):
super().__init__()
# build sub-module heads
self.right_hand_head = Heatmap3DHead(**keypoint_head_cfg)
self.left_hand_head = Heatmap3DHead(**keypoint_head_cfg)
self.root_head = Heatmap1DHead(**root_head_cfg)
self.hand_type_head = MultilabelClassificationHead(
**hand_type_head_cfg)
self.neck = GlobalAveragePooling()
# build losses
self.keypoint_loss = build_loss(loss_keypoint)
self.root_depth_loss = build_loss(loss_root_depth)
self.hand_type_loss = build_loss(loss_hand_type)
self.train_cfg = {} if train_cfg is None else train_cfg
self.test_cfg = {} if test_cfg is None else test_cfg
self.target_type = self.test_cfg.get('target_type', 'GaussianHeatmap')
def init_weights(self):
self.left_hand_head.init_weights()
self.right_hand_head.init_weights()
self.root_head.init_weights()
self.hand_type_head.init_weights()
def get_loss(self, output, target, target_weight):
"""Calculate loss for hand keypoint heatmaps, relative root depth and
hand type.
Args:
output (list[Tensor]): a list of outputs from multiple heads.
target (list[Tensor]): a list of targets for multiple heads.
target_weight (list[Tensor]): a list of targets weight for
multiple heads.
"""
losses = dict()
# hand keypoint loss
assert not isinstance(self.keypoint_loss, nn.Sequential)
out, tar, tar_weight = output[0], target[0], target_weight[0]
assert tar.dim() == 5 and tar_weight.dim() == 3
losses['hand_loss'] = self.keypoint_loss(out, tar, tar_weight)
# relative root depth loss
assert not isinstance(self.root_depth_loss, nn.Sequential)
out, tar, tar_weight = output[1], target[1], target_weight[1]
assert tar.dim() == 2 and tar_weight.dim() == 2
losses['rel_root_loss'] = self.root_depth_loss(out, tar, tar_weight)
# hand type loss
assert not isinstance(self.hand_type_loss, nn.Sequential)
out, tar, tar_weight = output[2], target[2], target_weight[2]
assert tar.dim() == 2 and tar_weight.dim() in [1, 2]
losses['hand_type_loss'] = self.hand_type_loss(out, tar, tar_weight)
return losses
def get_accuracy(self, output, target, target_weight):
"""Calculate accuracy for hand type.
Args:
output (list[Tensor]): a list of outputs from multiple heads.
target (list[Tensor]): a list of targets for multiple heads.
target_weight (list[Tensor]): a list of targets weight for
multiple heads.
"""
accuracy = dict()
avg_acc = multilabel_classification_accuracy(
output[2].detach().cpu().numpy(),
target[2].detach().cpu().numpy(),
target_weight[2].detach().cpu().numpy(),
)
accuracy['acc_classification'] = float(avg_acc)
return accuracy
def forward(self, x):
"""Forward function."""
outputs = []
outputs.append(
torch.cat([self.right_hand_head(x),
self.left_hand_head(x)], dim=1))
x = self.neck(x)
outputs.append(self.root_head(x))
outputs.append(self.hand_type_head(x))
return outputs
def inference_model(self, x, flip_pairs=None):
"""Inference function.
Returns:
output (list[np.ndarray]): list of output hand keypoint
heatmaps, relative root depth and hand type.
Args:
x (torch.Tensor[N,K,H,W]): Input features.
flip_pairs (None | list[tuple()):
Pairs of keypoints which are mirrored.
"""
output = self.forward(x)
if flip_pairs is not None:
# flip 3D heatmap
heatmap_3d = output[0]
N, K, D, H, W = heatmap_3d.shape
# reshape 3D heatmap to 2D heatmap
heatmap_3d = heatmap_3d.reshape(N, K * D, H, W)
# 2D heatmap flip
heatmap_3d_flipped_back = flip_back(
heatmap_3d.detach().cpu().numpy(),
flip_pairs,
target_type=self.target_type)
# reshape back to 3D heatmap
heatmap_3d_flipped_back = heatmap_3d_flipped_back.reshape(
N, K, D, H, W)
# feature is not aligned, shift flipped heatmap for higher accuracy
if self.test_cfg.get('shift_heatmap', False):
heatmap_3d_flipped_back[...,
1:] = heatmap_3d_flipped_back[..., :-1]
output[0] = heatmap_3d_flipped_back
# flip relative hand root depth
output[1] = -output[1].detach().cpu().numpy()
# flip hand type
hand_type = output[2].detach().cpu().numpy()
hand_type_flipped_back = hand_type.copy()
hand_type_flipped_back[:, 0] = hand_type[:, 1]
hand_type_flipped_back[:, 1] = hand_type[:, 0]
output[2] = hand_type_flipped_back
else:
output = [out.detach().cpu().numpy() for out in output]
return output
def decode(self, img_metas, output, **kwargs):
"""Decode hand keypoint, relative root depth and hand type.
Args:
img_metas (list(dict)): Information about data augmentation
By default this includes:
- "image_file: path to the image file
- "center": center of the bbox
- "scale": scale of the bbox
- "rotation": rotation of the bbox
- "bbox_score": score of bbox
- "heatmap3d_depth_bound": depth bound of hand keypoint
3D heatmap
- "root_depth_bound": depth bound of relative root depth
1D heatmap
output (list[np.ndarray]): model predicted 3D heatmaps, relative
root depth and hand type.
"""
batch_size = len(img_metas)
result = {}
heatmap3d_depth_bound = np.ones(batch_size, dtype=np.float32)
root_depth_bound = np.ones(batch_size, dtype=np.float32)
center = np.zeros((batch_size, 2), dtype=np.float32)
scale = np.zeros((batch_size, 2), dtype=np.float32)
image_paths = []
score = np.ones(batch_size, dtype=np.float32)
if 'bbox_id' in img_metas[0]:
bbox_ids = []
else:
bbox_ids = None
for i in range(batch_size):
heatmap3d_depth_bound[i] = img_metas[i]['heatmap3d_depth_bound']
root_depth_bound[i] = img_metas[i]['root_depth_bound']
center[i, :] = img_metas[i]['center']
scale[i, :] = img_metas[i]['scale']
image_paths.append(img_metas[i]['image_file'])
if 'bbox_score' in img_metas[i]:
score[i] = np.array(img_metas[i]['bbox_score']).reshape(-1)
if bbox_ids is not None:
bbox_ids.append(img_metas[i]['bbox_id'])
all_boxes = np.zeros((batch_size, 6), dtype=np.float32)
all_boxes[:, 0:2] = center[:, 0:2]
all_boxes[:, 2:4] = scale[:, 0:2]
# scale is defined as: bbox_size / 200.0, so we
# need multiply 200.0 to get bbox size
all_boxes[:, 4] = np.prod(scale * 200.0, axis=1)
all_boxes[:, 5] = score
result['boxes'] = all_boxes
result['image_paths'] = image_paths
result['bbox_ids'] = bbox_ids
# decode 3D heatmaps of hand keypoints
heatmap3d = output[0]
preds, maxvals = keypoints_from_heatmaps3d(heatmap3d, center, scale)
keypoints_3d = np.zeros((batch_size, preds.shape[1], 4),
dtype=np.float32)
keypoints_3d[:, :, 0:3] = preds[:, :, 0:3]
keypoints_3d[:, :, 3:4] = maxvals
# transform keypoint depth to camera space
keypoints_3d[:, :, 2] = \
(keypoints_3d[:, :, 2] / self.right_hand_head.depth_size - 0.5) \
* heatmap3d_depth_bound[:, np.newaxis]
result['preds'] = keypoints_3d
# decode relative hand root depth
# transform relative root depth to camera space
result['rel_root_depth'] = (output[1] / self.root_head.heatmap_size -
0.5) * root_depth_bound
# decode hand type
result['hand_type'] = output[2] > 0.5
return result