|
|
|
|
|
import fvcore.nn.weight_init as weight_init |
|
import torch |
|
from torch import nn |
|
from torch.nn import functional as F |
|
|
|
from detectron2.config import CfgNode |
|
from detectron2.layers import Conv2d |
|
|
|
from .registry import ROI_DENSEPOSE_HEAD_REGISTRY |
|
|
|
|
|
@ROI_DENSEPOSE_HEAD_REGISTRY.register() |
|
class DensePoseDeepLabHead(nn.Module): |
|
""" |
|
DensePose head using DeepLabV3 model from |
|
"Rethinking Atrous Convolution for Semantic Image Segmentation" |
|
<https://arxiv.org/abs/1706.05587>. |
|
""" |
|
|
|
def __init__(self, cfg: CfgNode, input_channels: int): |
|
super(DensePoseDeepLabHead, self).__init__() |
|
|
|
hidden_dim = cfg.MODEL.ROI_DENSEPOSE_HEAD.CONV_HEAD_DIM |
|
kernel_size = cfg.MODEL.ROI_DENSEPOSE_HEAD.CONV_HEAD_KERNEL |
|
norm = cfg.MODEL.ROI_DENSEPOSE_HEAD.DEEPLAB.NORM |
|
self.n_stacked_convs = cfg.MODEL.ROI_DENSEPOSE_HEAD.NUM_STACKED_CONVS |
|
self.use_nonlocal = cfg.MODEL.ROI_DENSEPOSE_HEAD.DEEPLAB.NONLOCAL_ON |
|
|
|
pad_size = kernel_size // 2 |
|
n_channels = input_channels |
|
|
|
self.ASPP = ASPP(input_channels, [6, 12, 56], n_channels) |
|
self.add_module("ASPP", self.ASPP) |
|
|
|
if self.use_nonlocal: |
|
self.NLBlock = NONLocalBlock2D(input_channels, bn_layer=True) |
|
self.add_module("NLBlock", self.NLBlock) |
|
|
|
|
|
for i in range(self.n_stacked_convs): |
|
norm_module = nn.GroupNorm(32, hidden_dim) if norm == "GN" else None |
|
layer = Conv2d( |
|
n_channels, |
|
hidden_dim, |
|
kernel_size, |
|
stride=1, |
|
padding=pad_size, |
|
bias=not norm, |
|
norm=norm_module, |
|
) |
|
weight_init.c2_msra_fill(layer) |
|
n_channels = hidden_dim |
|
layer_name = self._get_layer_name(i) |
|
self.add_module(layer_name, layer) |
|
self.n_out_channels = hidden_dim |
|
|
|
|
|
def forward(self, features): |
|
x0 = features |
|
x = self.ASPP(x0) |
|
if self.use_nonlocal: |
|
x = self.NLBlock(x) |
|
output = x |
|
for i in range(self.n_stacked_convs): |
|
layer_name = self._get_layer_name(i) |
|
x = getattr(self, layer_name)(x) |
|
x = F.relu(x) |
|
output = x |
|
return output |
|
|
|
def _get_layer_name(self, i: int): |
|
layer_name = "body_conv_fcn{}".format(i + 1) |
|
return layer_name |
|
|
|
|
|
|
|
|
|
|
|
class ASPPConv(nn.Sequential): |
|
def __init__(self, in_channels, out_channels, dilation): |
|
modules = [ |
|
nn.Conv2d( |
|
in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False |
|
), |
|
nn.GroupNorm(32, out_channels), |
|
nn.ReLU(), |
|
] |
|
super(ASPPConv, self).__init__(*modules) |
|
|
|
|
|
class ASPPPooling(nn.Sequential): |
|
def __init__(self, in_channels, out_channels): |
|
super(ASPPPooling, self).__init__( |
|
nn.AdaptiveAvgPool2d(1), |
|
nn.Conv2d(in_channels, out_channels, 1, bias=False), |
|
nn.GroupNorm(32, out_channels), |
|
nn.ReLU(), |
|
) |
|
|
|
def forward(self, x): |
|
size = x.shape[-2:] |
|
x = super(ASPPPooling, self).forward(x) |
|
return F.interpolate(x, size=size, mode="bilinear", align_corners=False) |
|
|
|
|
|
class ASPP(nn.Module): |
|
def __init__(self, in_channels, atrous_rates, out_channels): |
|
super(ASPP, self).__init__() |
|
modules = [] |
|
modules.append( |
|
nn.Sequential( |
|
nn.Conv2d(in_channels, out_channels, 1, bias=False), |
|
nn.GroupNorm(32, out_channels), |
|
nn.ReLU(), |
|
) |
|
) |
|
|
|
rate1, rate2, rate3 = tuple(atrous_rates) |
|
modules.append(ASPPConv(in_channels, out_channels, rate1)) |
|
modules.append(ASPPConv(in_channels, out_channels, rate2)) |
|
modules.append(ASPPConv(in_channels, out_channels, rate3)) |
|
modules.append(ASPPPooling(in_channels, out_channels)) |
|
|
|
self.convs = nn.ModuleList(modules) |
|
|
|
self.project = nn.Sequential( |
|
nn.Conv2d(5 * out_channels, out_channels, 1, bias=False), |
|
|
|
nn.ReLU() |
|
|
|
) |
|
|
|
def forward(self, x): |
|
res = [] |
|
for conv in self.convs: |
|
res.append(conv(x)) |
|
res = torch.cat(res, dim=1) |
|
return self.project(res) |
|
|
|
|
|
|
|
|
|
|
|
class _NonLocalBlockND(nn.Module): |
|
def __init__( |
|
self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True |
|
): |
|
super(_NonLocalBlockND, self).__init__() |
|
|
|
assert dimension in [1, 2, 3] |
|
|
|
self.dimension = dimension |
|
self.sub_sample = sub_sample |
|
|
|
self.in_channels = in_channels |
|
self.inter_channels = inter_channels |
|
|
|
if self.inter_channels is None: |
|
self.inter_channels = in_channels // 2 |
|
if self.inter_channels == 0: |
|
self.inter_channels = 1 |
|
|
|
if dimension == 3: |
|
conv_nd = nn.Conv3d |
|
max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2)) |
|
bn = nn.GroupNorm |
|
elif dimension == 2: |
|
conv_nd = nn.Conv2d |
|
max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2)) |
|
bn = nn.GroupNorm |
|
else: |
|
conv_nd = nn.Conv1d |
|
max_pool_layer = nn.MaxPool1d(kernel_size=2) |
|
bn = nn.GroupNorm |
|
|
|
self.g = conv_nd( |
|
in_channels=self.in_channels, |
|
out_channels=self.inter_channels, |
|
kernel_size=1, |
|
stride=1, |
|
padding=0, |
|
) |
|
|
|
if bn_layer: |
|
self.W = nn.Sequential( |
|
conv_nd( |
|
in_channels=self.inter_channels, |
|
out_channels=self.in_channels, |
|
kernel_size=1, |
|
stride=1, |
|
padding=0, |
|
), |
|
bn(32, self.in_channels), |
|
) |
|
nn.init.constant_(self.W[1].weight, 0) |
|
nn.init.constant_(self.W[1].bias, 0) |
|
else: |
|
self.W = conv_nd( |
|
in_channels=self.inter_channels, |
|
out_channels=self.in_channels, |
|
kernel_size=1, |
|
stride=1, |
|
padding=0, |
|
) |
|
nn.init.constant_(self.W.weight, 0) |
|
nn.init.constant_(self.W.bias, 0) |
|
|
|
self.theta = conv_nd( |
|
in_channels=self.in_channels, |
|
out_channels=self.inter_channels, |
|
kernel_size=1, |
|
stride=1, |
|
padding=0, |
|
) |
|
self.phi = conv_nd( |
|
in_channels=self.in_channels, |
|
out_channels=self.inter_channels, |
|
kernel_size=1, |
|
stride=1, |
|
padding=0, |
|
) |
|
|
|
if sub_sample: |
|
self.g = nn.Sequential(self.g, max_pool_layer) |
|
self.phi = nn.Sequential(self.phi, max_pool_layer) |
|
|
|
def forward(self, x): |
|
""" |
|
:param x: (b, c, t, h, w) |
|
:return: |
|
""" |
|
|
|
batch_size = x.size(0) |
|
|
|
g_x = self.g(x).view(batch_size, self.inter_channels, -1) |
|
g_x = g_x.permute(0, 2, 1) |
|
|
|
theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) |
|
theta_x = theta_x.permute(0, 2, 1) |
|
phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) |
|
f = torch.matmul(theta_x, phi_x) |
|
f_div_C = F.softmax(f, dim=-1) |
|
|
|
y = torch.matmul(f_div_C, g_x) |
|
y = y.permute(0, 2, 1).contiguous() |
|
y = y.view(batch_size, self.inter_channels, *x.size()[2:]) |
|
W_y = self.W(y) |
|
z = W_y + x |
|
|
|
return z |
|
|
|
|
|
class NONLocalBlock2D(_NonLocalBlockND): |
|
def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): |
|
super(NONLocalBlock2D, self).__init__( |
|
in_channels, |
|
inter_channels=inter_channels, |
|
dimension=2, |
|
sub_sample=sub_sample, |
|
bn_layer=bn_layer, |
|
) |
|
|