Spaces:
Build error
Build error
# Copyright (c) OpenMMLab. All rights reserved. | |
import torch | |
import torch.nn as nn | |
from mmcv.cnn import (build_conv_layer, build_upsample_layer, constant_init, | |
normal_init) | |
from mmpose.models.builder import build_loss | |
from ..backbones.resnet import BasicBlock | |
from ..builder import HEADS | |
class AEHigherResolutionHead(nn.Module): | |
"""Associative embedding with higher resolution head. paper ref: Bowen | |
Cheng et al. "HigherHRNet: Scale-Aware Representation Learning for Bottom- | |
Up Human Pose Estimation". | |
Args: | |
in_channels (int): Number of input channels. | |
num_joints (int): Number of joints | |
tag_per_joint (bool): If tag_per_joint is True, | |
the dimension of tags equals to num_joints, | |
else the dimension of tags is 1. Default: True | |
extra (dict): Configs for extra conv layers. Default: None | |
num_deconv_layers (int): Number of deconv layers. | |
num_deconv_layers should >= 0. Note that 0 means | |
no deconv layers. | |
num_deconv_filters (list|tuple): Number of filters. | |
If num_deconv_layers > 0, the length of | |
num_deconv_kernels (list|tuple): Kernel sizes. | |
cat_output (list[bool]): Option to concat outputs. | |
with_ae_loss (list[bool]): Option to use ae loss. | |
loss_keypoint (dict): Config for loss. Default: None. | |
""" | |
def __init__(self, | |
in_channels, | |
num_joints, | |
tag_per_joint=True, | |
extra=None, | |
num_deconv_layers=1, | |
num_deconv_filters=(32, ), | |
num_deconv_kernels=(4, ), | |
num_basic_blocks=4, | |
cat_output=None, | |
with_ae_loss=None, | |
loss_keypoint=None): | |
super().__init__() | |
self.loss = build_loss(loss_keypoint) | |
dim_tag = num_joints if tag_per_joint else 1 | |
self.num_deconvs = num_deconv_layers | |
self.cat_output = cat_output | |
final_layer_output_channels = [] | |
if with_ae_loss[0]: | |
out_channels = num_joints + dim_tag | |
else: | |
out_channels = num_joints | |
final_layer_output_channels.append(out_channels) | |
for i in range(num_deconv_layers): | |
if with_ae_loss[i + 1]: | |
out_channels = num_joints + dim_tag | |
else: | |
out_channels = num_joints | |
final_layer_output_channels.append(out_channels) | |
deconv_layer_output_channels = [] | |
for i in range(num_deconv_layers): | |
if with_ae_loss[i]: | |
out_channels = num_joints + dim_tag | |
else: | |
out_channels = num_joints | |
deconv_layer_output_channels.append(out_channels) | |
self.final_layers = self._make_final_layers( | |
in_channels, final_layer_output_channels, extra, num_deconv_layers, | |
num_deconv_filters) | |
self.deconv_layers = self._make_deconv_layers( | |
in_channels, deconv_layer_output_channels, num_deconv_layers, | |
num_deconv_filters, num_deconv_kernels, num_basic_blocks, | |
cat_output) | |
def _make_final_layers(in_channels, final_layer_output_channels, extra, | |
num_deconv_layers, num_deconv_filters): | |
"""Make final layers.""" | |
if extra is not None and 'final_conv_kernel' in extra: | |
assert extra['final_conv_kernel'] in [1, 3] | |
if extra['final_conv_kernel'] == 3: | |
padding = 1 | |
else: | |
padding = 0 | |
kernel_size = extra['final_conv_kernel'] | |
else: | |
kernel_size = 1 | |
padding = 0 | |
final_layers = [] | |
final_layers.append( | |
build_conv_layer( | |
cfg=dict(type='Conv2d'), | |
in_channels=in_channels, | |
out_channels=final_layer_output_channels[0], | |
kernel_size=kernel_size, | |
stride=1, | |
padding=padding)) | |
for i in range(num_deconv_layers): | |
in_channels = num_deconv_filters[i] | |
final_layers.append( | |
build_conv_layer( | |
cfg=dict(type='Conv2d'), | |
in_channels=in_channels, | |
out_channels=final_layer_output_channels[i + 1], | |
kernel_size=kernel_size, | |
stride=1, | |
padding=padding)) | |
return nn.ModuleList(final_layers) | |
def _make_deconv_layers(self, in_channels, deconv_layer_output_channels, | |
num_deconv_layers, num_deconv_filters, | |
num_deconv_kernels, num_basic_blocks, cat_output): | |
"""Make deconv layers.""" | |
deconv_layers = [] | |
for i in range(num_deconv_layers): | |
if cat_output[i]: | |
in_channels += deconv_layer_output_channels[i] | |
planes = num_deconv_filters[i] | |
deconv_kernel, padding, output_padding = \ | |
self._get_deconv_cfg(num_deconv_kernels[i]) | |
layers = [] | |
layers.append( | |
nn.Sequential( | |
build_upsample_layer( | |
dict(type='deconv'), | |
in_channels=in_channels, | |
out_channels=planes, | |
kernel_size=deconv_kernel, | |
stride=2, | |
padding=padding, | |
output_padding=output_padding, | |
bias=False), nn.BatchNorm2d(planes, momentum=0.1), | |
nn.ReLU(inplace=True))) | |
for _ in range(num_basic_blocks): | |
layers.append(nn.Sequential(BasicBlock(planes, planes), )) | |
deconv_layers.append(nn.Sequential(*layers)) | |
in_channels = planes | |
return nn.ModuleList(deconv_layers) | |
def _get_deconv_cfg(deconv_kernel): | |
"""Get configurations for deconv layers.""" | |
if deconv_kernel == 4: | |
padding = 1 | |
output_padding = 0 | |
elif deconv_kernel == 3: | |
padding = 1 | |
output_padding = 1 | |
elif deconv_kernel == 2: | |
padding = 0 | |
output_padding = 0 | |
else: | |
raise ValueError(f'Not supported num_kernels ({deconv_kernel}).') | |
return deconv_kernel, padding, output_padding | |
def get_loss(self, outputs, targets, masks, joints): | |
"""Calculate bottom-up keypoint loss. | |
Note: | |
- batch_size: N | |
- num_keypoints: K | |
- num_outputs: O | |
- heatmaps height: H | |
- heatmaps weight: W | |
Args: | |
outputs (list(torch.Tensor[N,K,H,W])): Multi-scale output heatmaps. | |
targets (List(torch.Tensor[N,K,H,W])): Multi-scale target heatmaps. | |
masks (List(torch.Tensor[N,H,W])): Masks of multi-scale target | |
heatmaps | |
joints (List(torch.Tensor[N,M,K,2])): Joints of multi-scale target | |
heatmaps for ae loss | |
""" | |
losses = dict() | |
heatmaps_losses, push_losses, pull_losses = self.loss( | |
outputs, targets, masks, joints) | |
for idx in range(len(targets)): | |
if heatmaps_losses[idx] is not None: | |
heatmaps_loss = heatmaps_losses[idx].mean(dim=0) | |
if 'heatmap_loss' not in losses: | |
losses['heatmap_loss'] = heatmaps_loss | |
else: | |
losses['heatmap_loss'] += heatmaps_loss | |
if push_losses[idx] is not None: | |
push_loss = push_losses[idx].mean(dim=0) | |
if 'push_loss' not in losses: | |
losses['push_loss'] = push_loss | |
else: | |
losses['push_loss'] += push_loss | |
if pull_losses[idx] is not None: | |
pull_loss = pull_losses[idx].mean(dim=0) | |
if 'pull_loss' not in losses: | |
losses['pull_loss'] = pull_loss | |
else: | |
losses['pull_loss'] += pull_loss | |
return losses | |
def forward(self, x): | |
"""Forward function.""" | |
if isinstance(x, list): | |
x = x[0] | |
final_outputs = [] | |
y = self.final_layers[0](x) | |
final_outputs.append(y) | |
for i in range(self.num_deconvs): | |
if self.cat_output[i]: | |
x = torch.cat((x, y), 1) | |
x = self.deconv_layers[i](x) | |
y = self.final_layers[i + 1](x) | |
final_outputs.append(y) | |
return final_outputs | |
def init_weights(self): | |
"""Initialize model weights.""" | |
for _, m in self.deconv_layers.named_modules(): | |
if isinstance(m, nn.ConvTranspose2d): | |
normal_init(m, std=0.001) | |
elif isinstance(m, nn.BatchNorm2d): | |
constant_init(m, 1) | |
for _, m in self.final_layers.named_modules(): | |
if isinstance(m, nn.Conv2d): | |
normal_init(m, std=0.001, bias=0) | |