# Copyright (c) OpenMMLab. All rights reserved. import copy import torch.nn as nn from mmcv.cnn import ConvModule, build_conv_layer, constant_init, kaiming_init from mmcv.utils.parrots_wrapper import _BatchNorm from mmpose.core import WeightNormClipHook from ..builder import BACKBONES from .base_backbone import BaseBackbone class BasicTemporalBlock(nn.Module): """Basic block for VideoPose3D. Args: in_channels (int): Input channels of this block. out_channels (int): Output channels of this block. mid_channels (int): The output channels of conv1. Default: 1024. kernel_size (int): Size of the convolving kernel. Default: 3. dilation (int): Spacing between kernel elements. Default: 3. dropout (float): Dropout rate. Default: 0.25. causal (bool): Use causal convolutions instead of symmetric convolutions (for real-time applications). Default: False. residual (bool): Use residual connection. Default: True. use_stride_conv (bool): Use optimized TCN that designed specifically for single-frame batching, i.e. where batches have input length = receptive field, and output length = 1. This implementation replaces dilated convolutions with strided convolutions to avoid generating unused intermediate results. Default: False. conv_cfg (dict): dictionary to construct and config conv layer. Default: dict(type='Conv1d'). norm_cfg (dict): dictionary to construct and config norm layer. Default: dict(type='BN1d'). """ def __init__(self, in_channels, out_channels, mid_channels=1024, kernel_size=3, dilation=3, dropout=0.25, causal=False, residual=True, use_stride_conv=False, conv_cfg=dict(type='Conv1d'), norm_cfg=dict(type='BN1d')): # Protect mutable default arguments conv_cfg = copy.deepcopy(conv_cfg) norm_cfg = copy.deepcopy(norm_cfg) super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.mid_channels = mid_channels self.kernel_size = kernel_size self.dilation = dilation self.dropout = dropout self.causal = causal self.residual = residual self.use_stride_conv = use_stride_conv self.pad = (kernel_size - 1) * dilation // 2 if use_stride_conv: self.stride = kernel_size self.causal_shift = kernel_size // 2 if causal else 0 self.dilation = 1 else: self.stride = 1 self.causal_shift = kernel_size // 2 * dilation if causal else 0 self.conv1 = nn.Sequential( ConvModule( in_channels, mid_channels, kernel_size=kernel_size, stride=self.stride, dilation=self.dilation, bias='auto', conv_cfg=conv_cfg, norm_cfg=norm_cfg)) self.conv2 = nn.Sequential( ConvModule( mid_channels, out_channels, kernel_size=1, bias='auto', conv_cfg=conv_cfg, norm_cfg=norm_cfg)) if residual and in_channels != out_channels: self.short_cut = build_conv_layer(conv_cfg, in_channels, out_channels, 1) else: self.short_cut = None self.dropout = nn.Dropout(dropout) if dropout > 0 else None def forward(self, x): """Forward function.""" if self.use_stride_conv: assert self.causal_shift + self.kernel_size // 2 < x.shape[2] else: assert 0 <= self.pad + self.causal_shift < x.shape[2] - \ self.pad + self.causal_shift <= x.shape[2] out = self.conv1(x) if self.dropout is not None: out = self.dropout(out) out = self.conv2(out) if self.dropout is not None: out = self.dropout(out) if self.residual: if self.use_stride_conv: res = x[:, :, self.causal_shift + self.kernel_size // 2::self.kernel_size] else: res = x[:, :, (self.pad + self.causal_shift):(x.shape[2] - self.pad + self.causal_shift)] if self.short_cut is not None: res = self.short_cut(res) out = out + res return out @BACKBONES.register_module() class TCN(BaseBackbone): """TCN backbone. Temporal Convolutional Networks. More details can be found in the `paper `__ . Args: in_channels (int): Number of input channels, which equals to num_keypoints * num_features. stem_channels (int): Number of feature channels. Default: 1024. num_blocks (int): NUmber of basic temporal convolutional blocks. Default: 2. kernel_sizes (Sequence[int]): Sizes of the convolving kernel of each basic block. Default: ``(3, 3, 3)``. dropout (float): Dropout rate. Default: 0.25. causal (bool): Use causal convolutions instead of symmetric convolutions (for real-time applications). Default: False. residual (bool): Use residual connection. Default: True. use_stride_conv (bool): Use TCN backbone optimized for single-frame batching, i.e. where batches have input length = receptive field, and output length = 1. This implementation replaces dilated convolutions with strided convolutions to avoid generating unused intermediate results. The weights are interchangeable with the reference implementation. Default: False conv_cfg (dict): dictionary to construct and config conv layer. Default: dict(type='Conv1d'). norm_cfg (dict): dictionary to construct and config norm layer. Default: dict(type='BN1d'). max_norm (float|None): if not None, the weight of convolution layers will be clipped to have a maximum norm of max_norm. Example: >>> from mmpose.models import TCN >>> import torch >>> self = TCN(in_channels=34) >>> self.eval() >>> inputs = torch.rand(1, 34, 243) >>> level_outputs = self.forward(inputs) >>> for level_out in level_outputs: ... print(tuple(level_out.shape)) (1, 1024, 235) (1, 1024, 217) """ def __init__(self, in_channels, stem_channels=1024, num_blocks=2, kernel_sizes=(3, 3, 3), dropout=0.25, causal=False, residual=True, use_stride_conv=False, conv_cfg=dict(type='Conv1d'), norm_cfg=dict(type='BN1d'), max_norm=None): # Protect mutable default arguments conv_cfg = copy.deepcopy(conv_cfg) norm_cfg = copy.deepcopy(norm_cfg) super().__init__() self.in_channels = in_channels self.stem_channels = stem_channels self.num_blocks = num_blocks self.kernel_sizes = kernel_sizes self.dropout = dropout self.causal = causal self.residual = residual self.use_stride_conv = use_stride_conv self.max_norm = max_norm assert num_blocks == len(kernel_sizes) - 1 for ks in kernel_sizes: assert ks % 2 == 1, 'Only odd filter widths are supported.' self.expand_conv = ConvModule( in_channels, stem_channels, kernel_size=kernel_sizes[0], stride=kernel_sizes[0] if use_stride_conv else 1, bias='auto', conv_cfg=conv_cfg, norm_cfg=norm_cfg) dilation = kernel_sizes[0] self.tcn_blocks = nn.ModuleList() for i in range(1, num_blocks + 1): self.tcn_blocks.append( BasicTemporalBlock( in_channels=stem_channels, out_channels=stem_channels, mid_channels=stem_channels, kernel_size=kernel_sizes[i], dilation=dilation, dropout=dropout, causal=causal, residual=residual, use_stride_conv=use_stride_conv, conv_cfg=conv_cfg, norm_cfg=norm_cfg)) dilation *= kernel_sizes[i] if self.max_norm is not None: # Apply weight norm clip to conv layers weight_clip = WeightNormClipHook(self.max_norm) for module in self.modules(): if isinstance(module, nn.modules.conv._ConvNd): weight_clip.register(module) self.dropout = nn.Dropout(dropout) if dropout > 0 else None def forward(self, x): """Forward function.""" x = self.expand_conv(x) if self.dropout is not None: x = self.dropout(x) outs = [] for i in range(self.num_blocks): x = self.tcn_blocks[i](x) outs.append(x) return tuple(outs) def init_weights(self, pretrained=None): """Initialize the weights.""" super().init_weights(pretrained) if pretrained is None: for m in self.modules(): if isinstance(m, nn.modules.conv._ConvNd): kaiming_init(m, mode='fan_in', nonlinearity='relu') elif isinstance(m, _BatchNorm): constant_init(m, 1)