Spaces:
Build error
Build error
# Copyright (c) OpenMMLab. All rights reserved. | |
from abc import ABCMeta, abstractmethod | |
import numpy as np | |
import torch.nn as nn | |
from mmpose.core.evaluation.top_down_eval import keypoints_from_heatmaps | |
class TopdownHeatmapBaseHead(nn.Module): | |
"""Base class for top-down heatmap heads. | |
All top-down heatmap heads should subclass it. | |
All subclass should overwrite: | |
Methods:`get_loss`, supporting to calculate loss. | |
Methods:`get_accuracy`, supporting to calculate accuracy. | |
Methods:`forward`, supporting to forward model. | |
Methods:`inference_model`, supporting to inference model. | |
""" | |
__metaclass__ = ABCMeta | |
def get_loss(self, **kwargs): | |
"""Gets the loss.""" | |
def get_accuracy(self, **kwargs): | |
"""Gets the accuracy.""" | |
def forward(self, **kwargs): | |
"""Forward function.""" | |
def inference_model(self, **kwargs): | |
"""Inference function.""" | |
def decode(self, img_metas, output, **kwargs): | |
"""Decode keypoints from heatmaps. | |
Args: | |
img_metas (list(dict)): Information about data augmentation | |
By default this includes: | |
- "image_file: path to the image file | |
- "center": center of the bbox | |
- "scale": scale of the bbox | |
- "rotation": rotation of the bbox | |
- "bbox_score": score of bbox | |
output (np.ndarray[N, K, H, W]): model predicted heatmaps. | |
""" | |
batch_size = len(img_metas) | |
if 'bbox_id' in img_metas[0]: | |
bbox_ids = [] | |
else: | |
bbox_ids = None | |
c = np.zeros((batch_size, 2), dtype=np.float32) | |
s = np.zeros((batch_size, 2), dtype=np.float32) | |
image_paths = [] | |
score = np.ones(batch_size) | |
for i in range(batch_size): | |
c[i, :] = img_metas[i]['center'] | |
s[i, :] = img_metas[i]['scale'] | |
image_paths.append(img_metas[i]['image_file']) | |
if 'bbox_score' in img_metas[i]: | |
score[i] = np.array(img_metas[i]['bbox_score']).reshape(-1) | |
if bbox_ids is not None: | |
bbox_ids.append(img_metas[i]['bbox_id']) | |
preds, maxvals = keypoints_from_heatmaps( | |
output, | |
c, | |
s, | |
unbiased=self.test_cfg.get('unbiased_decoding', False), | |
post_process=self.test_cfg.get('post_process', 'default'), | |
kernel=self.test_cfg.get('modulate_kernel', 11), | |
valid_radius_factor=self.test_cfg.get('valid_radius_factor', | |
0.0546875), | |
use_udp=self.test_cfg.get('use_udp', False), | |
target_type=self.test_cfg.get('target_type', 'GaussianHeatmap')) | |
all_preds = np.zeros((batch_size, preds.shape[1], 3), dtype=np.float32) | |
all_boxes = np.zeros((batch_size, 6), dtype=np.float32) | |
all_preds[:, :, 0:2] = preds[:, :, 0:2] | |
all_preds[:, :, 2:3] = maxvals | |
all_boxes[:, 0:2] = c[:, 0:2] | |
all_boxes[:, 2:4] = s[:, 0:2] | |
all_boxes[:, 4] = np.prod(s * 200.0, axis=1) | |
all_boxes[:, 5] = score | |
result = {} | |
result['preds'] = all_preds | |
result['boxes'] = all_boxes | |
result['image_paths'] = image_paths | |
result['bbox_ids'] = bbox_ids | |
return result | |
def _get_deconv_cfg(deconv_kernel): | |
"""Get configurations for deconv layers.""" | |
if deconv_kernel == 4: | |
padding = 1 | |
output_padding = 0 | |
elif deconv_kernel == 3: | |
padding = 1 | |
output_padding = 1 | |
elif deconv_kernel == 2: | |
padding = 0 | |
output_padding = 0 | |
else: | |
raise ValueError(f'Not supported num_kernels ({deconv_kernel}).') | |
return deconv_kernel, padding, output_padding | |