Spaces:
Runtime error
Runtime error
File size: 6,676 Bytes
cc0dd3c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Sequence, Tuple, Union
import numpy as np
import torch
from torch import Tensor, nn
from mmpose.evaluation.functional import keypoint_pck_accuracy
from mmpose.models.utils.tta import flip_coordinates
from mmpose.registry import KEYPOINT_CODECS, MODELS
from mmpose.utils.tensor_utils import to_numpy
from mmpose.utils.typing import (ConfigType, OptConfigType, OptSampleList,
Predictions)
from ..base_head import BaseHead
OptIntSeq = Optional[Sequence[int]]
@MODELS.register_module()
class RLEHead(BaseHead):
"""Top-down regression head introduced in `RLE`_ by Li et al(2021). The
head is composed of fully-connected layers to predict the coordinates and
sigma(the variance of the coordinates) together.
Args:
in_channels (int | sequence[int]): Number of input channels
num_joints (int): Number of joints
loss (Config): Config for keypoint loss. Defaults to use
:class:`RLELoss`
decoder (Config, optional): The decoder config that controls decoding
keypoint coordinates from the network output. Defaults to ``None``
init_cfg (Config, optional): Config to control the initialization. See
:attr:`default_init_cfg` for default settings
.. _`RLE`: https://arxiv.org/abs/2107.11291
"""
_version = 2
def __init__(self,
in_channels: Union[int, Sequence[int]],
num_joints: int,
loss: ConfigType = dict(
type='RLELoss', use_target_weight=True),
decoder: OptConfigType = None,
init_cfg: OptConfigType = None):
if init_cfg is None:
init_cfg = self.default_init_cfg
super().__init__(init_cfg)
self.in_channels = in_channels
self.num_joints = num_joints
self.loss_module = MODELS.build(loss)
if decoder is not None:
self.decoder = KEYPOINT_CODECS.build(decoder)
else:
self.decoder = None
# Define fully-connected layers
self.fc = nn.Linear(in_channels, self.num_joints * 4)
# Register the hook to automatically convert old version state dicts
self._register_load_state_dict_pre_hook(self._load_state_dict_pre_hook)
def forward(self, feats: Tuple[Tensor]) -> Tensor:
"""Forward the network. The input is multi scale feature maps and the
output is the coordinates.
Args:
feats (Tuple[Tensor]): Multi scale feature maps.
Returns:
Tensor: output coordinates(and sigmas[optional]).
"""
x = feats[-1]
x = torch.flatten(x, 1)
x = self.fc(x)
return x.reshape(-1, self.num_joints, 4)
def predict(self,
feats: Tuple[Tensor],
batch_data_samples: OptSampleList,
test_cfg: ConfigType = {}) -> Predictions:
"""Predict results from outputs."""
if test_cfg.get('flip_test', False):
# TTA: flip test -> feats = [orig, flipped]
assert isinstance(feats, list) and len(feats) == 2
flip_indices = batch_data_samples[0].metainfo['flip_indices']
input_size = batch_data_samples[0].metainfo['input_size']
_feats, _feats_flip = feats
_batch_coords = self.forward(_feats)
_batch_coords[..., 2:] = _batch_coords[..., 2:].sigmoid()
_batch_coords_flip = flip_coordinates(
self.forward(_feats_flip),
flip_indices=flip_indices,
shift_coords=test_cfg.get('shift_coords', True),
input_size=input_size)
_batch_coords_flip[..., 2:] = _batch_coords_flip[..., 2:].sigmoid()
batch_coords = (_batch_coords + _batch_coords_flip) * 0.5
else:
batch_coords = self.forward(feats) # (B, K, D)
batch_coords[..., 2:] = batch_coords[..., 2:].sigmoid()
batch_coords.unsqueeze_(dim=1) # (B, N, K, D)
preds = self.decode(batch_coords)
return preds
def loss(self,
inputs: Tuple[Tensor],
batch_data_samples: OptSampleList,
train_cfg: ConfigType = {}) -> dict:
"""Calculate losses from a batch of inputs and data samples."""
pred_outputs = self.forward(inputs)
keypoint_labels = torch.cat(
[d.gt_instance_labels.keypoint_labels for d in batch_data_samples])
keypoint_weights = torch.cat([
d.gt_instance_labels.keypoint_weights for d in batch_data_samples
])
pred_coords = pred_outputs[:, :, :2]
pred_sigma = pred_outputs[:, :, 2:4]
# calculate losses
losses = dict()
loss = self.loss_module(pred_coords, pred_sigma, keypoint_labels,
keypoint_weights.unsqueeze(-1))
losses.update(loss_kpt=loss)
# calculate accuracy
_, avg_acc, _ = keypoint_pck_accuracy(
pred=to_numpy(pred_coords),
gt=to_numpy(keypoint_labels),
mask=to_numpy(keypoint_weights) > 0,
thr=0.05,
norm_factor=np.ones((pred_coords.size(0), 2), dtype=np.float32))
acc_pose = torch.tensor(avg_acc, device=keypoint_labels.device)
losses.update(acc_pose=acc_pose)
return losses
def _load_state_dict_pre_hook(self, state_dict, prefix, local_meta, *args,
**kwargs):
"""A hook function to convert old-version state dict of
:class:`TopdownHeatmapSimpleHead` (before MMPose v1.0.0) to a
compatible format of :class:`HeatmapHead`.
The hook will be automatically registered during initialization.
"""
version = local_meta.get('version', None)
if version and version >= self._version:
return
# convert old-version state dict
keys = list(state_dict.keys())
for _k in keys:
v = state_dict.pop(_k)
k = _k.lstrip(prefix)
# In old version, "loss" includes the instances of loss,
# now it should be renamed "loss_module"
k_parts = k.split('.')
if k_parts[0] == 'loss':
# loss.xxx -> loss_module.xxx
k_new = prefix + 'loss_module.' + '.'.join(k_parts[1:])
else:
k_new = _k
state_dict[k_new] = v
@property
def default_init_cfg(self):
init_cfg = [dict(type='Normal', layer=['Linear'], std=0.01, bias=0)]
return init_cfg
|