Spaces:
Build error
Build error
File size: 9,137 Bytes
d7a991a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 |
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.cnn import (build_conv_layer, build_upsample_layer, constant_init,
normal_init)
from mmpose.models.builder import build_loss
from ..backbones.resnet import BasicBlock
from ..builder import HEADS
@HEADS.register_module()
class AEHigherResolutionHead(nn.Module):
"""Associative embedding with higher resolution head. paper ref: Bowen
Cheng et al. "HigherHRNet: Scale-Aware Representation Learning for Bottom-
Up Human Pose Estimation".
Args:
in_channels (int): Number of input channels.
num_joints (int): Number of joints
tag_per_joint (bool): If tag_per_joint is True,
the dimension of tags equals to num_joints,
else the dimension of tags is 1. Default: True
extra (dict): Configs for extra conv layers. Default: None
num_deconv_layers (int): Number of deconv layers.
num_deconv_layers should >= 0. Note that 0 means
no deconv layers.
num_deconv_filters (list|tuple): Number of filters.
If num_deconv_layers > 0, the length of
num_deconv_kernels (list|tuple): Kernel sizes.
cat_output (list[bool]): Option to concat outputs.
with_ae_loss (list[bool]): Option to use ae loss.
loss_keypoint (dict): Config for loss. Default: None.
"""
def __init__(self,
in_channels,
num_joints,
tag_per_joint=True,
extra=None,
num_deconv_layers=1,
num_deconv_filters=(32, ),
num_deconv_kernels=(4, ),
num_basic_blocks=4,
cat_output=None,
with_ae_loss=None,
loss_keypoint=None):
super().__init__()
self.loss = build_loss(loss_keypoint)
dim_tag = num_joints if tag_per_joint else 1
self.num_deconvs = num_deconv_layers
self.cat_output = cat_output
final_layer_output_channels = []
if with_ae_loss[0]:
out_channels = num_joints + dim_tag
else:
out_channels = num_joints
final_layer_output_channels.append(out_channels)
for i in range(num_deconv_layers):
if with_ae_loss[i + 1]:
out_channels = num_joints + dim_tag
else:
out_channels = num_joints
final_layer_output_channels.append(out_channels)
deconv_layer_output_channels = []
for i in range(num_deconv_layers):
if with_ae_loss[i]:
out_channels = num_joints + dim_tag
else:
out_channels = num_joints
deconv_layer_output_channels.append(out_channels)
self.final_layers = self._make_final_layers(
in_channels, final_layer_output_channels, extra, num_deconv_layers,
num_deconv_filters)
self.deconv_layers = self._make_deconv_layers(
in_channels, deconv_layer_output_channels, num_deconv_layers,
num_deconv_filters, num_deconv_kernels, num_basic_blocks,
cat_output)
@staticmethod
def _make_final_layers(in_channels, final_layer_output_channels, extra,
num_deconv_layers, num_deconv_filters):
"""Make final layers."""
if extra is not None and 'final_conv_kernel' in extra:
assert extra['final_conv_kernel'] in [1, 3]
if extra['final_conv_kernel'] == 3:
padding = 1
else:
padding = 0
kernel_size = extra['final_conv_kernel']
else:
kernel_size = 1
padding = 0
final_layers = []
final_layers.append(
build_conv_layer(
cfg=dict(type='Conv2d'),
in_channels=in_channels,
out_channels=final_layer_output_channels[0],
kernel_size=kernel_size,
stride=1,
padding=padding))
for i in range(num_deconv_layers):
in_channels = num_deconv_filters[i]
final_layers.append(
build_conv_layer(
cfg=dict(type='Conv2d'),
in_channels=in_channels,
out_channels=final_layer_output_channels[i + 1],
kernel_size=kernel_size,
stride=1,
padding=padding))
return nn.ModuleList(final_layers)
def _make_deconv_layers(self, in_channels, deconv_layer_output_channels,
num_deconv_layers, num_deconv_filters,
num_deconv_kernels, num_basic_blocks, cat_output):
"""Make deconv layers."""
deconv_layers = []
for i in range(num_deconv_layers):
if cat_output[i]:
in_channels += deconv_layer_output_channels[i]
planes = num_deconv_filters[i]
deconv_kernel, padding, output_padding = \
self._get_deconv_cfg(num_deconv_kernels[i])
layers = []
layers.append(
nn.Sequential(
build_upsample_layer(
dict(type='deconv'),
in_channels=in_channels,
out_channels=planes,
kernel_size=deconv_kernel,
stride=2,
padding=padding,
output_padding=output_padding,
bias=False), nn.BatchNorm2d(planes, momentum=0.1),
nn.ReLU(inplace=True)))
for _ in range(num_basic_blocks):
layers.append(nn.Sequential(BasicBlock(planes, planes), ))
deconv_layers.append(nn.Sequential(*layers))
in_channels = planes
return nn.ModuleList(deconv_layers)
@staticmethod
def _get_deconv_cfg(deconv_kernel):
"""Get configurations for deconv layers."""
if deconv_kernel == 4:
padding = 1
output_padding = 0
elif deconv_kernel == 3:
padding = 1
output_padding = 1
elif deconv_kernel == 2:
padding = 0
output_padding = 0
else:
raise ValueError(f'Not supported num_kernels ({deconv_kernel}).')
return deconv_kernel, padding, output_padding
def get_loss(self, outputs, targets, masks, joints):
"""Calculate bottom-up keypoint loss.
Note:
- batch_size: N
- num_keypoints: K
- num_outputs: O
- heatmaps height: H
- heatmaps weight: W
Args:
outputs (list(torch.Tensor[N,K,H,W])): Multi-scale output heatmaps.
targets (List(torch.Tensor[N,K,H,W])): Multi-scale target heatmaps.
masks (List(torch.Tensor[N,H,W])): Masks of multi-scale target
heatmaps
joints (List(torch.Tensor[N,M,K,2])): Joints of multi-scale target
heatmaps for ae loss
"""
losses = dict()
heatmaps_losses, push_losses, pull_losses = self.loss(
outputs, targets, masks, joints)
for idx in range(len(targets)):
if heatmaps_losses[idx] is not None:
heatmaps_loss = heatmaps_losses[idx].mean(dim=0)
if 'heatmap_loss' not in losses:
losses['heatmap_loss'] = heatmaps_loss
else:
losses['heatmap_loss'] += heatmaps_loss
if push_losses[idx] is not None:
push_loss = push_losses[idx].mean(dim=0)
if 'push_loss' not in losses:
losses['push_loss'] = push_loss
else:
losses['push_loss'] += push_loss
if pull_losses[idx] is not None:
pull_loss = pull_losses[idx].mean(dim=0)
if 'pull_loss' not in losses:
losses['pull_loss'] = pull_loss
else:
losses['pull_loss'] += pull_loss
return losses
def forward(self, x):
"""Forward function."""
if isinstance(x, list):
x = x[0]
final_outputs = []
y = self.final_layers[0](x)
final_outputs.append(y)
for i in range(self.num_deconvs):
if self.cat_output[i]:
x = torch.cat((x, y), 1)
x = self.deconv_layers[i](x)
y = self.final_layers[i + 1](x)
final_outputs.append(y)
return final_outputs
def init_weights(self):
"""Initialize model weights."""
for _, m in self.deconv_layers.named_modules():
if isinstance(m, nn.ConvTranspose2d):
normal_init(m, std=0.001)
elif isinstance(m, nn.BatchNorm2d):
constant_init(m, 1)
for _, m in self.final_layers.named_modules():
if isinstance(m, nn.Conv2d):
normal_init(m, std=0.001, bias=0)
|