Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from abc import ABCMeta, abstractmethod | |
from typing import Tuple, Union | |
from mmengine.model import BaseModule | |
from mmengine.structures import InstanceData | |
from torch import Tensor | |
from mmpose.utils.tensor_utils import to_numpy | |
from mmpose.utils.typing import (Features, InstanceList, OptConfigType, | |
OptSampleList, Predictions) | |
class BaseHead(BaseModule, metaclass=ABCMeta): | |
"""Base head. A subclass should override :meth:`predict` and :meth:`loss`. | |
Args: | |
init_cfg (dict, optional): The extra init config of layers. | |
Defaults to None. | |
""" | |
def forward(self, feats: Tuple[Tensor]): | |
"""Forward the network.""" | |
def predict(self, | |
feats: Features, | |
batch_data_samples: OptSampleList, | |
test_cfg: OptConfigType = {}) -> Predictions: | |
"""Predict results from features.""" | |
def loss(self, | |
feats: Tuple[Tensor], | |
batch_data_samples: OptSampleList, | |
train_cfg: OptConfigType = {}) -> dict: | |
"""Calculate losses from a batch of inputs and data samples.""" | |
def decode(self, batch_outputs: Union[Tensor, | |
Tuple[Tensor]]) -> InstanceList: | |
"""Decode keypoints from outputs. | |
Args: | |
batch_outputs (Tensor | Tuple[Tensor]): The network outputs of | |
a data batch | |
Returns: | |
List[InstanceData]: A list of InstanceData, each contains the | |
decoded pose information of the instances of one data sample. | |
""" | |
def _pack_and_call(args, func): | |
if not isinstance(args, tuple): | |
args = (args, ) | |
return func(*args) | |
if self.decoder is None: | |
raise RuntimeError( | |
f'The decoder has not been set in {self.__class__.__name__}. ' | |
'Please set the decoder configs in the init parameters to ' | |
'enable head methods `head.predict()` and `head.decode()`') | |
if self.decoder.support_batch_decoding: | |
batch_keypoints, batch_scores = _pack_and_call( | |
batch_outputs, self.decoder.batch_decode) | |
else: | |
batch_output_np = to_numpy(batch_outputs, unzip=True) | |
batch_keypoints = [] | |
batch_scores = [] | |
for outputs in batch_output_np: | |
keypoints, scores = _pack_and_call(outputs, | |
self.decoder.decode) | |
batch_keypoints.append(keypoints) | |
batch_scores.append(scores) | |
preds = [ | |
InstanceData(keypoints=keypoints, keypoint_scores=scores) | |
for keypoints, scores in zip(batch_keypoints, batch_scores) | |
] | |
return preds | |