Spaces:
Paused
Paused
# Copyright (c) Facebook, Inc. and its affiliates. | |
import fvcore.nn.weight_init as weight_init | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
from detectron2.config import CfgNode | |
from detectron2.layers import Conv2d | |
from .registry import ROI_DENSEPOSE_HEAD_REGISTRY | |
class DensePoseDeepLabHead(nn.Module): | |
""" | |
DensePose head using DeepLabV3 model from | |
"Rethinking Atrous Convolution for Semantic Image Segmentation" | |
<https://arxiv.org/abs/1706.05587>. | |
""" | |
def __init__(self, cfg: CfgNode, input_channels: int): | |
super(DensePoseDeepLabHead, self).__init__() | |
# fmt: off | |
hidden_dim = cfg.MODEL.ROI_DENSEPOSE_HEAD.CONV_HEAD_DIM | |
kernel_size = cfg.MODEL.ROI_DENSEPOSE_HEAD.CONV_HEAD_KERNEL | |
norm = cfg.MODEL.ROI_DENSEPOSE_HEAD.DEEPLAB.NORM | |
self.n_stacked_convs = cfg.MODEL.ROI_DENSEPOSE_HEAD.NUM_STACKED_CONVS | |
self.use_nonlocal = cfg.MODEL.ROI_DENSEPOSE_HEAD.DEEPLAB.NONLOCAL_ON | |
# fmt: on | |
pad_size = kernel_size // 2 | |
n_channels = input_channels | |
self.ASPP = ASPP(input_channels, [6, 12, 56], n_channels) # 6, 12, 56 | |
self.add_module("ASPP", self.ASPP) | |
if self.use_nonlocal: | |
self.NLBlock = NONLocalBlock2D(input_channels, bn_layer=True) | |
self.add_module("NLBlock", self.NLBlock) | |
# weight_init.c2_msra_fill(self.ASPP) | |
for i in range(self.n_stacked_convs): | |
norm_module = nn.GroupNorm(32, hidden_dim) if norm == "GN" else None | |
layer = Conv2d( | |
n_channels, | |
hidden_dim, | |
kernel_size, | |
stride=1, | |
padding=pad_size, | |
bias=not norm, | |
norm=norm_module, | |
) | |
weight_init.c2_msra_fill(layer) | |
n_channels = hidden_dim | |
layer_name = self._get_layer_name(i) | |
self.add_module(layer_name, layer) | |
self.n_out_channels = hidden_dim | |
# initialize_module_params(self) | |
def forward(self, features): | |
x0 = features | |
x = self.ASPP(x0) | |
if self.use_nonlocal: | |
x = self.NLBlock(x) | |
output = x | |
for i in range(self.n_stacked_convs): | |
layer_name = self._get_layer_name(i) | |
x = getattr(self, layer_name)(x) | |
x = F.relu(x) | |
output = x | |
return output | |
def _get_layer_name(self, i: int): | |
layer_name = "body_conv_fcn{}".format(i + 1) | |
return layer_name | |
# Copied from | |
# https://github.com/pytorch/vision/blob/master/torchvision/models/segmentation/deeplabv3.py | |
# See https://arxiv.org/pdf/1706.05587.pdf for details | |
class ASPPConv(nn.Sequential): | |
def __init__(self, in_channels, out_channels, dilation): | |
modules = [ | |
nn.Conv2d( | |
in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False | |
), | |
nn.GroupNorm(32, out_channels), | |
nn.ReLU(), | |
] | |
super(ASPPConv, self).__init__(*modules) | |
class ASPPPooling(nn.Sequential): | |
def __init__(self, in_channels, out_channels): | |
super(ASPPPooling, self).__init__( | |
nn.AdaptiveAvgPool2d(1), | |
nn.Conv2d(in_channels, out_channels, 1, bias=False), | |
nn.GroupNorm(32, out_channels), | |
nn.ReLU(), | |
) | |
def forward(self, x): | |
size = x.shape[-2:] | |
x = super(ASPPPooling, self).forward(x) | |
return F.interpolate(x, size=size, mode="bilinear", align_corners=False) | |
class ASPP(nn.Module): | |
def __init__(self, in_channels, atrous_rates, out_channels): | |
super(ASPP, self).__init__() | |
modules = [] | |
modules.append( | |
nn.Sequential( | |
nn.Conv2d(in_channels, out_channels, 1, bias=False), | |
nn.GroupNorm(32, out_channels), | |
nn.ReLU(), | |
) | |
) | |
rate1, rate2, rate3 = tuple(atrous_rates) | |
modules.append(ASPPConv(in_channels, out_channels, rate1)) | |
modules.append(ASPPConv(in_channels, out_channels, rate2)) | |
modules.append(ASPPConv(in_channels, out_channels, rate3)) | |
modules.append(ASPPPooling(in_channels, out_channels)) | |
self.convs = nn.ModuleList(modules) | |
self.project = nn.Sequential( | |
nn.Conv2d(5 * out_channels, out_channels, 1, bias=False), | |
# nn.BatchNorm2d(out_channels), | |
nn.ReLU() | |
# nn.Dropout(0.5) | |
) | |
def forward(self, x): | |
res = [] | |
for conv in self.convs: | |
res.append(conv(x)) | |
res = torch.cat(res, dim=1) | |
return self.project(res) | |
# copied from | |
# https://github.com/AlexHex7/Non-local_pytorch/blob/master/lib/non_local_embedded_gaussian.py | |
# See https://arxiv.org/abs/1711.07971 for details | |
class _NonLocalBlockND(nn.Module): | |
def __init__( | |
self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True | |
): | |
super(_NonLocalBlockND, self).__init__() | |
assert dimension in [1, 2, 3] | |
self.dimension = dimension | |
self.sub_sample = sub_sample | |
self.in_channels = in_channels | |
self.inter_channels = inter_channels | |
if self.inter_channels is None: | |
self.inter_channels = in_channels // 2 | |
if self.inter_channels == 0: | |
self.inter_channels = 1 | |
if dimension == 3: | |
conv_nd = nn.Conv3d | |
max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2)) | |
bn = nn.GroupNorm # (32, hidden_dim) #nn.BatchNorm3d | |
elif dimension == 2: | |
conv_nd = nn.Conv2d | |
max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2)) | |
bn = nn.GroupNorm # (32, hidden_dim)nn.BatchNorm2d | |
else: | |
conv_nd = nn.Conv1d | |
max_pool_layer = nn.MaxPool1d(kernel_size=2) | |
bn = nn.GroupNorm # (32, hidden_dim)nn.BatchNorm1d | |
self.g = conv_nd( | |
in_channels=self.in_channels, | |
out_channels=self.inter_channels, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
) | |
if bn_layer: | |
self.W = nn.Sequential( | |
conv_nd( | |
in_channels=self.inter_channels, | |
out_channels=self.in_channels, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
), | |
bn(32, self.in_channels), | |
) | |
nn.init.constant_(self.W[1].weight, 0) | |
nn.init.constant_(self.W[1].bias, 0) | |
else: | |
self.W = conv_nd( | |
in_channels=self.inter_channels, | |
out_channels=self.in_channels, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
) | |
nn.init.constant_(self.W.weight, 0) | |
nn.init.constant_(self.W.bias, 0) | |
self.theta = conv_nd( | |
in_channels=self.in_channels, | |
out_channels=self.inter_channels, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
) | |
self.phi = conv_nd( | |
in_channels=self.in_channels, | |
out_channels=self.inter_channels, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
) | |
if sub_sample: | |
self.g = nn.Sequential(self.g, max_pool_layer) | |
self.phi = nn.Sequential(self.phi, max_pool_layer) | |
def forward(self, x): | |
""" | |
:param x: (b, c, t, h, w) | |
:return: | |
""" | |
batch_size = x.size(0) | |
g_x = self.g(x).view(batch_size, self.inter_channels, -1) | |
g_x = g_x.permute(0, 2, 1) | |
theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) | |
theta_x = theta_x.permute(0, 2, 1) | |
phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) | |
f = torch.matmul(theta_x, phi_x) | |
f_div_C = F.softmax(f, dim=-1) | |
y = torch.matmul(f_div_C, g_x) | |
y = y.permute(0, 2, 1).contiguous() | |
y = y.view(batch_size, self.inter_channels, *x.size()[2:]) | |
W_y = self.W(y) | |
z = W_y + x | |
return z | |
class NONLocalBlock2D(_NonLocalBlockND): | |
def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): | |
super(NONLocalBlock2D, self).__init__( | |
in_channels, | |
inter_channels=inter_channels, | |
dimension=2, | |
sub_sample=sub_sample, | |
bn_layer=bn_layer, | |
) | |