import torch import torch.nn as nn import torch.nn.functional as F from mmcv.cnn import (ConvModule, DepthwiseSeparableConvModule, build_conv_layer, build_norm_layer, constant_init, normal_init) from torch.nn.modules.batchnorm import _BatchNorm import torch.utils.checkpoint as cp import mmcv from mmpose.utils import get_root_logger from mmpose.models.registry import BACKBONES from mmpose.models.backbones.resnet import BasicBlock, Bottleneck from mmpose.models.backbones.utils import load_checkpoint, channel_shuffle def channel_shuffle(x, groups): """Channel Shuffle operation. This function enables cross-group information flow for multiple groups convolution layers. Args: x (Tensor): The input tensor. groups (int): The number of groups to divide the input tensor in the channel dimension. Returns: Tensor: The output tensor after channel shuffle operation. """ batch_size, num_channels, height, width = x.size() assert (num_channels % groups == 0), ('num_channels should be ' 'divisible by groups') channels_per_group = num_channels // groups x = x.view(batch_size, groups, channels_per_group, height, width) x = torch.transpose(x, 1, 2).contiguous() x = x.view(batch_size, groups * channels_per_group, height, width) return x class SpatialWeighting(nn.Module): def __init__(self, channels, ratio=16, conv_cfg=None, act_cfg=(dict(type='ReLU'), dict(type='Sigmoid'))): super().__init__() if isinstance(act_cfg, dict): act_cfg = (act_cfg, act_cfg) assert len(act_cfg) == 2 assert mmcv.is_tuple_of(act_cfg, dict) self.global_avgpool = nn.AdaptiveAvgPool2d(1) self.conv1 = ConvModule( in_channels=channels, out_channels=int(channels / ratio), kernel_size=1, stride=1, conv_cfg=conv_cfg, act_cfg=act_cfg[0]) self.conv2 = ConvModule( in_channels=int(channels / ratio), out_channels=channels, kernel_size=1, stride=1, conv_cfg=conv_cfg, act_cfg=act_cfg[1]) def forward(self, x): out = self.global_avgpool(x) out = self.conv1(out) out = self.conv2(out) return x * out class CrossResolutionWeighting(nn.Module): def __init__(self, channels, ratio=16, conv_cfg=None, norm_cfg=None, act_cfg=(dict(type='ReLU'), dict(type='Sigmoid'))): super().__init__() if isinstance(act_cfg, dict): act_cfg = (act_cfg, act_cfg) assert len(act_cfg) == 2 assert mmcv.is_tuple_of(act_cfg, dict) self.channels = channels total_channel = sum(channels) self.conv1 = ConvModule( in_channels=total_channel, out_channels=int(total_channel / ratio), kernel_size=1, stride=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg[0]) self.conv2 = ConvModule( in_channels=int(total_channel / ratio), out_channels=total_channel, kernel_size=1, stride=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg[1]) def forward(self, x): mini_size = x[-1].size()[-2:] out = [F.adaptive_avg_pool2d(s, mini_size) for s in x[:-1]] + [x[-1]] out = torch.cat(out, dim=1) out = self.conv1(out) out = self.conv2(out) out = torch.split(out, self.channels, dim=1) out = [ s * F.interpolate(a, size=s.size()[-2:], mode='nearest') for s, a in zip(x, out) ] return out class ConditionalChannelWeighting(nn.Module): def __init__(self, in_channels, stride, reduce_ratio, conv_cfg=None, norm_cfg=dict(type='BN'), with_cp=False): super().__init__() self.with_cp = with_cp self.stride = stride assert stride in [1, 2] branch_channels = [channel // 2 for channel in in_channels] self.cross_resolution_weighting = CrossResolutionWeighting( branch_channels, ratio=reduce_ratio, conv_cfg=conv_cfg, norm_cfg=norm_cfg) self.depthwise_convs = nn.ModuleList([ ConvModule( channel, channel, kernel_size=3, stride=self.stride, padding=1, groups=channel, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=None) for channel in branch_channels ]) self.spatial_weighting = nn.ModuleList([ SpatialWeighting(channels=channel, ratio=4) for channel in branch_channels ]) def forward(self, x): def _inner_forward(x): x = [s.chunk(2, dim=1) for s in x] x1 = [s[0] for s in x] x2 = [s[1] for s in x] x2 = self.cross_resolution_weighting(x2) x2 = [dw(s) for s, dw in zip(x2, self.depthwise_convs)] x2 = [sw(s) for s, sw in zip(x2, self.spatial_weighting)] out = [torch.cat([s1, s2], dim=1) for s1, s2 in zip(x1, x2)] out = [channel_shuffle(s, 2) for s in out] return out if self.with_cp and x.requires_grad: out = cp.checkpoint(_inner_forward, x) else: out = _inner_forward(x) return out class Stem(nn.Module): def __init__(self, in_channels, stem_channels, out_channels, expand_ratio, conv_cfg=None, norm_cfg=dict(type='BN'), with_cp=False): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.with_cp = with_cp self.conv1 = ConvModule( in_channels=in_channels, out_channels=stem_channels, kernel_size=3, stride=2, padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=dict(type='ReLU')) mid_channels = int(round(stem_channels * expand_ratio)) branch_channels = stem_channels // 2 if stem_channels == self.out_channels: inc_channels = self.out_channels - branch_channels else: inc_channels = self.out_channels - stem_channels self.branch1 = nn.Sequential( ConvModule( branch_channels, branch_channels, kernel_size=3, stride=2, padding=1, groups=branch_channels, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=None), ConvModule( branch_channels, inc_channels, kernel_size=1, stride=1, padding=0, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=dict(type='ReLU')), ) self.expand_conv = ConvModule( branch_channels, mid_channels, kernel_size=1, stride=1, padding=0, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=dict(type='ReLU')) self.depthwise_conv = ConvModule( mid_channels, mid_channels, kernel_size=3, stride=2, padding=1, groups=mid_channels, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=None) self.linear_conv = ConvModule( mid_channels, branch_channels if stem_channels == self.out_channels else stem_channels, kernel_size=1, stride=1, padding=0, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=dict(type='ReLU')) def forward(self, x): def _inner_forward(x): x = self.conv1(x) x1, x2 = x.chunk(2, dim=1) x2 = self.expand_conv(x2) x2 = self.depthwise_conv(x2) x2 = self.linear_conv(x2) out = torch.cat((self.branch1(x1), x2), dim=1) out = channel_shuffle(out, 2) return out if self.with_cp and x.requires_grad: out = cp.checkpoint(_inner_forward, x) else: out = _inner_forward(x) return out class IterativeHead(nn.Module): def __init__(self, in_channels, conv_cfg=None, norm_cfg=dict(type='BN')): super().__init__() projects = [] num_branchs = len(in_channels) self.in_channels = in_channels[::-1] for i in range(num_branchs): if i != num_branchs - 1: projects.append( DepthwiseSeparableConvModule( in_channels=self.in_channels[i], out_channels=self.in_channels[i + 1], kernel_size=3, stride=1, padding=1, norm_cfg=norm_cfg, act_cfg=dict(type='ReLU'), dw_act_cfg=None, pw_act_cfg=dict(type='ReLU'))) else: projects.append( DepthwiseSeparableConvModule( in_channels=self.in_channels[i], out_channels=self.in_channels[i], kernel_size=3, stride=1, padding=1, norm_cfg=norm_cfg, act_cfg=dict(type='ReLU'), dw_act_cfg=None, pw_act_cfg=dict(type='ReLU'))) self.projects = nn.ModuleList(projects) def forward(self, x): x = x[::-1] y = [] last_x = None for i, s in enumerate(x): if last_x is not None: last_x = F.interpolate( last_x, size=s.size()[-2:], mode='bilinear', align_corners=True) s = s + last_x s = self.projects[i](s) y.append(s) last_x = s return y[::-1] class ShuffleUnit(nn.Module): """InvertedResidual block for ShuffleNetV2 backbone. Args: in_channels (int): The input channels of the block. out_channels (int): The output channels of the block. stride (int): Stride of the 3x3 convolution layer. Default: 1 conv_cfg (dict): Config dict for convolution layer. Default: None, which means using conv2d. norm_cfg (dict): Config dict for normalization layer. Default: dict(type='BN'). act_cfg (dict): Config dict for activation layer. Default: dict(type='ReLU'). with_cp (bool): Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. Default: False. """ def __init__(self, in_channels, out_channels, stride=1, conv_cfg=None, norm_cfg=dict(type='BN'), act_cfg=dict(type='ReLU'), with_cp=False): super().__init__() self.stride = stride self.with_cp = with_cp branch_features = out_channels // 2 if self.stride == 1: assert in_channels == branch_features * 2, ( f'in_channels ({in_channels}) should equal to ' f'branch_features * 2 ({branch_features * 2}) ' 'when stride is 1') if in_channels != branch_features * 2: assert self.stride != 1, ( f'stride ({self.stride}) should not equal 1 when ' f'in_channels != branch_features * 2') if self.stride > 1: self.branch1 = nn.Sequential( ConvModule( in_channels, in_channels, kernel_size=3, stride=self.stride, padding=1, groups=in_channels, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=None), ConvModule( in_channels, branch_features, kernel_size=1, stride=1, padding=0, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg), ) self.branch2 = nn.Sequential( ConvModule( in_channels if (self.stride > 1) else branch_features, branch_features, kernel_size=1, stride=1, padding=0, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg), ConvModule( branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1, groups=branch_features, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=None), ConvModule( branch_features, branch_features, kernel_size=1, stride=1, padding=0, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)) def forward(self, x): def _inner_forward(x): if self.stride > 1: out = torch.cat((self.branch1(x), self.branch2(x)), dim=1) else: x1, x2 = x.chunk(2, dim=1) out = torch.cat((x1, self.branch2(x2)), dim=1) out = channel_shuffle(out, 2) return out if self.with_cp and x.requires_grad: out = cp.checkpoint(_inner_forward, x) else: out = _inner_forward(x) return out class LiteHRModule(nn.Module): def __init__( self, num_branches, num_blocks, in_channels, reduce_ratio, module_type, multiscale_output=False, with_fuse=True, conv_cfg=None, norm_cfg=dict(type='BN'), with_cp=False, ): super().__init__() self._check_branches(num_branches, in_channels) self.in_channels = in_channels self.num_branches = num_branches self.module_type = module_type self.multiscale_output = multiscale_output self.with_fuse = with_fuse self.norm_cfg = norm_cfg self.conv_cfg = conv_cfg self.with_cp = with_cp if self.module_type == 'LITE': self.layers = self._make_weighting_blocks(num_blocks, reduce_ratio) elif self.module_type == 'NAIVE': self.layers = self._make_naive_branches(num_branches, num_blocks) if self.with_fuse: self.fuse_layers = self._make_fuse_layers() self.relu = nn.ReLU() def _check_branches(self, num_branches, in_channels): """Check input to avoid ValueError.""" if num_branches != len(in_channels): error_msg = f'NUM_BRANCHES({num_branches}) ' \ f'!= NUM_INCHANNELS({len(in_channels)})' raise ValueError(error_msg) def _make_weighting_blocks(self, num_blocks, reduce_ratio, stride=1): layers = [] for i in range(num_blocks): layers.append( ConditionalChannelWeighting( self.in_channels, stride=stride, reduce_ratio=reduce_ratio, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, with_cp=self.with_cp)) return nn.Sequential(*layers) def _make_one_branch(self, branch_index, num_blocks, stride=1): """Make one branch.""" layers = [] layers.append( ShuffleUnit( self.in_channels[branch_index], self.in_channels[branch_index], stride=stride, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=dict(type='ReLU'), with_cp=self.with_cp)) for i in range(1, num_blocks): layers.append( ShuffleUnit( self.in_channels[branch_index], self.in_channels[branch_index], stride=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=dict(type='ReLU'), with_cp=self.with_cp)) return nn.Sequential(*layers) def _make_naive_branches(self, num_branches, num_blocks): """Make branches.""" branches = [] for i in range(num_branches): branches.append(self._make_one_branch(i, num_blocks)) return nn.ModuleList(branches) def _make_fuse_layers(self): """Make fuse layer.""" if self.num_branches == 1: return None num_branches = self.num_branches in_channels = self.in_channels fuse_layers = [] num_out_branches = num_branches if self.multiscale_output else 1 for i in range(num_out_branches): fuse_layer = [] for j in range(num_branches): if j > i: fuse_layer.append( nn.Sequential( build_conv_layer( self.conv_cfg, in_channels[j], in_channels[i], kernel_size=1, stride=1, padding=0, bias=False), build_norm_layer(self.norm_cfg, in_channels[i])[1], nn.Upsample( scale_factor=2**(j - i), mode='nearest'))) elif j == i: fuse_layer.append(None) else: conv_downsamples = [] for k in range(i - j): if k == i - j - 1: conv_downsamples.append( nn.Sequential( build_conv_layer( self.conv_cfg, in_channels[j], in_channels[j], kernel_size=3, stride=2, padding=1, groups=in_channels[j], bias=False), build_norm_layer(self.norm_cfg, in_channels[j])[1], build_conv_layer( self.conv_cfg, in_channels[j], in_channels[i], kernel_size=1, stride=1, padding=0, bias=False), build_norm_layer(self.norm_cfg, in_channels[i])[1])) else: conv_downsamples.append( nn.Sequential( build_conv_layer( self.conv_cfg, in_channels[j], in_channels[j], kernel_size=3, stride=2, padding=1, groups=in_channels[j], bias=False), build_norm_layer(self.norm_cfg, in_channels[j])[1], build_conv_layer( self.conv_cfg, in_channels[j], in_channels[j], kernel_size=1, stride=1, padding=0, bias=False), build_norm_layer(self.norm_cfg, in_channels[j])[1], nn.ReLU(inplace=True))) fuse_layer.append(nn.Sequential(*conv_downsamples)) fuse_layers.append(nn.ModuleList(fuse_layer)) return nn.ModuleList(fuse_layers) def forward(self, x): """Forward function.""" if self.num_branches == 1: return [self.layers[0](x[0])] if self.module_type == 'LITE': out = self.layers(x) elif self.module_type == 'NAIVE': for i in range(self.num_branches): x[i] = self.layers[i](x[i]) out = x if self.with_fuse: out_fuse = [] for i in range(len(self.fuse_layers)): y = out[0] if i == 0 else self.fuse_layers[i][0](out[0]) for j in range(self.num_branches): if i == j: y += out[j] else: y += self.fuse_layers[i][j](out[j]) out_fuse.append(self.relu(y)) out = out_fuse elif not self.multiscale_output: out = [out[0]] return out @BACKBONES.register_module() class LiteHRNet(nn.Module): """Lite-HRNet backbone. `High-Resolution Representations for Labeling Pixels and Regions `_ Args: extra (dict): detailed configuration for each stage of HRNet. in_channels (int): Number of input image channels. Default: 3. conv_cfg (dict): dictionary to construct and config conv layer. norm_cfg (dict): dictionary to construct and config norm layer. norm_eval (bool): Whether to set norm layers to eval mode, namely, freeze running stats (mean and var). Note: Effect on Batch Norm and its variants only. Default: False with_cp (bool): Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. zero_init_residual (bool): whether to use zero init for last norm layer in resblocks to let them behave as identity. Example: >>> from mmpose.models import HRNet >>> import torch >>> extra = dict( >>> stage1=dict( >>> num_modules=1, >>> num_branches=1, >>> block='BOTTLENECK', >>> num_blocks=(4, ), >>> num_channels=(64, )), >>> stage2=dict( >>> num_modules=1, >>> num_branches=2, >>> block='BASIC', >>> num_blocks=(4, 4), >>> num_channels=(32, 64)), >>> stage3=dict( >>> num_modules=4, >>> num_branches=3, >>> block='BASIC', >>> num_blocks=(4, 4, 4), >>> num_channels=(32, 64, 128)), >>> stage4=dict( >>> num_modules=3, >>> num_branches=4, >>> block='BASIC', >>> num_blocks=(4, 4, 4, 4), >>> num_channels=(32, 64, 128, 256))) >>> self = HRNet(extra, in_channels=1) >>> self.eval() >>> inputs = torch.rand(1, 1, 32, 32) >>> level_outputs = self.forward(inputs) >>> for level_out in level_outputs: ... print(tuple(level_out.shape)) (1, 32, 8, 8) (1, 64, 4, 4) (1, 128, 2, 2) (1, 256, 1, 1) """ def __init__(self, extra, in_channels=3, conv_cfg=None, norm_cfg=dict(type='BN'), norm_eval=False, with_cp=False, zero_init_residual=False): super().__init__() self.extra = extra self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.norm_eval = norm_eval self.with_cp = with_cp self.zero_init_residual = zero_init_residual self.stem = Stem( in_channels, stem_channels=self.extra['stem']['stem_channels'], out_channels=self.extra['stem']['out_channels'], expand_ratio=self.extra['stem']['expand_ratio'], conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg) self.num_stages = self.extra['num_stages'] self.stages_spec = self.extra['stages_spec'] num_channels_last = [ self.stem.out_channels, ] for i in range(self.num_stages): num_channels = self.stages_spec['num_channels'][i] num_channels = [num_channels[i] for i in range(len(num_channels))] setattr( self, 'transition{}'.format(i), self._make_transition_layer(num_channels_last, num_channels)) stage, num_channels_last = self._make_stage( self.stages_spec, i, num_channels, multiscale_output=True) setattr(self, 'stage{}'.format(i), stage) self.with_head = self.extra['with_head'] if self.with_head: self.head_layer = IterativeHead( in_channels=num_channels_last, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, ) def _make_transition_layer(self, num_channels_pre_layer, num_channels_cur_layer): """Make transition layer.""" num_branches_cur = len(num_channels_cur_layer) num_branches_pre = len(num_channels_pre_layer) transition_layers = [] for i in range(num_branches_cur): if i < num_branches_pre: if num_channels_cur_layer[i] != num_channels_pre_layer[i]: transition_layers.append( nn.Sequential( build_conv_layer( self.conv_cfg, num_channels_pre_layer[i], num_channels_pre_layer[i], kernel_size=3, stride=1, padding=1, groups=num_channels_pre_layer[i], bias=False), build_norm_layer(self.norm_cfg, num_channels_pre_layer[i])[1], build_conv_layer( self.conv_cfg, num_channels_pre_layer[i], num_channels_cur_layer[i], kernel_size=1, stride=1, padding=0, bias=False), build_norm_layer(self.norm_cfg, num_channels_cur_layer[i])[1], nn.ReLU())) else: transition_layers.append(None) else: conv_downsamples = [] for j in range(i + 1 - num_branches_pre): in_channels = num_channels_pre_layer[-1] out_channels = num_channels_cur_layer[i] \ if j == i - num_branches_pre else in_channels conv_downsamples.append( nn.Sequential( build_conv_layer( self.conv_cfg, in_channels, in_channels, kernel_size=3, stride=2, padding=1, groups=in_channels, bias=False), build_norm_layer(self.norm_cfg, in_channels)[1], build_conv_layer( self.conv_cfg, in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False), build_norm_layer(self.norm_cfg, out_channels)[1], nn.ReLU())) transition_layers.append(nn.Sequential(*conv_downsamples)) return nn.ModuleList(transition_layers) def _make_stage(self, stages_spec, stage_index, in_channels, multiscale_output=True): num_modules = stages_spec['num_modules'][stage_index] num_branches = stages_spec['num_branches'][stage_index] num_blocks = stages_spec['num_blocks'][stage_index] reduce_ratio = stages_spec['reduce_ratios'][stage_index] with_fuse = stages_spec['with_fuse'][stage_index] module_type = stages_spec['module_type'][stage_index] modules = [] for i in range(num_modules): # multi_scale_output is only used last module if not multiscale_output and i == num_modules - 1: reset_multiscale_output = False else: reset_multiscale_output = True modules.append( LiteHRModule( num_branches, num_blocks, in_channels, reduce_ratio, module_type, multiscale_output=reset_multiscale_output, with_fuse=with_fuse, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, with_cp=self.with_cp)) in_channels = modules[-1].in_channels return nn.Sequential(*modules), in_channels def init_weights(self, pretrained=None): """Initialize the weights in backbone. Args: pretrained (str, optional): Path to pre-trained weights. Defaults to None. """ if isinstance(pretrained, str): logger = get_root_logger() load_checkpoint(self, pretrained, strict=False, logger=logger) elif pretrained is None: for m in self.modules(): if isinstance(m, nn.Conv2d): normal_init(m, std=0.001) elif isinstance(m, (_BatchNorm, nn.GroupNorm)): constant_init(m, 1) if self.zero_init_residual: for m in self.modules(): if isinstance(m, Bottleneck): constant_init(m.norm3, 0) elif isinstance(m, BasicBlock): constant_init(m.norm2, 0) else: raise TypeError('pretrained must be a str or None') def forward(self, x): """Forward function.""" x = self.stem(x) y_list = [x] for i in range(self.num_stages): x_list = [] transition = getattr(self, 'transition{}'.format(i)) for j in range(self.stages_spec['num_branches'][i]): if transition[j]: if j >= len(y_list): x_list.append(transition[j](y_list[-1])) else: x_list.append(transition[j](y_list[j])) else: x_list.append(y_list[j]) y_list = getattr(self, 'stage{}'.format(i))(x_list) x = y_list if self.with_head: x = self.head_layer(x) return [x[0]] def train(self, mode=True): """Convert the model into training mode.""" super().train(mode) if mode and self.norm_eval: for m in self.modules(): if isinstance(m, _BatchNorm): m.eval()