Spaces:
Build error
Build error
# Copyright (c) OpenMMLab. All rights reserved. | |
import torch | |
import torch.nn as nn | |
from mmcv.cnn import (build_conv_layer, build_norm_layer, build_upsample_layer, | |
constant_init, normal_init) | |
from mmpose.models.builder import HEADS, build_loss | |
from mmpose.models.utils.ops import resize | |
class DeconvHead(nn.Module): | |
"""Simple deconv head. | |
Args: | |
in_channels (int): Number of input channels. | |
out_channels (int): Number of output channels. | |
num_deconv_layers (int): Number of deconv layers. | |
num_deconv_layers should >= 0. Note that 0 means | |
no deconv layers. | |
num_deconv_filters (list|tuple): Number of filters. | |
If num_deconv_layers > 0, the length of | |
num_deconv_kernels (list|tuple): Kernel sizes. | |
in_index (int|Sequence[int]): Input feature index. Default: 0 | |
input_transform (str|None): Transformation type of input features. | |
Options: 'resize_concat', 'multiple_select', None. | |
Default: None. | |
- 'resize_concat': Multiple feature maps will be resized to the | |
same size as the first one and then concat together. | |
Usually used in FCN head of HRNet. | |
- 'multiple_select': Multiple feature maps will be bundle into | |
a list and passed into decode head. | |
- None: Only one select feature map is allowed. | |
align_corners (bool): align_corners argument of F.interpolate. | |
Default: False. | |
loss_keypoint (dict): Config for loss. Default: None. | |
""" | |
def __init__(self, | |
in_channels=3, | |
out_channels=17, | |
num_deconv_layers=3, | |
num_deconv_filters=(256, 256, 256), | |
num_deconv_kernels=(4, 4, 4), | |
extra=None, | |
in_index=0, | |
input_transform=None, | |
align_corners=False, | |
loss_keypoint=None): | |
super().__init__() | |
self.in_channels = in_channels | |
self.loss = build_loss(loss_keypoint) | |
self._init_inputs(in_channels, in_index, input_transform) | |
self.in_index = in_index | |
self.align_corners = align_corners | |
if extra is not None and not isinstance(extra, dict): | |
raise TypeError('extra should be dict or None.') | |
if num_deconv_layers > 0: | |
self.deconv_layers = self._make_deconv_layer( | |
num_deconv_layers, | |
num_deconv_filters, | |
num_deconv_kernels, | |
) | |
elif num_deconv_layers == 0: | |
self.deconv_layers = nn.Identity() | |
else: | |
raise ValueError( | |
f'num_deconv_layers ({num_deconv_layers}) should >= 0.') | |
identity_final_layer = False | |
if extra is not None and 'final_conv_kernel' in extra: | |
assert extra['final_conv_kernel'] in [0, 1, 3] | |
if extra['final_conv_kernel'] == 3: | |
padding = 1 | |
elif extra['final_conv_kernel'] == 1: | |
padding = 0 | |
else: | |
# 0 for Identity mapping. | |
identity_final_layer = True | |
kernel_size = extra['final_conv_kernel'] | |
else: | |
kernel_size = 1 | |
padding = 0 | |
if identity_final_layer: | |
self.final_layer = nn.Identity() | |
else: | |
conv_channels = num_deconv_filters[ | |
-1] if num_deconv_layers > 0 else self.in_channels | |
layers = [] | |
if extra is not None: | |
num_conv_layers = extra.get('num_conv_layers', 0) | |
num_conv_kernels = extra.get('num_conv_kernels', | |
[1] * num_conv_layers) | |
for i in range(num_conv_layers): | |
layers.append( | |
build_conv_layer( | |
dict(type='Conv2d'), | |
in_channels=conv_channels, | |
out_channels=conv_channels, | |
kernel_size=num_conv_kernels[i], | |
stride=1, | |
padding=(num_conv_kernels[i] - 1) // 2)) | |
layers.append( | |
build_norm_layer(dict(type='BN'), conv_channels)[1]) | |
layers.append(nn.ReLU(inplace=True)) | |
layers.append( | |
build_conv_layer( | |
cfg=dict(type='Conv2d'), | |
in_channels=conv_channels, | |
out_channels=out_channels, | |
kernel_size=kernel_size, | |
stride=1, | |
padding=padding)) | |
if len(layers) > 1: | |
self.final_layer = nn.Sequential(*layers) | |
else: | |
self.final_layer = layers[0] | |
def _init_inputs(self, in_channels, in_index, input_transform): | |
"""Check and initialize input transforms. | |
The in_channels, in_index and input_transform must match. | |
Specifically, when input_transform is None, only single feature map | |
will be selected. So in_channels and in_index must be of type int. | |
When input_transform is not None, in_channels and in_index must be | |
list or tuple, with the same length. | |
Args: | |
in_channels (int|Sequence[int]): Input channels. | |
in_index (int|Sequence[int]): Input feature index. | |
input_transform (str|None): Transformation type of input features. | |
Options: 'resize_concat', 'multiple_select', None. | |
- 'resize_concat': Multiple feature maps will be resize to the | |
same size as first one and than concat together. | |
Usually used in FCN head of HRNet. | |
- 'multiple_select': Multiple feature maps will be bundle into | |
a list and passed into decode head. | |
- None: Only one select feature map is allowed. | |
""" | |
if input_transform is not None: | |
assert input_transform in ['resize_concat', 'multiple_select'] | |
self.input_transform = input_transform | |
self.in_index = in_index | |
if input_transform is not None: | |
assert isinstance(in_channels, (list, tuple)) | |
assert isinstance(in_index, (list, tuple)) | |
assert len(in_channels) == len(in_index) | |
if input_transform == 'resize_concat': | |
self.in_channels = sum(in_channels) | |
else: | |
self.in_channels = in_channels | |
else: | |
assert isinstance(in_channels, int) | |
assert isinstance(in_index, int) | |
self.in_channels = in_channels | |
def _transform_inputs(self, inputs): | |
"""Transform inputs for decoder. | |
Args: | |
inputs (list[Tensor] | Tensor): multi-level img features. | |
Returns: | |
Tensor: The transformed inputs | |
""" | |
if not isinstance(inputs, list): | |
return inputs | |
if self.input_transform == 'resize_concat': | |
inputs = [inputs[i] for i in self.in_index] | |
upsampled_inputs = [ | |
resize( | |
input=x, | |
size=inputs[0].shape[2:], | |
mode='bilinear', | |
align_corners=self.align_corners) for x in inputs | |
] | |
inputs = torch.cat(upsampled_inputs, dim=1) | |
elif self.input_transform == 'multiple_select': | |
inputs = [inputs[i] for i in self.in_index] | |
else: | |
inputs = inputs[self.in_index] | |
return inputs | |
def _make_deconv_layer(self, num_layers, num_filters, num_kernels): | |
"""Make deconv layers.""" | |
if num_layers != len(num_filters): | |
error_msg = f'num_layers({num_layers}) ' \ | |
f'!= length of num_filters({len(num_filters)})' | |
raise ValueError(error_msg) | |
if num_layers != len(num_kernels): | |
error_msg = f'num_layers({num_layers}) ' \ | |
f'!= length of num_kernels({len(num_kernels)})' | |
raise ValueError(error_msg) | |
layers = [] | |
for i in range(num_layers): | |
kernel, padding, output_padding = \ | |
self._get_deconv_cfg(num_kernels[i]) | |
planes = num_filters[i] | |
layers.append( | |
build_upsample_layer( | |
dict(type='deconv'), | |
in_channels=self.in_channels, | |
out_channels=planes, | |
kernel_size=kernel, | |
stride=2, | |
padding=padding, | |
output_padding=output_padding, | |
bias=False)) | |
layers.append(nn.BatchNorm2d(planes)) | |
layers.append(nn.ReLU(inplace=True)) | |
self.in_channels = planes | |
return nn.Sequential(*layers) | |
def _get_deconv_cfg(deconv_kernel): | |
"""Get configurations for deconv layers.""" | |
if deconv_kernel == 4: | |
padding = 1 | |
output_padding = 0 | |
elif deconv_kernel == 3: | |
padding = 1 | |
output_padding = 1 | |
elif deconv_kernel == 2: | |
padding = 0 | |
output_padding = 0 | |
else: | |
raise ValueError(f'Not supported num_kernels ({deconv_kernel}).') | |
return deconv_kernel, padding, output_padding | |
def get_loss(self, outputs, targets, masks): | |
"""Calculate bottom-up masked mse loss. | |
Note: | |
- batch_size: N | |
- num_channels: C | |
- heatmaps height: H | |
- heatmaps weight: W | |
Args: | |
outputs (List(torch.Tensor[N,C,H,W])): Multi-scale outputs. | |
targets (List(torch.Tensor[N,C,H,W])): Multi-scale targets. | |
masks (List(torch.Tensor[N,H,W])): Masks of multi-scale targets. | |
""" | |
losses = dict() | |
for idx in range(len(targets)): | |
if 'loss' not in losses: | |
losses['loss'] = self.loss(outputs[idx], targets[idx], | |
masks[idx]) | |
else: | |
losses['loss'] += self.loss(outputs[idx], targets[idx], | |
masks[idx]) | |
return losses | |
def forward(self, x): | |
"""Forward function.""" | |
x = self._transform_inputs(x) | |
final_outputs = [] | |
x = self.deconv_layers(x) | |
y = self.final_layer(x) | |
final_outputs.append(y) | |
return final_outputs | |
def init_weights(self): | |
"""Initialize model weights.""" | |
for _, m in self.deconv_layers.named_modules(): | |
if isinstance(m, nn.ConvTranspose2d): | |
normal_init(m, std=0.001) | |
elif isinstance(m, nn.BatchNorm2d): | |
constant_init(m, 1) | |
for m in self.final_layer.modules(): | |
if isinstance(m, nn.Conv2d): | |
normal_init(m, std=0.001, bias=0) | |
elif isinstance(m, nn.BatchNorm2d): | |
constant_init(m, 1) | |