|
|
|
|
|
|
|
"""Video models.""" |
|
|
|
import torch |
|
import torch.nn as nn |
|
import copy |
|
|
|
import slowfast.utils.weight_init_helper as init_helper |
|
from slowfast.models.batchnorm_helper import get_norm |
|
|
|
from . import head_helper, resnet_helper, stem_helper |
|
from .build import MODEL_REGISTRY |
|
|
|
|
|
_MODEL_STAGE_DEPTH = {18:(2,2,2,2),50: (3, 4, 6, 3), 101: (3, 4, 23, 3)} |
|
|
|
|
|
_TEMPORAL_KERNEL_BASIS = { |
|
"c2d": [ |
|
[[1]], |
|
[[1]], |
|
[[1]], |
|
[[1]], |
|
[[1]], |
|
], |
|
"c2d_nopool": [ |
|
[[1]], |
|
[[1]], |
|
[[1]], |
|
[[1]], |
|
[[1]], |
|
], |
|
"i3d": [ |
|
[[5]], |
|
[[3]], |
|
[[3, 1]], |
|
[[3, 1]], |
|
[[1, 3]], |
|
], |
|
"r3d_18": [ |
|
[[3]], |
|
[[3]], |
|
[[3, 1]], |
|
[[3, 1]], |
|
[[1, 3]], |
|
], |
|
"i3d_nopool": [ |
|
[[5]], |
|
[[3]], |
|
[[3, 1]], |
|
[[3, 1]], |
|
[[1, 3]], |
|
], |
|
"slow": [ |
|
[[1]], |
|
[[1]], |
|
[[1]], |
|
[[3]], |
|
[[3]], |
|
], |
|
"slowfast": [ |
|
[[1], [5]], |
|
[[1], [3]], |
|
[[1], [3]], |
|
[[3], [3]], |
|
[[3], [3]], |
|
], |
|
} |
|
|
|
_POOL1 = { |
|
"c2d": [[2, 1, 1]], |
|
"c2d_nopool": [[1, 1, 1]], |
|
"i3d": [[2, 1, 1]], |
|
"r3d_18": [[2, 1, 1]], |
|
"i3d_nopool": [[1, 1, 1]], |
|
"slow": [[1, 1, 1]], |
|
"slowfast": [[1, 1, 1], [1, 1, 1]], |
|
} |
|
|
|
|
|
|
|
|
|
class FuseFastToSlow(nn.Module): |
|
""" |
|
Fuses the information from the Fast pathway to the Slow pathway. Given the |
|
tensors from Slow pathway and Fast pathway, fuse information from Fast to |
|
Slow, then return the fused tensors from Slow and Fast pathway in order. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
dim_in, |
|
fusion_conv_channel_ratio, |
|
fusion_kernel, |
|
alpha, |
|
eps=1e-5, |
|
bn_mmt=0.1, |
|
inplace_relu=True, |
|
norm_module=nn.BatchNorm3d, |
|
): |
|
""" |
|
Args: |
|
dim_in (int): the channel dimension of the input. |
|
fusion_conv_channel_ratio (int): channel ratio for the convolution |
|
used to fuse from Fast pathway to Slow pathway. |
|
fusion_kernel (int): kernel size of the convolution used to fuse |
|
from Fast pathway to Slow pathway. |
|
alpha (int): the frame rate ratio between the Fast and Slow pathway. |
|
eps (float): epsilon for batch norm. |
|
bn_mmt (float): momentum for batch norm. Noted that BN momentum in |
|
PyTorch = 1 - BN momentum in Caffe2. |
|
inplace_relu (bool): if True, calculate the relu on the original |
|
input without allocating new memory. |
|
norm_module (nn.Module): nn.Module for the normalization layer. The |
|
default is nn.BatchNorm3d. |
|
""" |
|
super(FuseFastToSlow, self).__init__() |
|
self.conv_f2s = nn.Conv3d( |
|
dim_in, |
|
dim_in * fusion_conv_channel_ratio, |
|
kernel_size=[fusion_kernel, 1, 1], |
|
stride=[alpha, 1, 1], |
|
padding=[fusion_kernel // 2, 0, 0], |
|
bias=False, |
|
) |
|
self.bn = norm_module( |
|
num_features=dim_in * fusion_conv_channel_ratio, |
|
eps=eps, |
|
momentum=bn_mmt, |
|
) |
|
self.relu = nn.ReLU(inplace_relu) |
|
|
|
def forward(self, x): |
|
x_s = x[0] |
|
x_f = x[1] |
|
fuse = self.conv_f2s(x_f) |
|
fuse = self.bn(fuse) |
|
fuse = self.relu(fuse) |
|
x_s_fuse = torch.cat([x_s, fuse], 1) |
|
return [x_s_fuse, x_f] |
|
|
|
|
|
@MODEL_REGISTRY.register() |
|
class SlowFast(nn.Module): |
|
""" |
|
SlowFast model builder for SlowFast network. |
|
|
|
Christoph Feichtenhofer, Haoqi Fan, Jitendra Malik, and Kaiming He. |
|
"SlowFast networks for video recognition." |
|
https://arxiv.org/pdf/1812.03982.pdf |
|
""" |
|
|
|
def __init__(self, cfg): |
|
""" |
|
The `__init__` method of any subclass should also contain these |
|
arguments. |
|
Args: |
|
cfg (CfgNode): model building configs, details are in the |
|
comments of the config file. |
|
""" |
|
super(SlowFast, self).__init__() |
|
self.norm_module = get_norm(cfg) |
|
self.enable_detection = cfg.DETECTION.ENABLE |
|
self.num_pathways = 2 |
|
self._construct_network(cfg) |
|
init_helper.init_weights( |
|
self, cfg.MODEL.FC_INIT_STD, cfg.RESNET.ZERO_INIT_FINAL_BN |
|
) |
|
|
|
def _construct_network(self, cfg): |
|
""" |
|
Builds a SlowFast model. The first pathway is the Slow pathway and the |
|
second pathway is the Fast pathway. |
|
Args: |
|
cfg (CfgNode): model building configs, details are in the |
|
comments of the config file. |
|
""" |
|
assert cfg.MODEL.ARCH in _POOL1.keys() |
|
pool_size = _POOL1[cfg.MODEL.ARCH] |
|
assert len({len(pool_size), self.num_pathways}) == 1 |
|
assert cfg.RESNET.DEPTH in _MODEL_STAGE_DEPTH.keys() |
|
|
|
(d2, d3, d4, d5) = _MODEL_STAGE_DEPTH[cfg.RESNET.DEPTH] |
|
|
|
num_groups = cfg.RESNET.NUM_GROUPS |
|
width_per_group = cfg.RESNET.WIDTH_PER_GROUP |
|
dim_inner = num_groups * width_per_group |
|
out_dim_ratio = ( |
|
cfg.SLOWFAST.BETA_INV // cfg.SLOWFAST.FUSION_CONV_CHANNEL_RATIO |
|
) |
|
|
|
temp_kernel = _TEMPORAL_KERNEL_BASIS[cfg.MODEL.ARCH] |
|
|
|
self.s1 = stem_helper.VideoModelStem( |
|
dim_in=cfg.DATA.INPUT_CHANNEL_NUM, |
|
dim_out=[width_per_group, width_per_group // cfg.SLOWFAST.BETA_INV], |
|
kernel=[temp_kernel[0][0] + [7, 7], temp_kernel[0][1] + [7, 7]], |
|
stride=[[1, 2, 2]] * 2, |
|
padding=[ |
|
[temp_kernel[0][0][0] // 2, 3, 3], |
|
[temp_kernel[0][1][0] // 2, 3, 3], |
|
], |
|
norm_module=self.norm_module, |
|
) |
|
self.s1_fuse = FuseFastToSlow( |
|
width_per_group // cfg.SLOWFAST.BETA_INV, |
|
cfg.SLOWFAST.FUSION_CONV_CHANNEL_RATIO, |
|
cfg.SLOWFAST.FUSION_KERNEL_SZ, |
|
cfg.SLOWFAST.ALPHA, |
|
norm_module=self.norm_module, |
|
) |
|
|
|
self.s2 = resnet_helper.ResStage( |
|
dim_in=[ |
|
width_per_group + width_per_group // out_dim_ratio, |
|
width_per_group // cfg.SLOWFAST.BETA_INV, |
|
], |
|
dim_out=[ |
|
width_per_group * 4, |
|
width_per_group * 4 // cfg.SLOWFAST.BETA_INV, |
|
], |
|
dim_inner=[dim_inner, dim_inner // cfg.SLOWFAST.BETA_INV], |
|
temp_kernel_sizes=temp_kernel[1], |
|
stride=cfg.RESNET.SPATIAL_STRIDES[0], |
|
num_blocks=[d2] * 2, |
|
num_groups=[num_groups] * 2, |
|
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[0], |
|
nonlocal_inds=cfg.NONLOCAL.LOCATION[0], |
|
nonlocal_group=cfg.NONLOCAL.GROUP[0], |
|
nonlocal_pool=cfg.NONLOCAL.POOL[0], |
|
instantiation=cfg.NONLOCAL.INSTANTIATION, |
|
trans_func_name=cfg.RESNET.TRANS_FUNC, |
|
dilation=cfg.RESNET.SPATIAL_DILATIONS[0], |
|
norm_module=self.norm_module, |
|
) |
|
self.s2_fuse = FuseFastToSlow( |
|
width_per_group * 4 // cfg.SLOWFAST.BETA_INV, |
|
cfg.SLOWFAST.FUSION_CONV_CHANNEL_RATIO, |
|
cfg.SLOWFAST.FUSION_KERNEL_SZ, |
|
cfg.SLOWFAST.ALPHA, |
|
norm_module=self.norm_module, |
|
) |
|
|
|
for pathway in range(self.num_pathways): |
|
pool = nn.MaxPool3d( |
|
kernel_size=pool_size[pathway], |
|
stride=pool_size[pathway], |
|
padding=[0, 0, 0], |
|
) |
|
self.add_module("pathway{}_pool".format(pathway), pool) |
|
|
|
self.s3 = resnet_helper.ResStage( |
|
dim_in=[ |
|
width_per_group * 4 + width_per_group * 4 // out_dim_ratio, |
|
width_per_group * 4 // cfg.SLOWFAST.BETA_INV, |
|
], |
|
dim_out=[ |
|
width_per_group * 8, |
|
width_per_group * 8 // cfg.SLOWFAST.BETA_INV, |
|
], |
|
dim_inner=[dim_inner * 2, dim_inner * 2 // cfg.SLOWFAST.BETA_INV], |
|
temp_kernel_sizes=temp_kernel[2], |
|
stride=cfg.RESNET.SPATIAL_STRIDES[1], |
|
num_blocks=[d3] * 2, |
|
num_groups=[num_groups] * 2, |
|
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[1], |
|
nonlocal_inds=cfg.NONLOCAL.LOCATION[1], |
|
nonlocal_group=cfg.NONLOCAL.GROUP[1], |
|
nonlocal_pool=cfg.NONLOCAL.POOL[1], |
|
instantiation=cfg.NONLOCAL.INSTANTIATION, |
|
trans_func_name=cfg.RESNET.TRANS_FUNC, |
|
dilation=cfg.RESNET.SPATIAL_DILATIONS[1], |
|
norm_module=self.norm_module, |
|
) |
|
self.s3_fuse = FuseFastToSlow( |
|
width_per_group * 8 // cfg.SLOWFAST.BETA_INV, |
|
cfg.SLOWFAST.FUSION_CONV_CHANNEL_RATIO, |
|
cfg.SLOWFAST.FUSION_KERNEL_SZ, |
|
cfg.SLOWFAST.ALPHA, |
|
norm_module=self.norm_module, |
|
) |
|
|
|
self.s4 = resnet_helper.ResStage( |
|
dim_in=[ |
|
width_per_group * 8 + width_per_group * 8 // out_dim_ratio, |
|
width_per_group * 8 // cfg.SLOWFAST.BETA_INV, |
|
], |
|
dim_out=[ |
|
width_per_group * 16, |
|
width_per_group * 16 // cfg.SLOWFAST.BETA_INV, |
|
], |
|
dim_inner=[dim_inner * 4, dim_inner * 4 // cfg.SLOWFAST.BETA_INV], |
|
temp_kernel_sizes=temp_kernel[3], |
|
stride=cfg.RESNET.SPATIAL_STRIDES[2], |
|
num_blocks=[d4] * 2, |
|
num_groups=[num_groups] * 2, |
|
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[2], |
|
nonlocal_inds=cfg.NONLOCAL.LOCATION[2], |
|
nonlocal_group=cfg.NONLOCAL.GROUP[2], |
|
nonlocal_pool=cfg.NONLOCAL.POOL[2], |
|
instantiation=cfg.NONLOCAL.INSTANTIATION, |
|
trans_func_name=cfg.RESNET.TRANS_FUNC, |
|
dilation=cfg.RESNET.SPATIAL_DILATIONS[2], |
|
norm_module=self.norm_module, |
|
) |
|
self.s4_fuse = FuseFastToSlow( |
|
width_per_group * 16 // cfg.SLOWFAST.BETA_INV, |
|
cfg.SLOWFAST.FUSION_CONV_CHANNEL_RATIO, |
|
cfg.SLOWFAST.FUSION_KERNEL_SZ, |
|
cfg.SLOWFAST.ALPHA, |
|
norm_module=self.norm_module, |
|
) |
|
|
|
self.s5 = resnet_helper.ResStage( |
|
dim_in=[ |
|
width_per_group * 16 + width_per_group * 16 // out_dim_ratio, |
|
width_per_group * 16 // cfg.SLOWFAST.BETA_INV, |
|
], |
|
dim_out=[ |
|
width_per_group * 32, |
|
width_per_group * 32 // cfg.SLOWFAST.BETA_INV, |
|
], |
|
dim_inner=[dim_inner * 8, dim_inner * 8 // cfg.SLOWFAST.BETA_INV], |
|
temp_kernel_sizes=temp_kernel[4], |
|
stride=cfg.RESNET.SPATIAL_STRIDES[3], |
|
num_blocks=[d5] * 2, |
|
num_groups=[num_groups] * 2, |
|
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[3], |
|
nonlocal_inds=cfg.NONLOCAL.LOCATION[3], |
|
nonlocal_group=cfg.NONLOCAL.GROUP[3], |
|
nonlocal_pool=cfg.NONLOCAL.POOL[3], |
|
instantiation=cfg.NONLOCAL.INSTANTIATION, |
|
trans_func_name=cfg.RESNET.TRANS_FUNC, |
|
dilation=cfg.RESNET.SPATIAL_DILATIONS[3], |
|
norm_module=self.norm_module, |
|
) |
|
|
|
if cfg.DETECTION.ENABLE: |
|
raise NotImplementedError |
|
else: |
|
self.head = head_helper.ResNetBasicHead( |
|
dim_in=[ |
|
width_per_group * 32, |
|
width_per_group * 32 // cfg.SLOWFAST.BETA_INV, |
|
], |
|
num_classes=cfg.MODEL.NUM_CLASSES, |
|
pool_size=[None, None] |
|
if cfg.MULTIGRID.SHORT_CYCLE |
|
else [ |
|
[ |
|
cfg.DATA.NUM_FRAMES |
|
// cfg.SLOWFAST.ALPHA |
|
// pool_size[0][0], |
|
cfg.DATA.CROP_SIZE // 32 // pool_size[0][1], |
|
cfg.DATA.CROP_SIZE // 32 // pool_size[0][2], |
|
], |
|
[ |
|
cfg.DATA.NUM_FRAMES // pool_size[1][0], |
|
cfg.DATA.CROP_SIZE // 32 // pool_size[1][1], |
|
cfg.DATA.CROP_SIZE // 32 // pool_size[1][2], |
|
], |
|
], |
|
dropout_rate=cfg.MODEL.DROPOUT_RATE, |
|
act_func=cfg.MODEL.HEAD_ACT, |
|
) |
|
|
|
def forward(self, x, bboxes=None): |
|
x = self.s1(x) |
|
x = self.s1_fuse(x) |
|
x = self.s2(x) |
|
x = self.s2_fuse(x) |
|
for pathway in range(self.num_pathways): |
|
pool = getattr(self, "pathway{}_pool".format(pathway)) |
|
x[pathway] = pool(x[pathway]) |
|
x = self.s3(x) |
|
x = self.s3_fuse(x) |
|
x = self.s4(x) |
|
x = self.s4_fuse(x) |
|
x = self.s5(x) |
|
if self.enable_detection: |
|
x = self.head(x, bboxes) |
|
else: |
|
x = self.head(x) |
|
return x |
|
|
|
|
|
@MODEL_REGISTRY.register() |
|
class ResNet(nn.Module): |
|
""" |
|
ResNet model builder. It builds a ResNet like network backbone without |
|
lateral connection (C2D, I3D, Slow). |
|
|
|
Christoph Feichtenhofer, Haoqi Fan, Jitendra Malik, and Kaiming He. |
|
"SlowFast networks for video recognition." |
|
https://arxiv.org/pdf/1812.03982.pdf |
|
|
|
Xiaolong Wang, Ross Girshick, Abhinav Gupta, and Kaiming He. |
|
"Non-local neural networks." |
|
https://arxiv.org/pdf/1711.07971.pdf |
|
""" |
|
|
|
def __init__(self, cfg): |
|
""" |
|
The `__init__` method of any subclass should also contain these |
|
arguments. |
|
|
|
Args: |
|
cfg (CfgNode): model building configs, details are in the |
|
comments of the config file. |
|
""" |
|
super(ResNet, self).__init__() |
|
self.norm_module = get_norm(cfg) |
|
self.enable_detection = cfg.DETECTION.ENABLE |
|
self.num_pathways = 1 |
|
self._construct_network(cfg) |
|
init_helper.init_weights( |
|
self, cfg.MODEL.FC_INIT_STD, cfg.RESNET.ZERO_INIT_FINAL_BN |
|
) |
|
|
|
def _construct_network(self, cfg): |
|
""" |
|
Builds a single pathway ResNet model. |
|
|
|
Args: |
|
cfg (CfgNode): model building configs, details are in the |
|
comments of the config file. |
|
""" |
|
assert cfg.MODEL.ARCH in _POOL1.keys() |
|
pool_size = _POOL1[cfg.MODEL.ARCH] |
|
assert len({len(pool_size), self.num_pathways}) == 1 |
|
assert cfg.RESNET.DEPTH in _MODEL_STAGE_DEPTH.keys() |
|
|
|
(d2, d3, d4, d5) = _MODEL_STAGE_DEPTH[cfg.RESNET.DEPTH] |
|
|
|
num_groups = cfg.RESNET.NUM_GROUPS |
|
width_per_group = cfg.RESNET.WIDTH_PER_GROUP |
|
dim_inner = num_groups * width_per_group |
|
print(dim_inner) |
|
|
|
temp_kernel = _TEMPORAL_KERNEL_BASIS[cfg.MODEL.ARCH] |
|
|
|
self.s1 = stem_helper.VideoModelStem( |
|
dim_in=cfg.DATA.INPUT_CHANNEL_NUM, |
|
dim_out=[width_per_group], |
|
kernel=[temp_kernel[0][0] + [7, 7]], |
|
stride=[[1, 2, 2]], |
|
padding=[[temp_kernel[0][0][0] // 2, 3, 3]], |
|
norm_module=self.norm_module, |
|
) |
|
|
|
self.s2 = resnet_helper.ResStage( |
|
dim_in=[width_per_group], |
|
dim_out=[width_per_group * 4], |
|
dim_inner=[dim_inner], |
|
temp_kernel_sizes=temp_kernel[1], |
|
stride=cfg.RESNET.SPATIAL_STRIDES[0], |
|
num_blocks=[d2], |
|
num_groups=[num_groups], |
|
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[0], |
|
nonlocal_inds=cfg.NONLOCAL.LOCATION[0], |
|
nonlocal_group=cfg.NONLOCAL.GROUP[0], |
|
nonlocal_pool=cfg.NONLOCAL.POOL[0], |
|
instantiation=cfg.NONLOCAL.INSTANTIATION, |
|
trans_func_name=cfg.RESNET.TRANS_FUNC, |
|
stride_1x1=cfg.RESNET.STRIDE_1X1, |
|
inplace_relu=cfg.RESNET.INPLACE_RELU, |
|
dilation=cfg.RESNET.SPATIAL_DILATIONS[0], |
|
norm_module=self.norm_module, |
|
) |
|
|
|
for pathway in range(self.num_pathways): |
|
pool = nn.MaxPool3d( |
|
kernel_size=pool_size[pathway], |
|
stride=pool_size[pathway], |
|
padding=[0, 0, 0], |
|
) |
|
self.add_module("pathway{}_pool".format(pathway), pool) |
|
|
|
self.s3 = resnet_helper.ResStage( |
|
dim_in=[width_per_group * 4], |
|
dim_out=[width_per_group * 8], |
|
dim_inner=[dim_inner * 2], |
|
temp_kernel_sizes=temp_kernel[2], |
|
stride=cfg.RESNET.SPATIAL_STRIDES[1], |
|
num_blocks=[d3], |
|
num_groups=[num_groups], |
|
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[1], |
|
nonlocal_inds=cfg.NONLOCAL.LOCATION[1], |
|
nonlocal_group=cfg.NONLOCAL.GROUP[1], |
|
nonlocal_pool=cfg.NONLOCAL.POOL[1], |
|
instantiation=cfg.NONLOCAL.INSTANTIATION, |
|
trans_func_name=cfg.RESNET.TRANS_FUNC, |
|
stride_1x1=cfg.RESNET.STRIDE_1X1, |
|
inplace_relu=cfg.RESNET.INPLACE_RELU, |
|
dilation=cfg.RESNET.SPATIAL_DILATIONS[1], |
|
norm_module=self.norm_module, |
|
) |
|
|
|
self.s4 = resnet_helper.ResStage( |
|
dim_in=[width_per_group * 8], |
|
dim_out=[width_per_group * 16], |
|
dim_inner=[dim_inner * 4], |
|
temp_kernel_sizes=temp_kernel[3], |
|
stride=cfg.RESNET.SPATIAL_STRIDES[2], |
|
num_blocks=[d4], |
|
num_groups=[num_groups], |
|
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[2], |
|
nonlocal_inds=cfg.NONLOCAL.LOCATION[2], |
|
nonlocal_group=cfg.NONLOCAL.GROUP[2], |
|
nonlocal_pool=cfg.NONLOCAL.POOL[2], |
|
instantiation=cfg.NONLOCAL.INSTANTIATION, |
|
trans_func_name=cfg.RESNET.TRANS_FUNC, |
|
stride_1x1=cfg.RESNET.STRIDE_1X1, |
|
inplace_relu=cfg.RESNET.INPLACE_RELU, |
|
dilation=cfg.RESNET.SPATIAL_DILATIONS[2], |
|
norm_module=self.norm_module, |
|
) |
|
|
|
self.s5 = resnet_helper.ResStage( |
|
dim_in=[width_per_group * 16], |
|
dim_out=[width_per_group * 32], |
|
dim_inner=[dim_inner * 8], |
|
temp_kernel_sizes=temp_kernel[4], |
|
stride=cfg.RESNET.SPATIAL_STRIDES[3], |
|
num_blocks=[d5], |
|
num_groups=[num_groups], |
|
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[3], |
|
nonlocal_inds=cfg.NONLOCAL.LOCATION[3], |
|
nonlocal_group=cfg.NONLOCAL.GROUP[3], |
|
nonlocal_pool=cfg.NONLOCAL.POOL[3], |
|
instantiation=cfg.NONLOCAL.INSTANTIATION, |
|
trans_func_name=cfg.RESNET.TRANS_FUNC, |
|
stride_1x1=cfg.RESNET.STRIDE_1X1, |
|
inplace_relu=cfg.RESNET.INPLACE_RELU, |
|
dilation=cfg.RESNET.SPATIAL_DILATIONS[3], |
|
norm_module=self.norm_module, |
|
) |
|
|
|
if self.enable_detection: |
|
raise NotImplementedError |
|
else: |
|
self.head = head_helper.ResNetBasicHead( |
|
dim_in=[width_per_group * 32], |
|
num_classes=cfg.MODEL.NUM_CLASSES, |
|
pool_size=[None, None] |
|
if cfg.MULTIGRID.SHORT_CYCLE |
|
else [ |
|
[ |
|
cfg.DATA.NUM_FRAMES // pool_size[0][0], |
|
cfg.DATA.CROP_SIZE // 32 // pool_size[0][1], |
|
cfg.DATA.CROP_SIZE // 32 // pool_size[0][2], |
|
] |
|
], |
|
dropout_rate=cfg.MODEL.DROPOUT_RATE, |
|
act_func=cfg.MODEL.HEAD_ACT, |
|
) |
|
|
|
def forward(self, x, return_feat=False, bboxes=None): |
|
x = self.s1(x) |
|
x = self.s2(x) |
|
for pathway in range(self.num_pathways): |
|
pool = getattr(self, "pathway{}_pool".format(pathway)) |
|
x[pathway] = pool(x[pathway]) |
|
x = self.s3(x) |
|
x = self.s4(x) |
|
feat = self.s5(x) |
|
if return_feat: |
|
return feat |
|
if self.enable_detection: |
|
x = self.head(feat, bboxes) |
|
else: |
|
x = self.head(feat) |
|
return x |
|
|
|
@MODEL_REGISTRY.register() |
|
class ResNetVar(nn.Module): |
|
""" |
|
ResNet model builder. It builds a ResNet like network backbone without |
|
lateral connection (C2D, I3D, Slow). |
|
|
|
Christoph Feichtenhofer, Haoqi Fan, Jitendra Malik, and Kaiming He. |
|
"SlowFast networks for video recognition." |
|
https://arxiv.org/pdf/1812.03982.pdf |
|
|
|
Xiaolong Wang, Ross Girshick, Abhinav Gupta, and Kaiming He. |
|
"Non-local neural networks." |
|
https://arxiv.org/pdf/1711.07971.pdf |
|
""" |
|
|
|
def __init__(self, cfg): |
|
""" |
|
The `__init__` method of any subclass should also contain these |
|
arguments. |
|
|
|
Args: |
|
cfg (CfgNode): model building configs, details are in the |
|
comments of the config file. |
|
""" |
|
super(ResNetVar, self).__init__() |
|
self.norm_module = get_norm(cfg) |
|
self.enable_detection = cfg.DETECTION.ENABLE |
|
self.num_pathways = 1 |
|
self._construct_network(cfg) |
|
init_helper.init_weights( |
|
self, cfg.MODEL.FC_INIT_STD, cfg.RESNET.ZERO_INIT_FINAL_BN |
|
) |
|
|
|
def _construct_network(self, cfg): |
|
""" |
|
Builds a single pathway ResNet model. |
|
|
|
Args: |
|
cfg (CfgNode): model building configs, details are in the |
|
comments of the config file. |
|
""" |
|
assert cfg.MODEL.ARCH in _POOL1.keys() |
|
pool_size = _POOL1[cfg.MODEL.ARCH] |
|
assert len({len(pool_size), self.num_pathways}) == 1 |
|
assert cfg.RESNET.DEPTH in _MODEL_STAGE_DEPTH.keys() |
|
|
|
(d2, d3, d4, d5) = _MODEL_STAGE_DEPTH[cfg.RESNET.DEPTH] |
|
|
|
num_groups = cfg.RESNET.NUM_GROUPS |
|
width_per_group = cfg.RESNET.WIDTH_PER_GROUP |
|
dim_inner = num_groups * width_per_group |
|
|
|
temp_kernel = _TEMPORAL_KERNEL_BASIS[cfg.MODEL.ARCH] |
|
|
|
self.s1 = stem_helper.VideoModelStem( |
|
dim_in=cfg.DATA.INPUT_CHANNEL_NUM, |
|
dim_out=[width_per_group], |
|
kernel=[temp_kernel[0][0] + [7, 7]], |
|
stride=[[1, 2, 2]], |
|
padding=[[temp_kernel[0][0][0] // 2, 3, 3]], |
|
norm_module=self.norm_module, |
|
) |
|
|
|
self.s2 = resnet_helper.ResStage( |
|
dim_in=[width_per_group], |
|
dim_out=[width_per_group * 4], |
|
dim_inner=[dim_inner], |
|
temp_kernel_sizes=temp_kernel[1], |
|
stride=cfg.RESNET.SPATIAL_STRIDES[0], |
|
num_blocks=[d2], |
|
num_groups=[num_groups], |
|
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[0], |
|
nonlocal_inds=cfg.NONLOCAL.LOCATION[0], |
|
nonlocal_group=cfg.NONLOCAL.GROUP[0], |
|
nonlocal_pool=cfg.NONLOCAL.POOL[0], |
|
instantiation=cfg.NONLOCAL.INSTANTIATION, |
|
trans_func_name=cfg.RESNET.TRANS_FUNC, |
|
stride_1x1=cfg.RESNET.STRIDE_1X1, |
|
inplace_relu=cfg.RESNET.INPLACE_RELU, |
|
dilation=cfg.RESNET.SPATIAL_DILATIONS[0], |
|
norm_module=self.norm_module, |
|
) |
|
|
|
for pathway in range(self.num_pathways): |
|
pool = nn.MaxPool3d( |
|
kernel_size=pool_size[pathway], |
|
stride=pool_size[pathway], |
|
padding=[0, 0, 0], |
|
) |
|
self.add_module("pathway{}_pool".format(pathway), pool) |
|
|
|
self.s3 = resnet_helper.ResStage( |
|
dim_in=[width_per_group * 4], |
|
dim_out=[width_per_group * 8], |
|
dim_inner=[dim_inner * 2], |
|
temp_kernel_sizes=temp_kernel[2], |
|
stride=cfg.RESNET.SPATIAL_STRIDES[1], |
|
num_blocks=[d3], |
|
num_groups=[num_groups], |
|
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[1], |
|
nonlocal_inds=cfg.NONLOCAL.LOCATION[1], |
|
nonlocal_group=cfg.NONLOCAL.GROUP[1], |
|
nonlocal_pool=cfg.NONLOCAL.POOL[1], |
|
instantiation=cfg.NONLOCAL.INSTANTIATION, |
|
trans_func_name=cfg.RESNET.TRANS_FUNC, |
|
stride_1x1=cfg.RESNET.STRIDE_1X1, |
|
inplace_relu=cfg.RESNET.INPLACE_RELU, |
|
dilation=cfg.RESNET.SPATIAL_DILATIONS[1], |
|
norm_module=self.norm_module, |
|
) |
|
|
|
self.s4 = resnet_helper.ResStage( |
|
dim_in=[width_per_group * 8], |
|
dim_out=[width_per_group * 16], |
|
dim_inner=[dim_inner * 4], |
|
temp_kernel_sizes=temp_kernel[3], |
|
stride=cfg.RESNET.SPATIAL_STRIDES[2], |
|
num_blocks=[d4], |
|
num_groups=[num_groups], |
|
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[2], |
|
nonlocal_inds=cfg.NONLOCAL.LOCATION[2], |
|
nonlocal_group=cfg.NONLOCAL.GROUP[2], |
|
nonlocal_pool=cfg.NONLOCAL.POOL[2], |
|
instantiation=cfg.NONLOCAL.INSTANTIATION, |
|
trans_func_name=cfg.RESNET.TRANS_FUNC, |
|
stride_1x1=cfg.RESNET.STRIDE_1X1, |
|
inplace_relu=cfg.RESNET.INPLACE_RELU, |
|
dilation=cfg.RESNET.SPATIAL_DILATIONS[2], |
|
norm_module=self.norm_module, |
|
) |
|
|
|
self.s5 = resnet_helper.ResStage( |
|
dim_in=[width_per_group * 16], |
|
dim_out=[width_per_group * 32], |
|
dim_inner=[dim_inner * 8], |
|
temp_kernel_sizes=temp_kernel[4], |
|
stride=cfg.RESNET.SPATIAL_STRIDES[3], |
|
num_blocks=[d5], |
|
num_groups=[num_groups], |
|
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[3], |
|
nonlocal_inds=cfg.NONLOCAL.LOCATION[3], |
|
nonlocal_group=cfg.NONLOCAL.GROUP[3], |
|
nonlocal_pool=cfg.NONLOCAL.POOL[3], |
|
instantiation=cfg.NONLOCAL.INSTANTIATION, |
|
trans_func_name=cfg.RESNET.TRANS_FUNC, |
|
stride_1x1=cfg.RESNET.STRIDE_1X1, |
|
inplace_relu=cfg.RESNET.INPLACE_RELU, |
|
dilation=cfg.RESNET.SPATIAL_DILATIONS[3], |
|
norm_module=self.norm_module, |
|
) |
|
|
|
if self.enable_detection: |
|
raise NotImplementedError |
|
else: |
|
self.head = head_helper.ResNetBasicHead( |
|
dim_in=[width_per_group * 32], |
|
num_classes=cfg.MODEL.NUM_CLASSES, |
|
pool_size=[None], |
|
dropout_rate=cfg.MODEL.DROPOUT_RATE, |
|
act_func=cfg.MODEL.HEAD_ACT, |
|
) |
|
|
|
def forward(self, x, bboxes=None): |
|
x = self.s1(x) |
|
x = self.s2(x) |
|
for pathway in range(self.num_pathways): |
|
pool = getattr(self, "pathway{}_pool".format(pathway)) |
|
x[pathway] = pool(x[pathway]) |
|
x = self.s3(x) |
|
x = self.s4(x) |
|
x = self.s5(x) |
|
if self.enable_detection: |
|
x = self.head(x, bboxes) |
|
else: |
|
x = self.head(x) |
|
return x |
|
|
|
@MODEL_REGISTRY.register() |
|
class ResNetBase(nn.Module): |
|
""" |
|
ResNet model builder. It builds a ResNet like network backbone without |
|
lateral connection (C2D, I3D, Slow). |
|
|
|
Christoph Feichtenhofer, Haoqi Fan, Jitendra Malik, and Kaiming He. |
|
"SlowFast networks for video recognition." |
|
https://arxiv.org/pdf/1812.03982.pdf |
|
|
|
Xiaolong Wang, Ross Girshick, Abhinav Gupta, and Kaiming He. |
|
"Non-local neural networks." |
|
https://arxiv.org/pdf/1711.07971.pdf |
|
""" |
|
|
|
def __init__(self, cfg): |
|
""" |
|
The `__init__` method of any subclass should also contain these |
|
arguments. |
|
|
|
Args: |
|
cfg (CfgNode): model building configs, details are in the |
|
comments of the config file. |
|
""" |
|
super(ResNetBase, self).__init__() |
|
self.norm_module = get_norm(cfg) |
|
self.enable_detection = cfg.DETECTION.ENABLE |
|
self.num_pathways = 1 |
|
self._construct_network(cfg) |
|
init_helper.init_weights( |
|
self, cfg.MODEL.FC_INIT_STD, cfg.RESNET.ZERO_INIT_FINAL_BN |
|
) |
|
|
|
def _construct_network(self, cfg): |
|
""" |
|
Builds a single pathway ResNet model. |
|
|
|
Args: |
|
cfg (CfgNode): model building configs, details are in the |
|
comments of the config file. |
|
""" |
|
assert cfg.MODEL.ARCH in _POOL1.keys() |
|
pool_size = _POOL1[cfg.MODEL.ARCH] |
|
assert len({len(pool_size), self.num_pathways}) == 1 |
|
assert cfg.RESNET.DEPTH in _MODEL_STAGE_DEPTH.keys() |
|
|
|
(d2, d3, d4, d5) = _MODEL_STAGE_DEPTH[cfg.RESNET.DEPTH] |
|
|
|
num_groups = cfg.RESNET.NUM_GROUPS |
|
width_per_group = cfg.RESNET.WIDTH_PER_GROUP |
|
dim_inner = num_groups * width_per_group |
|
|
|
temp_kernel = _TEMPORAL_KERNEL_BASIS[cfg.MODEL.ARCH] |
|
|
|
self.s1 = stem_helper.VideoModelStem( |
|
dim_in=cfg.DATA.INPUT_CHANNEL_NUM, |
|
dim_out=[width_per_group], |
|
kernel=[temp_kernel[0][0] + [7, 7]], |
|
stride=[[1, 2, 2]], |
|
padding=[[temp_kernel[0][0][0] // 2, 3, 3]], |
|
norm_module=self.norm_module, |
|
) |
|
|
|
self.s2 = resnet_helper.ResStage( |
|
dim_in=[width_per_group], |
|
dim_out=[width_per_group * 4], |
|
dim_inner=[dim_inner], |
|
temp_kernel_sizes=temp_kernel[1], |
|
stride=cfg.RESNET.SPATIAL_STRIDES[0], |
|
num_blocks=[d2], |
|
num_groups=[num_groups], |
|
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[0], |
|
nonlocal_inds=cfg.NONLOCAL.LOCATION[0], |
|
nonlocal_group=cfg.NONLOCAL.GROUP[0], |
|
nonlocal_pool=cfg.NONLOCAL.POOL[0], |
|
instantiation=cfg.NONLOCAL.INSTANTIATION, |
|
trans_func_name=cfg.RESNET.TRANS_FUNC, |
|
stride_1x1=cfg.RESNET.STRIDE_1X1, |
|
inplace_relu=cfg.RESNET.INPLACE_RELU, |
|
dilation=cfg.RESNET.SPATIAL_DILATIONS[0], |
|
norm_module=self.norm_module, |
|
) |
|
|
|
for pathway in range(self.num_pathways): |
|
pool = nn.MaxPool3d( |
|
kernel_size=pool_size[pathway], |
|
stride=pool_size[pathway], |
|
padding=[0, 0, 0], |
|
) |
|
self.add_module("pathway{}_pool".format(pathway), pool) |
|
|
|
self.s3 = resnet_helper.ResStage( |
|
dim_in=[width_per_group * 4], |
|
dim_out=[width_per_group * 8], |
|
dim_inner=[dim_inner * 2], |
|
temp_kernel_sizes=temp_kernel[2], |
|
stride=cfg.RESNET.SPATIAL_STRIDES[1], |
|
num_blocks=[d3], |
|
num_groups=[num_groups], |
|
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[1], |
|
nonlocal_inds=cfg.NONLOCAL.LOCATION[1], |
|
nonlocal_group=cfg.NONLOCAL.GROUP[1], |
|
nonlocal_pool=cfg.NONLOCAL.POOL[1], |
|
instantiation=cfg.NONLOCAL.INSTANTIATION, |
|
trans_func_name=cfg.RESNET.TRANS_FUNC, |
|
stride_1x1=cfg.RESNET.STRIDE_1X1, |
|
inplace_relu=cfg.RESNET.INPLACE_RELU, |
|
dilation=cfg.RESNET.SPATIAL_DILATIONS[1], |
|
norm_module=self.norm_module, |
|
) |
|
|
|
self.s4 = resnet_helper.ResStage( |
|
dim_in=[width_per_group * 8], |
|
dim_out=[width_per_group * 16], |
|
dim_inner=[dim_inner * 4], |
|
temp_kernel_sizes=temp_kernel[3], |
|
stride=cfg.RESNET.SPATIAL_STRIDES[2], |
|
num_blocks=[d4], |
|
num_groups=[num_groups], |
|
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[2], |
|
nonlocal_inds=cfg.NONLOCAL.LOCATION[2], |
|
nonlocal_group=cfg.NONLOCAL.GROUP[2], |
|
nonlocal_pool=cfg.NONLOCAL.POOL[2], |
|
instantiation=cfg.NONLOCAL.INSTANTIATION, |
|
trans_func_name=cfg.RESNET.TRANS_FUNC, |
|
stride_1x1=cfg.RESNET.STRIDE_1X1, |
|
inplace_relu=cfg.RESNET.INPLACE_RELU, |
|
dilation=cfg.RESNET.SPATIAL_DILATIONS[2], |
|
norm_module=self.norm_module, |
|
) |
|
|
|
self.s5 = resnet_helper.ResStage( |
|
dim_in=[width_per_group * 16], |
|
dim_out=[width_per_group * 32], |
|
dim_inner=[dim_inner * 8], |
|
temp_kernel_sizes=temp_kernel[4], |
|
stride=cfg.RESNET.SPATIAL_STRIDES[3], |
|
num_blocks=[d5], |
|
num_groups=[num_groups], |
|
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[3], |
|
nonlocal_inds=cfg.NONLOCAL.LOCATION[3], |
|
nonlocal_group=cfg.NONLOCAL.GROUP[3], |
|
nonlocal_pool=cfg.NONLOCAL.POOL[3], |
|
instantiation=cfg.NONLOCAL.INSTANTIATION, |
|
trans_func_name=cfg.RESNET.TRANS_FUNC, |
|
stride_1x1=cfg.RESNET.STRIDE_1X1, |
|
inplace_relu=cfg.RESNET.INPLACE_RELU, |
|
dilation=cfg.RESNET.SPATIAL_DILATIONS[3], |
|
norm_module=self.norm_module, |
|
) |
|
|
|
if self.enable_detection: |
|
raise NotImplementedError |
|
else: |
|
self.head = head_helper.ResNetBasicHead( |
|
dim_in=[width_per_group * 32], |
|
num_classes=cfg.MODEL.NUM_CLASSES, |
|
pool_size=[None, None] |
|
if cfg.MULTIGRID.SHORT_CYCLE |
|
else [ |
|
None |
|
], |
|
dropout_rate=cfg.MODEL.DROPOUT_RATE, |
|
act_func=cfg.MODEL.HEAD_ACT, |
|
) |
|
|
|
def forward(self, x, bboxes=None): |
|
x = self.s1(x) |
|
x = self.s2(x) |
|
for pathway in range(self.num_pathways): |
|
pool = getattr(self, "pathway{}_pool".format(pathway)) |
|
x[pathway] = pool(x[pathway]) |
|
x = self.s3(x) |
|
x = self.s4(x) |
|
x = self.s5(x) |
|
if self.enable_detection: |
|
x = self.head(x, bboxes) |
|
else: |
|
x = self.head(x) |
|
return x |
|
|
|
|
|
@MODEL_REGISTRY.register() |
|
class ResNetFreeze(nn.Module): |
|
""" |
|
ResNet model builder. It builds a ResNet like network backbone without |
|
lateral connection (C2D, I3D, Slow). |
|
|
|
Christoph Feichtenhofer, Haoqi Fan, Jitendra Malik, and Kaiming He. |
|
"SlowFast networks for video recognition." |
|
https://arxiv.org/pdf/1812.03982.pdf |
|
|
|
Xiaolong Wang, Ross Girshick, Abhinav Gupta, and Kaiming He. |
|
"Non-local neural networks." |
|
https://arxiv.org/pdf/1711.07971.pdf |
|
""" |
|
|
|
def __init__(self, cfg): |
|
""" |
|
The `__init__` method of any subclass should also contain these |
|
arguments. |
|
|
|
Args: |
|
cfg (CfgNode): model building configs, details are in the |
|
comments of the config file. |
|
""" |
|
super(ResNetFreeze, self).__init__() |
|
self.norm_module = get_norm(cfg) |
|
self.enable_detection = cfg.DETECTION.ENABLE |
|
self.num_pathways = 1 |
|
self._construct_network(cfg) |
|
init_helper.init_weights( |
|
self, cfg.MODEL.FC_INIT_STD, cfg.RESNET.ZERO_INIT_FINAL_BN |
|
) |
|
|
|
def _construct_network(self, cfg): |
|
""" |
|
Builds a single pathway ResNet model. |
|
|
|
Args: |
|
cfg (CfgNode): model building configs, details are in the |
|
comments of the config file. |
|
""" |
|
assert cfg.MODEL.ARCH in _POOL1.keys() |
|
pool_size = _POOL1[cfg.MODEL.ARCH] |
|
assert len({len(pool_size), self.num_pathways}) == 1 |
|
assert cfg.RESNET.DEPTH in _MODEL_STAGE_DEPTH.keys() |
|
|
|
(d2, d3, d4, d5) = _MODEL_STAGE_DEPTH[cfg.RESNET.DEPTH] |
|
|
|
num_groups = cfg.RESNET.NUM_GROUPS |
|
width_per_group = cfg.RESNET.WIDTH_PER_GROUP |
|
dim_inner = num_groups * width_per_group |
|
|
|
temp_kernel = _TEMPORAL_KERNEL_BASIS[cfg.MODEL.ARCH] |
|
|
|
self.s1 = stem_helper.VideoModelStem( |
|
dim_in=cfg.DATA.INPUT_CHANNEL_NUM, |
|
dim_out=[width_per_group], |
|
kernel=[temp_kernel[0][0] + [7, 7]], |
|
stride=[[1, 2, 2]], |
|
padding=[[temp_kernel[0][0][0] // 2, 3, 3]], |
|
norm_module=self.norm_module, |
|
) |
|
|
|
self.s2 = resnet_helper.ResStage( |
|
dim_in=[width_per_group], |
|
dim_out=[width_per_group * 4], |
|
dim_inner=[dim_inner], |
|
temp_kernel_sizes=temp_kernel[1], |
|
stride=cfg.RESNET.SPATIAL_STRIDES[0], |
|
num_blocks=[d2], |
|
num_groups=[num_groups], |
|
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[0], |
|
nonlocal_inds=cfg.NONLOCAL.LOCATION[0], |
|
nonlocal_group=cfg.NONLOCAL.GROUP[0], |
|
nonlocal_pool=cfg.NONLOCAL.POOL[0], |
|
instantiation=cfg.NONLOCAL.INSTANTIATION, |
|
trans_func_name=cfg.RESNET.TRANS_FUNC, |
|
stride_1x1=cfg.RESNET.STRIDE_1X1, |
|
inplace_relu=cfg.RESNET.INPLACE_RELU, |
|
dilation=cfg.RESNET.SPATIAL_DILATIONS[0], |
|
norm_module=self.norm_module, |
|
) |
|
|
|
for pathway in range(self.num_pathways): |
|
pool = nn.MaxPool3d( |
|
kernel_size=pool_size[pathway], |
|
stride=pool_size[pathway], |
|
padding=[0, 0, 0], |
|
) |
|
self.add_module("pathway{}_pool".format(pathway), pool) |
|
|
|
self.s3 = resnet_helper.ResStage( |
|
dim_in=[width_per_group * 4], |
|
dim_out=[width_per_group * 8], |
|
dim_inner=[dim_inner * 2], |
|
temp_kernel_sizes=temp_kernel[2], |
|
stride=cfg.RESNET.SPATIAL_STRIDES[1], |
|
num_blocks=[d3], |
|
num_groups=[num_groups], |
|
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[1], |
|
nonlocal_inds=cfg.NONLOCAL.LOCATION[1], |
|
nonlocal_group=cfg.NONLOCAL.GROUP[1], |
|
nonlocal_pool=cfg.NONLOCAL.POOL[1], |
|
instantiation=cfg.NONLOCAL.INSTANTIATION, |
|
trans_func_name=cfg.RESNET.TRANS_FUNC, |
|
stride_1x1=cfg.RESNET.STRIDE_1X1, |
|
inplace_relu=cfg.RESNET.INPLACE_RELU, |
|
dilation=cfg.RESNET.SPATIAL_DILATIONS[1], |
|
norm_module=self.norm_module, |
|
) |
|
|
|
self.s4 = resnet_helper.ResStage( |
|
dim_in=[width_per_group * 8], |
|
dim_out=[width_per_group * 16], |
|
dim_inner=[dim_inner * 4], |
|
temp_kernel_sizes=temp_kernel[3], |
|
stride=cfg.RESNET.SPATIAL_STRIDES[2], |
|
num_blocks=[d4], |
|
num_groups=[num_groups], |
|
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[2], |
|
nonlocal_inds=cfg.NONLOCAL.LOCATION[2], |
|
nonlocal_group=cfg.NONLOCAL.GROUP[2], |
|
nonlocal_pool=cfg.NONLOCAL.POOL[2], |
|
instantiation=cfg.NONLOCAL.INSTANTIATION, |
|
trans_func_name=cfg.RESNET.TRANS_FUNC, |
|
stride_1x1=cfg.RESNET.STRIDE_1X1, |
|
inplace_relu=cfg.RESNET.INPLACE_RELU, |
|
dilation=cfg.RESNET.SPATIAL_DILATIONS[2], |
|
norm_module=self.norm_module, |
|
) |
|
|
|
self.s5 = resnet_helper.ResStage( |
|
dim_in=[width_per_group * 16], |
|
dim_out=[width_per_group * 32], |
|
dim_inner=[dim_inner * 8], |
|
temp_kernel_sizes=temp_kernel[4], |
|
stride=cfg.RESNET.SPATIAL_STRIDES[3], |
|
num_blocks=[d5], |
|
num_groups=[num_groups], |
|
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[3], |
|
nonlocal_inds=cfg.NONLOCAL.LOCATION[3], |
|
nonlocal_group=cfg.NONLOCAL.GROUP[3], |
|
nonlocal_pool=cfg.NONLOCAL.POOL[3], |
|
instantiation=cfg.NONLOCAL.INSTANTIATION, |
|
trans_func_name=cfg.RESNET.TRANS_FUNC, |
|
stride_1x1=cfg.RESNET.STRIDE_1X1, |
|
inplace_relu=cfg.RESNET.INPLACE_RELU, |
|
dilation=cfg.RESNET.SPATIAL_DILATIONS[3], |
|
norm_module=self.norm_module, |
|
) |
|
|
|
if self.enable_detection: |
|
raise NotImplementedError |
|
else: |
|
self.head = head_helper.ResNetBasicHead( |
|
dim_in=[width_per_group * 32], |
|
num_classes=cfg.MODEL.NUM_CLASSES, |
|
pool_size=[None,None] |
|
if cfg.MULTIGRID.SHORT_CYCLE |
|
else [ |
|
None |
|
], |
|
dropout_rate=cfg.MODEL.DROPOUT_RATE, |
|
act_func=cfg.MODEL.HEAD_ACT, |
|
) |
|
|
|
def forward(self, x, freeze_backbone=False): |
|
assert isinstance(freeze_backbone,bool) |
|
x = self.s1(x) |
|
x = self.s2(x) |
|
|
|
|
|
|
|
x = self.s3(x) |
|
x = self.s4(x) |
|
x = self.s5(x) |
|
if freeze_backbone: |
|
x=[item.detach() for item in x] |
|
|
|
x = self.head(x) |
|
return x |
|
|
|
|
|
|
|
import torch.nn.functional as F |
|
from .unet_helper import DecoderBlock,LightDecoderBlock,ResDecoderBlock |
|
|
|
|
|
@MODEL_REGISTRY.register() |
|
class ResUNet(nn.Module): |
|
""" |
|
ResNet model builder. It builds a ResNet like network backbone without |
|
lateral connection (C2D, I3D, Slow). |
|
|
|
Christoph Feichtenhofer, Haoqi Fan, Jitendra Malik, and Kaiming He. |
|
"SlowFast networks for video recognition." |
|
https://arxiv.org/pdf/1812.03982.pdf |
|
|
|
Xiaolong Wang, Ross Girshick, Abhinav Gupta, and Kaiming He. |
|
"Non-local neural networks." |
|
https://arxiv.org/pdf/1711.07971.pdf |
|
""" |
|
|
|
def __init__(self, cfg): |
|
""" |
|
The `__init__` method of any subclass should also contain these |
|
arguments. |
|
|
|
Args: |
|
cfg (CfgNode): model building configs, details are in the |
|
comments of the config file. |
|
""" |
|
super(ResUNet, self).__init__() |
|
self.norm_module = get_norm(cfg) |
|
self.enable_detection = cfg.DETECTION.ENABLE |
|
self.enable_jitter = cfg.JITTER.ENABLE |
|
self.num_pathways = 1 |
|
assert cfg.DATA.TRAIN_CROP_SIZE == cfg.DATA.TEST_CROP_SIZE |
|
self.image_size = cfg.DATA.TRAIN_CROP_SIZE |
|
self.clip_size = cfg.DATA.NUM_FRAMES |
|
self._construct_network(cfg) |
|
init_helper.init_weights( |
|
self, cfg.MODEL.FC_INIT_STD, cfg.RESNET.ZERO_INIT_FINAL_BN |
|
) |
|
|
|
def _construct_network(self, cfg): |
|
""" |
|
Builds a single pathway ResNet model. |
|
|
|
Args: |
|
cfg (CfgNode): model building configs, details are in the |
|
comments of the config file. |
|
""" |
|
assert cfg.MODEL.ARCH in _POOL1.keys() |
|
pool_size = _POOL1[cfg.MODEL.ARCH] |
|
self.cfg = cfg |
|
assert len({len(pool_size), self.num_pathways}) == 1 |
|
assert cfg.RESNET.DEPTH in _MODEL_STAGE_DEPTH.keys() |
|
|
|
(d2, d3, d4, d5) = _MODEL_STAGE_DEPTH[cfg.RESNET.DEPTH] |
|
|
|
num_groups = cfg.RESNET.NUM_GROUPS |
|
width_per_group = cfg.RESNET.WIDTH_PER_GROUP |
|
dim_inner = num_groups * width_per_group |
|
|
|
temp_kernel = _TEMPORAL_KERNEL_BASIS[cfg.MODEL.ARCH] |
|
|
|
self.s1 = stem_helper.VideoModelStem( |
|
dim_in=cfg.DATA.INPUT_CHANNEL_NUM, |
|
dim_out=[width_per_group], |
|
kernel=[temp_kernel[0][0] + [7, 7]], |
|
stride=[[1, 2, 2]], |
|
padding=[[temp_kernel[0][0][0] // 2, 3, 3]], |
|
norm_module=self.norm_module, |
|
) |
|
|
|
self.s2 = resnet_helper.ResStage( |
|
dim_in=[width_per_group], |
|
dim_out=[width_per_group * 4], |
|
dim_inner=[dim_inner], |
|
temp_kernel_sizes=temp_kernel[1], |
|
stride=cfg.RESNET.SPATIAL_STRIDES[0], |
|
num_blocks=[d2], |
|
num_groups=[num_groups], |
|
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[0], |
|
nonlocal_inds=cfg.NONLOCAL.LOCATION[0], |
|
nonlocal_group=cfg.NONLOCAL.GROUP[0], |
|
nonlocal_pool=cfg.NONLOCAL.POOL[0], |
|
instantiation=cfg.NONLOCAL.INSTANTIATION, |
|
trans_func_name=cfg.RESNET.TRANS_FUNC, |
|
stride_1x1=cfg.RESNET.STRIDE_1X1, |
|
inplace_relu=cfg.RESNET.INPLACE_RELU, |
|
dilation=cfg.RESNET.SPATIAL_DILATIONS[0], |
|
norm_module=self.norm_module, |
|
) |
|
|
|
for pathway in range(self.num_pathways): |
|
pool = nn.MaxPool3d( |
|
kernel_size=pool_size[pathway], |
|
stride=pool_size[pathway], |
|
padding=[0, 0, 0], |
|
) |
|
self.add_module("pathway{}_pool".format(pathway), pool) |
|
|
|
self.s3 = resnet_helper.ResStage( |
|
dim_in=[width_per_group * 4], |
|
dim_out=[width_per_group * 8], |
|
dim_inner=[dim_inner * 2], |
|
temp_kernel_sizes=temp_kernel[2], |
|
stride=cfg.RESNET.SPATIAL_STRIDES[1], |
|
num_blocks=[d3], |
|
num_groups=[num_groups], |
|
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[1], |
|
nonlocal_inds=cfg.NONLOCAL.LOCATION[1], |
|
nonlocal_group=cfg.NONLOCAL.GROUP[1], |
|
nonlocal_pool=cfg.NONLOCAL.POOL[1], |
|
instantiation=cfg.NONLOCAL.INSTANTIATION, |
|
trans_func_name=cfg.RESNET.TRANS_FUNC, |
|
stride_1x1=cfg.RESNET.STRIDE_1X1, |
|
inplace_relu=cfg.RESNET.INPLACE_RELU, |
|
dilation=cfg.RESNET.SPATIAL_DILATIONS[1], |
|
norm_module=self.norm_module, |
|
) |
|
|
|
self.s4 = resnet_helper.ResStage( |
|
dim_in=[width_per_group * 8], |
|
dim_out=[width_per_group * 16], |
|
dim_inner=[dim_inner * 4], |
|
temp_kernel_sizes=temp_kernel[3], |
|
stride=cfg.RESNET.SPATIAL_STRIDES[2], |
|
num_blocks=[d4], |
|
num_groups=[num_groups], |
|
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[2], |
|
nonlocal_inds=cfg.NONLOCAL.LOCATION[2], |
|
nonlocal_group=cfg.NONLOCAL.GROUP[2], |
|
nonlocal_pool=cfg.NONLOCAL.POOL[2], |
|
instantiation=cfg.NONLOCAL.INSTANTIATION, |
|
trans_func_name=cfg.RESNET.TRANS_FUNC, |
|
stride_1x1=cfg.RESNET.STRIDE_1X1, |
|
inplace_relu=cfg.RESNET.INPLACE_RELU, |
|
dilation=cfg.RESNET.SPATIAL_DILATIONS[2], |
|
norm_module=self.norm_module, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.labels=["rotate","light"] |
|
self.dual_define("t4",self.labels,DecoderBlock(width_per_group * 16,width_per_group * 8,width_per_group * 8)) |
|
self.dual_define("t3",self.labels,DecoderBlock(width_per_group * 8,width_per_group * 4, 256)) |
|
self.dual_define("conv1x1",self.labels,nn.Sequential( |
|
nn.Conv3d(width_per_group*4+width_per_group, 1, kernel_size=(1, 1, 1), stride=1, padding=0), nn.Sigmoid() |
|
)) |
|
|
|
self.linear = nn.Sequential(nn.Linear(1, 1), nn.Sigmoid()) |
|
|
|
def forward_plus(self, x, y, net): |
|
return [net(x)[0] + y[0]] |
|
|
|
|
|
def dual_define(self,name,labels,net): |
|
for label in labels: |
|
self.add_module(f"{name}_{label}",copy.deepcopy(net)) |
|
|
|
|
|
|
|
def upsample(self, x, dims=["space"]): |
|
ori_size = x[0].shape[2:5] |
|
t, h, w = ori_size |
|
if "space" in dims: |
|
h = 2 * h |
|
w = 2 * w |
|
if "time" in dims: |
|
t = 2 * t |
|
size = (t, h, w) |
|
return [F.interpolate(x[0], size)] |
|
|
|
def concat(self,x,y): |
|
return [torch.cat([x[0],y[0]],1)] |
|
|
|
|
|
|
|
|
|
def forward(self, x, bboxes=None): |
|
x1 = self.s1(x) |
|
x2 = self.s2(x1) |
|
x3 = self.s3(x2) |
|
x = self.s4(x3) |
|
x = self.upsample(x) |
|
x = self.concat(x3,x) |
|
x=[self.forward_branch(x,x1,x2,label) for label in self.labels] |
|
x=torch.cat(x,1) |
|
out = x.mean([3, 4]).view(-1, 1)*100 |
|
out = self.linear(out) |
|
out = out.view(x.size(0), -1) |
|
return x,out |
|
|
|
|
|
|
|
def forward_branch(self,x,x1,x2,label): |
|
t4=getattr(self,f"t4_{label}") |
|
x = t4(x[0]) |
|
x = self.upsample([x]) |
|
x = self.concat(x2,x) |
|
t3= getattr(self,f"t3_{label}") |
|
x = t3(x[0]) |
|
x = self.concat(x1,[x]) |
|
conv1x1=getattr(self,f"conv1x1_{label}") |
|
x = conv1x1(x[0]) |
|
return x |
|
|
|
|
|
|
|
@MODEL_REGISTRY.register() |
|
class ResUNetLight(nn.Module): |
|
""" |
|
ResNet model builder. It builds a ResNet like network backbone without |
|
lateral connection (C2D, I3D, Slow). |
|
|
|
Christoph Feichtenhofer, Haoqi Fan, Jitendra Malik, and Kaiming He. |
|
"SlowFast networks for video recognition." |
|
https://arxiv.org/pdf/1812.03982.pdf |
|
|
|
Xiaolong Wang, Ross Girshick, Abhinav Gupta, and Kaiming He. |
|
"Non-local neural networks." |
|
https://arxiv.org/pdf/1711.07971.pdf |
|
""" |
|
|
|
def __init__(self, cfg): |
|
""" |
|
The `__init__` method of any subclass should also contain these |
|
arguments. |
|
|
|
Args: |
|
cfg (CfgNode): model building configs, details are in the |
|
comments of the config file. |
|
""" |
|
super(ResUNetLight, self).__init__() |
|
self.norm_module = get_norm(cfg) |
|
self.enable_detection = cfg.DETECTION.ENABLE |
|
self.enable_jitter = cfg.JITTER.ENABLE |
|
self.num_pathways = 1 |
|
assert cfg.DATA.TRAIN_CROP_SIZE == cfg.DATA.TEST_CROP_SIZE |
|
self.image_size = cfg.DATA.TRAIN_CROP_SIZE |
|
self.clip_size = cfg.DATA.NUM_FRAMES |
|
self._construct_network(cfg) |
|
init_helper.init_weights( |
|
self, cfg.MODEL.FC_INIT_STD, cfg.RESNET.ZERO_INIT_FINAL_BN |
|
) |
|
|
|
def _construct_network(self, cfg): |
|
""" |
|
Builds a single pathway ResNet model. |
|
|
|
Args: |
|
cfg (CfgNode): model building configs, details are in the |
|
comments of the config file. |
|
""" |
|
assert cfg.MODEL.ARCH in _POOL1.keys() |
|
pool_size = _POOL1[cfg.MODEL.ARCH] |
|
self.cfg = cfg |
|
assert len({len(pool_size), self.num_pathways}) == 1 |
|
assert cfg.RESNET.DEPTH in _MODEL_STAGE_DEPTH.keys() |
|
|
|
(d2, d3, d4, d5) = _MODEL_STAGE_DEPTH[cfg.RESNET.DEPTH] |
|
|
|
num_groups = cfg.RESNET.NUM_GROUPS |
|
width_per_group = cfg.RESNET.WIDTH_PER_GROUP |
|
dim_inner = num_groups * width_per_group |
|
|
|
temp_kernel = _TEMPORAL_KERNEL_BASIS[cfg.MODEL.ARCH] |
|
|
|
self.s1 = stem_helper.VideoModelStem( |
|
dim_in=cfg.DATA.INPUT_CHANNEL_NUM, |
|
dim_out=[width_per_group], |
|
kernel=[temp_kernel[0][0] + [7, 7]], |
|
stride=[[1, 2, 2]], |
|
padding=[[temp_kernel[0][0][0] // 2, 3, 3]], |
|
norm_module=self.norm_module, |
|
) |
|
|
|
self.s2 = resnet_helper.ResStage( |
|
dim_in=[width_per_group], |
|
dim_out=[width_per_group * 4], |
|
dim_inner=[dim_inner], |
|
temp_kernel_sizes=temp_kernel[1], |
|
stride=cfg.RESNET.SPATIAL_STRIDES[0], |
|
num_blocks=[d2], |
|
num_groups=[num_groups], |
|
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[0], |
|
nonlocal_inds=cfg.NONLOCAL.LOCATION[0], |
|
nonlocal_group=cfg.NONLOCAL.GROUP[0], |
|
nonlocal_pool=cfg.NONLOCAL.POOL[0], |
|
instantiation=cfg.NONLOCAL.INSTANTIATION, |
|
trans_func_name=cfg.RESNET.TRANS_FUNC, |
|
stride_1x1=cfg.RESNET.STRIDE_1X1, |
|
inplace_relu=cfg.RESNET.INPLACE_RELU, |
|
dilation=cfg.RESNET.SPATIAL_DILATIONS[0], |
|
norm_module=self.norm_module, |
|
) |
|
|
|
for pathway in range(self.num_pathways): |
|
pool = nn.MaxPool3d( |
|
kernel_size=pool_size[pathway], |
|
stride=pool_size[pathway], |
|
padding=[0, 0, 0], |
|
) |
|
self.add_module("pathway{}_pool".format(pathway), pool) |
|
|
|
self.s3 = resnet_helper.ResStage( |
|
dim_in=[width_per_group * 4], |
|
dim_out=[width_per_group * 8], |
|
dim_inner=[dim_inner * 2], |
|
temp_kernel_sizes=temp_kernel[2], |
|
stride=cfg.RESNET.SPATIAL_STRIDES[1], |
|
num_blocks=[d3], |
|
num_groups=[num_groups], |
|
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[1], |
|
nonlocal_inds=cfg.NONLOCAL.LOCATION[1], |
|
nonlocal_group=cfg.NONLOCAL.GROUP[1], |
|
nonlocal_pool=cfg.NONLOCAL.POOL[1], |
|
instantiation=cfg.NONLOCAL.INSTANTIATION, |
|
trans_func_name=cfg.RESNET.TRANS_FUNC, |
|
stride_1x1=cfg.RESNET.STRIDE_1X1, |
|
inplace_relu=cfg.RESNET.INPLACE_RELU, |
|
dilation=cfg.RESNET.SPATIAL_DILATIONS[1], |
|
norm_module=self.norm_module, |
|
) |
|
|
|
self.s4 = resnet_helper.ResStage( |
|
dim_in=[width_per_group * 8], |
|
dim_out=[width_per_group * 16], |
|
dim_inner=[dim_inner * 4], |
|
temp_kernel_sizes=temp_kernel[3], |
|
stride=cfg.RESNET.SPATIAL_STRIDES[2], |
|
num_blocks=[d4], |
|
num_groups=[num_groups], |
|
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[2], |
|
nonlocal_inds=cfg.NONLOCAL.LOCATION[2], |
|
nonlocal_group=cfg.NONLOCAL.GROUP[2], |
|
nonlocal_pool=cfg.NONLOCAL.POOL[2], |
|
instantiation=cfg.NONLOCAL.INSTANTIATION, |
|
trans_func_name=cfg.RESNET.TRANS_FUNC, |
|
stride_1x1=cfg.RESNET.STRIDE_1X1, |
|
inplace_relu=cfg.RESNET.INPLACE_RELU, |
|
dilation=cfg.RESNET.SPATIAL_DILATIONS[2], |
|
norm_module=self.norm_module, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.labels=["rotate","light"] |
|
self.dual_define("t4",self.labels,LightDecoderBlock(width_per_group * 16,width_per_group * 8,width_per_group * 4)) |
|
self.dual_define("t3",self.labels,LightDecoderBlock(width_per_group * 4,width_per_group * 4, 128)) |
|
self.dual_define("conv1x1",self.labels,nn.Sequential( |
|
nn.Conv3d(128+width_per_group, 1, kernel_size=(1, 1, 1), stride=1, padding=0), nn.Sigmoid() |
|
)) |
|
|
|
self.linear = nn.Sequential(nn.Linear(1, 1), nn.Sigmoid()) |
|
|
|
def forward_plus(self, x, y, net): |
|
return [net(x)[0] + y[0]] |
|
|
|
|
|
def dual_define(self,name,labels,net): |
|
for label in labels: |
|
self.add_module(f"{name}_{label}",copy.deepcopy(net)) |
|
|
|
|
|
|
|
def upsample(self, x, dims=["space"]): |
|
ori_size = x[0].shape[2:5] |
|
t, h, w = ori_size |
|
if "space" in dims: |
|
h = 2 * h |
|
w = 2 * w |
|
if "time" in dims: |
|
t = 2 * t |
|
size = (t, h, w) |
|
return [F.interpolate(x[0], size)] |
|
|
|
def concat(self,x,y): |
|
return [torch.cat([x[0],y[0]],1)] |
|
|
|
def get_detach_var(self,x): |
|
return [t.detach() for t in x] |
|
|
|
|
|
def forward(self, x, freeze_backbone=False): |
|
x1 = self.s1(x) |
|
x2 = self.s2(x1) |
|
x3 = self.s3(x2) |
|
x = self.s4(x3) |
|
assert isinstance(freeze_backbone,bool) |
|
if freeze_backbone: |
|
x=self.get_detach_var(x) |
|
x1=self.get_detach_var(x1) |
|
x2=self.get_detach_var(x2) |
|
x3=self.get_detach_var(x3) |
|
|
|
x = self.upsample(x) |
|
x = self.concat(x3,x) |
|
x=[self.forward_branch(x,x1,x2,label) for label in self.labels] |
|
x=torch.cat(x,1) |
|
out = x.mean([3, 4]).view(-1, 1)*100 |
|
out = self.linear(out) |
|
out = out.view(x.size(0), -1) |
|
return x,out |
|
|
|
|
|
|
|
def forward_branch(self,x,x1,x2,label): |
|
t4=getattr(self,f"t4_{label}") |
|
x = t4(x[0]) |
|
x = self.upsample([x]) |
|
x = self.concat(x2,x) |
|
t3= getattr(self,f"t3_{label}") |
|
x = t3(x[0]) |
|
x = self.concat(x1,[x]) |
|
conv1x1=getattr(self,f"conv1x1_{label}") |
|
x = conv1x1(x[0]) |
|
return x |
|
|
|
|
|
|
|
@MODEL_REGISTRY.register() |
|
class ResUNetLightFix(nn.Module): |
|
""" |
|
ResNet model builder. It builds a ResNet like network backbone without |
|
lateral connection (C2D, I3D, Slow). |
|
|
|
Christoph Feichtenhofer, Haoqi Fan, Jitendra Malik, and Kaiming He. |
|
"SlowFast networks for video recognition." |
|
https://arxiv.org/pdf/1812.03982.pdf |
|
|
|
Xiaolong Wang, Ross Girshick, Abhinav Gupta, and Kaiming He. |
|
"Non-local neural networks." |
|
https://arxiv.org/pdf/1711.07971.pdf |
|
""" |
|
|
|
def __init__(self, cfg): |
|
""" |
|
The `__init__` method of any subclass should also contain these |
|
arguments. |
|
|
|
Args: |
|
cfg (CfgNode): model building configs, details are in the |
|
comments of the config file. |
|
""" |
|
super(ResUNetLightFix, self).__init__() |
|
self.norm_module = get_norm(cfg) |
|
self.enable_detection = cfg.DETECTION.ENABLE |
|
self.enable_jitter = cfg.JITTER.ENABLE |
|
self.num_pathways = 1 |
|
assert cfg.DATA.TRAIN_CROP_SIZE == cfg.DATA.TEST_CROP_SIZE |
|
self.image_size = cfg.DATA.TRAIN_CROP_SIZE |
|
self.clip_size = cfg.DATA.NUM_FRAMES |
|
self._construct_network(cfg) |
|
init_helper.init_weights( |
|
self, cfg.MODEL.FC_INIT_STD, cfg.RESNET.ZERO_INIT_FINAL_BN |
|
) |
|
|
|
def _construct_network(self, cfg): |
|
""" |
|
Builds a single pathway ResNet model. |
|
|
|
Args: |
|
cfg (CfgNode): model building configs, details are in the |
|
comments of the config file. |
|
""" |
|
assert cfg.MODEL.ARCH in _POOL1.keys() |
|
pool_size = _POOL1[cfg.MODEL.ARCH] |
|
self.cfg = cfg |
|
assert len({len(pool_size), self.num_pathways}) == 1 |
|
assert cfg.RESNET.DEPTH in _MODEL_STAGE_DEPTH.keys() |
|
|
|
(d2, d3, d4, d5) = _MODEL_STAGE_DEPTH[cfg.RESNET.DEPTH] |
|
|
|
num_groups = cfg.RESNET.NUM_GROUPS |
|
width_per_group = cfg.RESNET.WIDTH_PER_GROUP |
|
dim_inner = num_groups * width_per_group |
|
|
|
temp_kernel = _TEMPORAL_KERNEL_BASIS[cfg.MODEL.ARCH] |
|
|
|
self.s1 = stem_helper.VideoModelStem( |
|
dim_in=cfg.DATA.INPUT_CHANNEL_NUM, |
|
dim_out=[width_per_group], |
|
kernel=[temp_kernel[0][0] + [7, 7]], |
|
stride=[[1, 2, 2]], |
|
padding=[[temp_kernel[0][0][0] // 2, 3, 3]], |
|
norm_module=self.norm_module, |
|
) |
|
|
|
self.s2 = resnet_helper.ResStage( |
|
dim_in=[width_per_group], |
|
dim_out=[width_per_group * 4], |
|
dim_inner=[dim_inner], |
|
temp_kernel_sizes=temp_kernel[1], |
|
stride=cfg.RESNET.SPATIAL_STRIDES[0], |
|
num_blocks=[d2], |
|
num_groups=[num_groups], |
|
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[0], |
|
nonlocal_inds=cfg.NONLOCAL.LOCATION[0], |
|
nonlocal_group=cfg.NONLOCAL.GROUP[0], |
|
nonlocal_pool=cfg.NONLOCAL.POOL[0], |
|
instantiation=cfg.NONLOCAL.INSTANTIATION, |
|
trans_func_name=cfg.RESNET.TRANS_FUNC, |
|
stride_1x1=cfg.RESNET.STRIDE_1X1, |
|
inplace_relu=cfg.RESNET.INPLACE_RELU, |
|
dilation=cfg.RESNET.SPATIAL_DILATIONS[0], |
|
norm_module=self.norm_module, |
|
) |
|
|
|
for pathway in range(self.num_pathways): |
|
pool = nn.MaxPool3d( |
|
kernel_size=pool_size[pathway], |
|
stride=pool_size[pathway], |
|
padding=[0, 0, 0], |
|
) |
|
self.add_module("pathway{}_pool".format(pathway), pool) |
|
|
|
self.s3 = resnet_helper.ResStage( |
|
dim_in=[width_per_group * 4], |
|
dim_out=[width_per_group * 8], |
|
dim_inner=[dim_inner * 2], |
|
temp_kernel_sizes=temp_kernel[2], |
|
stride=cfg.RESNET.SPATIAL_STRIDES[1], |
|
num_blocks=[d3], |
|
num_groups=[num_groups], |
|
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[1], |
|
nonlocal_inds=cfg.NONLOCAL.LOCATION[1], |
|
nonlocal_group=cfg.NONLOCAL.GROUP[1], |
|
nonlocal_pool=cfg.NONLOCAL.POOL[1], |
|
instantiation=cfg.NONLOCAL.INSTANTIATION, |
|
trans_func_name=cfg.RESNET.TRANS_FUNC, |
|
stride_1x1=cfg.RESNET.STRIDE_1X1, |
|
inplace_relu=cfg.RESNET.INPLACE_RELU, |
|
dilation=cfg.RESNET.SPATIAL_DILATIONS[1], |
|
norm_module=self.norm_module, |
|
) |
|
|
|
self.s4 = resnet_helper.ResStage( |
|
dim_in=[width_per_group * 8], |
|
dim_out=[width_per_group * 16], |
|
dim_inner=[dim_inner * 4], |
|
temp_kernel_sizes=temp_kernel[3], |
|
stride=cfg.RESNET.SPATIAL_STRIDES[2], |
|
num_blocks=[d4], |
|
num_groups=[num_groups], |
|
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[2], |
|
nonlocal_inds=cfg.NONLOCAL.LOCATION[2], |
|
nonlocal_group=cfg.NONLOCAL.GROUP[2], |
|
nonlocal_pool=cfg.NONLOCAL.POOL[2], |
|
instantiation=cfg.NONLOCAL.INSTANTIATION, |
|
trans_func_name=cfg.RESNET.TRANS_FUNC, |
|
stride_1x1=cfg.RESNET.STRIDE_1X1, |
|
inplace_relu=cfg.RESNET.INPLACE_RELU, |
|
dilation=cfg.RESNET.SPATIAL_DILATIONS[2], |
|
norm_module=self.norm_module, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.labels=["rotate","light","skip"] |
|
self.dual_define("t4",self.labels,LightDecoderBlock(width_per_group * 16,width_per_group * 8,width_per_group * 4)) |
|
self.dual_define("t3",self.labels,LightDecoderBlock(width_per_group * 4,width_per_group * 4, 128)) |
|
self.dual_define("conv1x1",self.labels,nn.Sequential( |
|
nn.Conv3d(128+width_per_group, 64, kernel_size=(1, 1, 1), stride=1, padding=0), |
|
nn.BatchNorm3d(64), |
|
nn.ReLU(), |
|
nn.Conv3d(64, 1, kernel_size=(1, 1, 1), stride=1, padding=0), |
|
)) |
|
|
|
self.linear = nn.Sequential(nn.Linear(1, 1)) |
|
|
|
def forward_plus(self, x, y, net): |
|
return [net(x)[0] + y[0]] |
|
|
|
|
|
def dual_define(self,name,labels,net): |
|
for label in labels: |
|
self.add_module(f"{name}_{label}",copy.deepcopy(net)) |
|
|
|
|
|
|
|
def upsample(self, x, dims=["space"]): |
|
ori_size = x[0].shape[2:5] |
|
t, h, w = ori_size |
|
if "space" in dims: |
|
h = 2 * h |
|
w = 2 * w |
|
if "time" in dims: |
|
t = 2 * t |
|
size = (t, h, w) |
|
return [F.interpolate(x[0], size)] |
|
|
|
def concat(self,x,y): |
|
return [torch.cat([x[0],y[0]],1)] |
|
|
|
def get_detach_var(self,x): |
|
return [t.detach() for t in x] |
|
|
|
|
|
def forward(self, x, freeze_backbone=False): |
|
x1 = self.s1(x) |
|
x2 = self.s2(x1) |
|
x3 = self.s3(x2) |
|
x = self.s4(x3) |
|
assert isinstance(freeze_backbone,bool) |
|
if freeze_backbone: |
|
x=self.get_detach_var(x) |
|
x1=self.get_detach_var(x1) |
|
x2=self.get_detach_var(x2) |
|
x3=self.get_detach_var(x3) |
|
|
|
x = self.upsample(x) |
|
x = self.concat(x3,x) |
|
x=[self.forward_branch(x,x1,x2,label) for label in self.labels] |
|
x=torch.cat(x,1) |
|
x=torch.sigmoid(x) |
|
out = x.mean([3, 4]).view(-1, 1)*100 |
|
out = self.linear(out) |
|
out = out.view(x.size(0), -1) |
|
out = torch.sigmoid(out) |
|
return x,out |
|
|
|
|
|
|
|
def forward_branch(self,x,x1,x2,label): |
|
t4=getattr(self,f"t4_{label}") |
|
x = t4(x[0]) |
|
x = self.upsample([x]) |
|
x = self.concat(x2,x) |
|
t3= getattr(self,f"t3_{label}") |
|
x = t3(x[0]) |
|
x = self.concat(x1,[x]) |
|
conv1x1=getattr(self,f"conv1x1_{label}") |
|
x = conv1x1(x[0]) |
|
return x |
|
|
|
|
|
|
|
@MODEL_REGISTRY.register() |
|
class ResUNetContinus(nn.Module): |
|
""" |
|
ResNet model builder. It builds a ResNet like network backbone without |
|
lateral connection (C2D, I3D, Slow). |
|
|
|
Christoph Feichtenhofer, Haoqi Fan, Jitendra Malik, and Kaiming He. |
|
"SlowFast networks for video recognition." |
|
https://arxiv.org/pdf/1812.03982.pdf |
|
|
|
Xiaolong Wang, Ross Girshick, Abhinav Gupta, and Kaiming He. |
|
"Non-local neural networks." |
|
https://arxiv.org/pdf/1711.07971.pdf |
|
""" |
|
|
|
def __init__(self, cfg): |
|
""" |
|
The `__init__` method of any subclass should also contain these |
|
arguments. |
|
|
|
Args: |
|
cfg (CfgNode): model building configs, details are in the |
|
comments of the config file. |
|
""" |
|
super(ResUNetContinus, self).__init__() |
|
self.norm_module = get_norm(cfg) |
|
self.enable_detection = cfg.DETECTION.ENABLE |
|
self.enable_jitter = cfg.JITTER.ENABLE |
|
self.num_pathways = 1 |
|
assert cfg.DATA.TRAIN_CROP_SIZE == cfg.DATA.TEST_CROP_SIZE |
|
self.image_size = cfg.DATA.TRAIN_CROP_SIZE |
|
self.clip_size = cfg.DATA.NUM_FRAMES |
|
self._construct_network(cfg) |
|
init_helper.init_weights( |
|
self, cfg.MODEL.FC_INIT_STD, cfg.RESNET.ZERO_INIT_FINAL_BN |
|
) |
|
|
|
def _construct_network(self, cfg): |
|
""" |
|
Builds a single pathway ResNet model. |
|
|
|
Args: |
|
cfg (CfgNode): model building configs, details are in the |
|
comments of the config file. |
|
""" |
|
assert cfg.MODEL.ARCH in _POOL1.keys() |
|
pool_size = _POOL1[cfg.MODEL.ARCH] |
|
self.cfg = cfg |
|
assert len({len(pool_size), self.num_pathways}) == 1 |
|
assert cfg.RESNET.DEPTH in _MODEL_STAGE_DEPTH.keys() |
|
|
|
(d2, d3, d4, d5) = _MODEL_STAGE_DEPTH[cfg.RESNET.DEPTH] |
|
|
|
num_groups = cfg.RESNET.NUM_GROUPS |
|
width_per_group = cfg.RESNET.WIDTH_PER_GROUP |
|
dim_inner = num_groups * width_per_group |
|
|
|
temp_kernel = _TEMPORAL_KERNEL_BASIS[cfg.MODEL.ARCH] |
|
|
|
self.s1 = stem_helper.VideoModelStem( |
|
dim_in=cfg.DATA.INPUT_CHANNEL_NUM, |
|
dim_out=[width_per_group], |
|
kernel=[temp_kernel[0][0] + [7, 7]], |
|
stride=[[1, 2, 2]], |
|
padding=[[temp_kernel[0][0][0] // 2, 3, 3]], |
|
norm_module=self.norm_module, |
|
) |
|
|
|
self.s2 = resnet_helper.ResStage( |
|
dim_in=[width_per_group], |
|
dim_out=[width_per_group * 4], |
|
dim_inner=[dim_inner], |
|
temp_kernel_sizes=temp_kernel[1], |
|
stride=cfg.RESNET.SPATIAL_STRIDES[0], |
|
num_blocks=[d2], |
|
num_groups=[num_groups], |
|
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[0], |
|
nonlocal_inds=cfg.NONLOCAL.LOCATION[0], |
|
nonlocal_group=cfg.NONLOCAL.GROUP[0], |
|
nonlocal_pool=cfg.NONLOCAL.POOL[0], |
|
instantiation=cfg.NONLOCAL.INSTANTIATION, |
|
trans_func_name=cfg.RESNET.TRANS_FUNC, |
|
stride_1x1=cfg.RESNET.STRIDE_1X1, |
|
inplace_relu=cfg.RESNET.INPLACE_RELU, |
|
dilation=cfg.RESNET.SPATIAL_DILATIONS[0], |
|
norm_module=self.norm_module, |
|
) |
|
|
|
for pathway in range(self.num_pathways): |
|
pool = nn.MaxPool3d( |
|
kernel_size=pool_size[pathway], |
|
stride=pool_size[pathway], |
|
padding=[0, 0, 0], |
|
) |
|
self.add_module("pathway{}_pool".format(pathway), pool) |
|
|
|
self.s3 = resnet_helper.ResStage( |
|
dim_in=[width_per_group * 4], |
|
dim_out=[width_per_group * 8], |
|
dim_inner=[dim_inner * 2], |
|
temp_kernel_sizes=temp_kernel[2], |
|
stride=cfg.RESNET.SPATIAL_STRIDES[1], |
|
num_blocks=[d3], |
|
num_groups=[num_groups], |
|
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[1], |
|
nonlocal_inds=cfg.NONLOCAL.LOCATION[1], |
|
nonlocal_group=cfg.NONLOCAL.GROUP[1], |
|
nonlocal_pool=cfg.NONLOCAL.POOL[1], |
|
instantiation=cfg.NONLOCAL.INSTANTIATION, |
|
trans_func_name=cfg.RESNET.TRANS_FUNC, |
|
stride_1x1=cfg.RESNET.STRIDE_1X1, |
|
inplace_relu=cfg.RESNET.INPLACE_RELU, |
|
dilation=cfg.RESNET.SPATIAL_DILATIONS[1], |
|
norm_module=self.norm_module, |
|
) |
|
|
|
self.s4 = resnet_helper.ResStage( |
|
dim_in=[width_per_group * 8], |
|
dim_out=[width_per_group * 16], |
|
dim_inner=[dim_inner * 4], |
|
temp_kernel_sizes=temp_kernel[3], |
|
stride=cfg.RESNET.SPATIAL_STRIDES[2], |
|
num_blocks=[d4], |
|
num_groups=[num_groups], |
|
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[2], |
|
nonlocal_inds=cfg.NONLOCAL.LOCATION[2], |
|
nonlocal_group=cfg.NONLOCAL.GROUP[2], |
|
nonlocal_pool=cfg.NONLOCAL.POOL[2], |
|
instantiation=cfg.NONLOCAL.INSTANTIATION, |
|
trans_func_name=cfg.RESNET.TRANS_FUNC, |
|
stride_1x1=cfg.RESNET.STRIDE_1X1, |
|
inplace_relu=cfg.RESNET.INPLACE_RELU, |
|
dilation=cfg.RESNET.SPATIAL_DILATIONS[2], |
|
norm_module=self.norm_module, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.labels=["all"] |
|
self.dual_define("t4",self.labels,LightDecoderBlock(width_per_group * 16,width_per_group * 8,width_per_group * 4)) |
|
self.dual_define("t3",self.labels,LightDecoderBlock(width_per_group * 4,width_per_group * 4, 128)) |
|
self.dual_define("conv1x1",self.labels,nn.Sequential( |
|
nn.Conv3d(128+width_per_group, 64, kernel_size=(1, 1, 1), stride=1, padding=0), |
|
nn.BatchNorm3d(64), |
|
nn.ReLU(), |
|
nn.Conv3d(64, 1, kernel_size=(1, 1, 1), stride=1, padding=0), |
|
)) |
|
|
|
self.linear = nn.Sequential(nn.Linear(1, 1)) |
|
|
|
def forward_plus(self, x, y, net): |
|
return [net(x)[0] + y[0]] |
|
|
|
|
|
def dual_define(self,name,labels,net): |
|
for label in labels: |
|
self.add_module(f"{name}_{label}",copy.deepcopy(net)) |
|
|
|
|
|
|
|
def upsample(self, x, dims=["space"]): |
|
ori_size = x[0].shape[2:5] |
|
t, h, w = ori_size |
|
if "space" in dims: |
|
h = 2 * h |
|
w = 2 * w |
|
if "time" in dims: |
|
t = 2 * t |
|
size = (t, h, w) |
|
return [F.interpolate(x[0], size)] |
|
|
|
def concat(self,x,y): |
|
return [torch.cat([x[0],y[0]],1)] |
|
|
|
def get_detach_var(self,x): |
|
return [t.detach() for t in x] |
|
|
|
|
|
def forward(self, x, freeze_backbone=False): |
|
x1 = self.s1(x) |
|
x2 = self.s2(x1) |
|
x3 = self.s3(x2) |
|
x = self.s4(x3) |
|
assert isinstance(freeze_backbone,bool) |
|
if freeze_backbone: |
|
x=self.get_detach_var(x) |
|
x1=self.get_detach_var(x1) |
|
x2=self.get_detach_var(x2) |
|
x3=self.get_detach_var(x3) |
|
|
|
x = self.upsample(x) |
|
x = self.concat(x3,x) |
|
x=[self.forward_branch(x,x1,x2,label) for label in self.labels] |
|
x=torch.cat(x,1) |
|
x=torch.sigmoid(x) |
|
out = x.mean([3, 4]).view(-1, 1)*100 |
|
out = self.linear(out) |
|
out = out.view(x.size(0), -1) |
|
out = torch.sigmoid(out) |
|
return x,out |
|
|
|
|
|
def forward_branch(self,x,x1,x2,label): |
|
t4= getattr(self,f"t4_{label}") |
|
x = t4(x[0]) |
|
x = self.upsample([x]) |
|
x = self.concat(x2,x) |
|
t3= getattr(self,f"t3_{label}") |
|
x = t3(x[0]) |
|
x = self.concat(x1,[x]) |
|
conv1x1=getattr(self,f"conv1x1_{label}") |
|
x = conv1x1(x[0]) |
|
return x |
|
|
|
|
|
|
|
|
|
@MODEL_REGISTRY.register() |
|
class ResUNetCommon(nn.Module): |
|
""" |
|
ResNet model builder. It builds a ResNet like network backbone without |
|
lateral connection (C2D, I3D, Slow). |
|
|
|
Christoph Feichtenhofer, Haoqi Fan, Jitendra Malik, and Kaiming He. |
|
"SlowFast networks for video recognition." |
|
https://arxiv.org/pdf/1812.03982.pdf |
|
|
|
Xiaolong Wang, Ross Girshick, Abhinav Gupta, and Kaiming He. |
|
"Non-local neural networks." |
|
https://arxiv.org/pdf/1711.07971.pdf |
|
""" |
|
|
|
def __init__(self, cfg): |
|
""" |
|
The `__init__` method of any subclass should also contain these |
|
arguments. |
|
|
|
Args: |
|
cfg (CfgNode): model building configs, details are in the |
|
comments of the config file. |
|
""" |
|
super(ResUNetCommon, self).__init__() |
|
self.norm_module = get_norm(cfg) |
|
self.enable_detection = cfg.DETECTION.ENABLE |
|
self.enable_jitter = cfg.JITTER.ENABLE |
|
self.num_pathways = 1 |
|
assert cfg.DATA.TRAIN_CROP_SIZE == cfg.DATA.TEST_CROP_SIZE |
|
self.image_size = cfg.DATA.TRAIN_CROP_SIZE |
|
self.clip_size = cfg.DATA.NUM_FRAMES |
|
self._construct_network(cfg) |
|
init_helper.init_weights( |
|
self, cfg.MODEL.FC_INIT_STD, cfg.RESNET.ZERO_INIT_FINAL_BN |
|
) |
|
|
|
def _construct_network(self, cfg): |
|
""" |
|
Builds a single pathway ResNet model. |
|
|
|
Args: |
|
cfg (CfgNode): model building configs, details are in the |
|
comments of the config file. |
|
""" |
|
assert cfg.MODEL.ARCH in _POOL1.keys() |
|
pool_size = _POOL1[cfg.MODEL.ARCH] |
|
self.cfg = cfg |
|
assert len({len(pool_size), self.num_pathways}) == 1 |
|
assert cfg.RESNET.DEPTH in _MODEL_STAGE_DEPTH.keys() |
|
|
|
(d2, d3, d4, d5) = _MODEL_STAGE_DEPTH[cfg.RESNET.DEPTH] |
|
|
|
num_groups = cfg.RESNET.NUM_GROUPS |
|
width_per_group = cfg.RESNET.WIDTH_PER_GROUP |
|
dim_inner = num_groups * width_per_group |
|
|
|
temp_kernel = _TEMPORAL_KERNEL_BASIS[cfg.MODEL.ARCH] |
|
|
|
self.s1 = stem_helper.VideoModelStem( |
|
dim_in=cfg.DATA.INPUT_CHANNEL_NUM, |
|
dim_out=[width_per_group], |
|
kernel=[temp_kernel[0][0] + [7, 7]], |
|
stride=[[1, 2, 2]], |
|
padding=[[temp_kernel[0][0][0] // 2, 3, 3]], |
|
norm_module=self.norm_module, |
|
) |
|
|
|
self.s2 = resnet_helper.ResStage( |
|
dim_in=[width_per_group], |
|
dim_out=[width_per_group * 4], |
|
dim_inner=[dim_inner], |
|
temp_kernel_sizes=temp_kernel[1], |
|
stride=cfg.RESNET.SPATIAL_STRIDES[0], |
|
num_blocks=[d2], |
|
num_groups=[num_groups], |
|
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[0], |
|
nonlocal_inds=cfg.NONLOCAL.LOCATION[0], |
|
nonlocal_group=cfg.NONLOCAL.GROUP[0], |
|
nonlocal_pool=cfg.NONLOCAL.POOL[0], |
|
instantiation=cfg.NONLOCAL.INSTANTIATION, |
|
trans_func_name=cfg.RESNET.TRANS_FUNC, |
|
stride_1x1=cfg.RESNET.STRIDE_1X1, |
|
inplace_relu=cfg.RESNET.INPLACE_RELU, |
|
dilation=cfg.RESNET.SPATIAL_DILATIONS[0], |
|
norm_module=self.norm_module, |
|
) |
|
|
|
for pathway in range(self.num_pathways): |
|
pool = nn.MaxPool3d( |
|
kernel_size=pool_size[pathway], |
|
stride=pool_size[pathway], |
|
padding=[0, 0, 0], |
|
) |
|
self.add_module("pathway{}_pool".format(pathway), pool) |
|
|
|
self.s3 = resnet_helper.ResStage( |
|
dim_in=[width_per_group * 4], |
|
dim_out=[width_per_group * 8], |
|
dim_inner=[dim_inner * 2], |
|
temp_kernel_sizes=temp_kernel[2], |
|
stride=cfg.RESNET.SPATIAL_STRIDES[1], |
|
num_blocks=[d3], |
|
num_groups=[num_groups], |
|
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[1], |
|
nonlocal_inds=cfg.NONLOCAL.LOCATION[1], |
|
nonlocal_group=cfg.NONLOCAL.GROUP[1], |
|
nonlocal_pool=cfg.NONLOCAL.POOL[1], |
|
instantiation=cfg.NONLOCAL.INSTANTIATION, |
|
trans_func_name=cfg.RESNET.TRANS_FUNC, |
|
stride_1x1=cfg.RESNET.STRIDE_1X1, |
|
inplace_relu=cfg.RESNET.INPLACE_RELU, |
|
dilation=cfg.RESNET.SPATIAL_DILATIONS[1], |
|
norm_module=self.norm_module, |
|
) |
|
|
|
self.s4 = resnet_helper.ResStage( |
|
dim_in=[width_per_group * 8], |
|
dim_out=[width_per_group * 16], |
|
dim_inner=[dim_inner * 4], |
|
temp_kernel_sizes=temp_kernel[3], |
|
stride=cfg.RESNET.SPATIAL_STRIDES[2], |
|
num_blocks=[d4], |
|
num_groups=[num_groups], |
|
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[2], |
|
nonlocal_inds=cfg.NONLOCAL.LOCATION[2], |
|
nonlocal_group=cfg.NONLOCAL.GROUP[2], |
|
nonlocal_pool=cfg.NONLOCAL.POOL[2], |
|
instantiation=cfg.NONLOCAL.INSTANTIATION, |
|
trans_func_name=cfg.RESNET.TRANS_FUNC, |
|
stride_1x1=cfg.RESNET.STRIDE_1X1, |
|
inplace_relu=cfg.RESNET.INPLACE_RELU, |
|
dilation=cfg.RESNET.SPATIAL_DILATIONS[2], |
|
norm_module=self.norm_module, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.labels=cfg.RESNET.LABELS |
|
self.dual_define("t4",self.labels,LightDecoderBlock(width_per_group * 16,width_per_group * 8,width_per_group * 4)) |
|
self.dual_define("t3",self.labels,LightDecoderBlock(width_per_group * 4,width_per_group * 4, 128)) |
|
self.dual_define("conv1x1",self.labels,nn.Sequential( |
|
nn.Conv3d(128+width_per_group, 64, kernel_size=(1, 1, 1), stride=1, padding=0), |
|
nn.BatchNorm3d(64), |
|
nn.ReLU(), |
|
nn.Conv3d(64, 1, kernel_size=(1, 1, 1), stride=1, padding=0), |
|
)) |
|
|
|
self.linear = nn.Linear(1, 2) |
|
|
|
def forward_plus(self, x, y, net): |
|
return [net(x)[0] + y[0]] |
|
|
|
|
|
def dual_define(self,name,labels,net): |
|
for label in labels: |
|
self.add_module(f"{name}_{label}",copy.deepcopy(net)) |
|
|
|
|
|
def upsample(self, x, dims=["space"]): |
|
ori_size = x[0].shape[2:5] |
|
t, h, w = ori_size |
|
if "space" in dims: |
|
h = 2 * h |
|
w = 2 * w |
|
if "time" in dims: |
|
t = 2 * t |
|
size = (t, h, w) |
|
return [F.interpolate(x[0], size)] |
|
|
|
def concat(self,x,y): |
|
return [torch.cat([x[0],y[0]],1)] |
|
|
|
def get_detach_var(self,x): |
|
return [t.detach() for t in x] |
|
|
|
|
|
def forward(self, x, freeze_backbone=False): |
|
x = self.get_detach_var(x) |
|
x1 = self.s1(x) |
|
x2 = self.s2(x1) |
|
x3 = self.s3(x2) |
|
feat= self.s4(x3) |
|
assert isinstance(freeze_backbone,bool) |
|
if freeze_backbone: |
|
feat=self.get_detach_var(feat) |
|
x1=self.get_detach_var(x1) |
|
x2=self.get_detach_var(x2) |
|
x3=self.get_detach_var(x3) |
|
|
|
feat = self.upsample(feat) |
|
feat = self.concat(x3,feat) |
|
reg_out=[self.forward_branch(feat,x1,x2,label) for label in self.labels] |
|
reg_out=torch.cat(reg_out,1) |
|
reg_out=torch.sigmoid(reg_out) |
|
class_out = reg_out.mean([3, 4]).view(-1, 1)*100 |
|
class_out = self.linear(class_out) |
|
class_out = class_out.view(reg_out.size(0),len(self.labels),-1) |
|
class_out = class_out |
|
return reg_out,class_out |
|
|
|
|
|
def forward_branch(self,feat,x1,x2,label): |
|
t4= getattr(self,f"t4_{label}") |
|
feat = t4(feat[0]) |
|
feat = self.upsample([feat]) |
|
feat = self.concat(x2,feat) |
|
t3= getattr(self,f"t3_{label}") |
|
feat = t3(feat[0]) |
|
feat = self.concat(x1,[feat]) |
|
conv1x1=getattr(self,f"conv1x1_{label}") |
|
feat = conv1x1(feat[0]) |
|
return feat |
|
|
|
|
|
|
|
|
|
@MODEL_REGISTRY.register() |
|
class ResUNetCommon2(nn.Module): |
|
""" |
|
ResNet model builder. It builds a ResNet like network backbone without |
|
lateral connection (C2D, I3D, Slow). |
|
|
|
Christoph Feichtenhofer, Haoqi Fan, Jitendra Malik, and Kaiming He. |
|
"SlowFast networks for video recognition." |
|
https://arxiv.org/pdf/1812.03982.pdf |
|
|
|
Xiaolong Wang, Ross Girshick, Abhinav Gupta, and Kaiming He. |
|
"Non-local neural networks." |
|
https://arxiv.org/pdf/1711.07971.pdf |
|
""" |
|
|
|
def __init__(self, cfg): |
|
""" |
|
The `__init__` method of any subclass should also contain these |
|
arguments. |
|
|
|
Args: |
|
cfg (CfgNode): model building configs, details are in the |
|
comments of the config file. |
|
""" |
|
super(ResUNetCommon2, self).__init__() |
|
self.norm_module = get_norm(cfg) |
|
self.enable_detection = cfg.DETECTION.ENABLE |
|
self.enable_jitter = cfg.JITTER.ENABLE |
|
self.num_pathways = 1 |
|
assert cfg.DATA.TRAIN_CROP_SIZE == cfg.DATA.TEST_CROP_SIZE |
|
self.image_size = cfg.DATA.TRAIN_CROP_SIZE |
|
self.clip_size = cfg.DATA.NUM_FRAMES |
|
self._construct_network(cfg) |
|
init_helper.init_weights( |
|
self, cfg.MODEL.FC_INIT_STD, cfg.RESNET.ZERO_INIT_FINAL_BN |
|
) |
|
|
|
def _construct_network(self, cfg): |
|
""" |
|
Builds a single pathway ResNet model. |
|
|
|
Args: |
|
cfg (CfgNode): model building configs, details are in the |
|
comments of the config file. |
|
""" |
|
assert cfg.MODEL.ARCH in _POOL1.keys() |
|
pool_size = _POOL1[cfg.MODEL.ARCH] |
|
self.cfg = cfg |
|
assert len({len(pool_size), self.num_pathways}) == 1 |
|
assert cfg.RESNET.DEPTH in _MODEL_STAGE_DEPTH.keys() |
|
|
|
(d2, d3, d4, d5) = _MODEL_STAGE_DEPTH[cfg.RESNET.DEPTH] |
|
|
|
num_groups = cfg.RESNET.NUM_GROUPS |
|
width_per_group = cfg.RESNET.WIDTH_PER_GROUP |
|
dim_inner = num_groups * width_per_group |
|
|
|
temp_kernel = _TEMPORAL_KERNEL_BASIS[cfg.MODEL.ARCH] |
|
|
|
self.s1 = stem_helper.VideoModelStem( |
|
dim_in=cfg.DATA.INPUT_CHANNEL_NUM, |
|
dim_out=[width_per_group], |
|
kernel=[temp_kernel[0][0] + [7, 7]], |
|
stride=[[1, 2, 2]], |
|
padding=[[temp_kernel[0][0][0] // 2, 3, 3]], |
|
norm_module=self.norm_module, |
|
) |
|
|
|
self.s2 = resnet_helper.ResStage( |
|
dim_in=[width_per_group], |
|
dim_out=[width_per_group * 4], |
|
dim_inner=[dim_inner], |
|
temp_kernel_sizes=temp_kernel[1], |
|
stride=cfg.RESNET.SPATIAL_STRIDES[0], |
|
num_blocks=[d2], |
|
num_groups=[num_groups], |
|
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[0], |
|
nonlocal_inds=cfg.NONLOCAL.LOCATION[0], |
|
nonlocal_group=cfg.NONLOCAL.GROUP[0], |
|
nonlocal_pool=cfg.NONLOCAL.POOL[0], |
|
instantiation=cfg.NONLOCAL.INSTANTIATION, |
|
trans_func_name=cfg.RESNET.TRANS_FUNC, |
|
stride_1x1=cfg.RESNET.STRIDE_1X1, |
|
inplace_relu=cfg.RESNET.INPLACE_RELU, |
|
dilation=cfg.RESNET.SPATIAL_DILATIONS[0], |
|
norm_module=self.norm_module, |
|
) |
|
|
|
for pathway in range(self.num_pathways): |
|
pool = nn.MaxPool3d( |
|
kernel_size=pool_size[pathway], |
|
stride=pool_size[pathway], |
|
padding=[0, 0, 0], |
|
) |
|
self.add_module("pathway{}_pool".format(pathway), pool) |
|
|
|
self.s3 = resnet_helper.ResStage( |
|
dim_in=[width_per_group * 4], |
|
dim_out=[width_per_group * 8], |
|
dim_inner=[dim_inner * 2], |
|
temp_kernel_sizes=temp_kernel[2], |
|
stride=cfg.RESNET.SPATIAL_STRIDES[1], |
|
num_blocks=[d3], |
|
num_groups=[num_groups], |
|
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[1], |
|
nonlocal_inds=cfg.NONLOCAL.LOCATION[1], |
|
nonlocal_group=cfg.NONLOCAL.GROUP[1], |
|
nonlocal_pool=cfg.NONLOCAL.POOL[1], |
|
instantiation=cfg.NONLOCAL.INSTANTIATION, |
|
trans_func_name=cfg.RESNET.TRANS_FUNC, |
|
stride_1x1=cfg.RESNET.STRIDE_1X1, |
|
inplace_relu=cfg.RESNET.INPLACE_RELU, |
|
dilation=cfg.RESNET.SPATIAL_DILATIONS[1], |
|
norm_module=self.norm_module, |
|
) |
|
|
|
self.s4 = resnet_helper.ResStage( |
|
dim_in=[width_per_group * 8], |
|
dim_out=[width_per_group * 16], |
|
dim_inner=[dim_inner * 4], |
|
temp_kernel_sizes=temp_kernel[3], |
|
stride=cfg.RESNET.SPATIAL_STRIDES[2], |
|
num_blocks=[d4], |
|
num_groups=[num_groups], |
|
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[2], |
|
nonlocal_inds=cfg.NONLOCAL.LOCATION[2], |
|
nonlocal_group=cfg.NONLOCAL.GROUP[2], |
|
nonlocal_pool=cfg.NONLOCAL.POOL[2], |
|
instantiation=cfg.NONLOCAL.INSTANTIATION, |
|
trans_func_name=cfg.RESNET.TRANS_FUNC, |
|
stride_1x1=cfg.RESNET.STRIDE_1X1, |
|
inplace_relu=cfg.RESNET.INPLACE_RELU, |
|
dilation=cfg.RESNET.SPATIAL_DILATIONS[2], |
|
norm_module=self.norm_module, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.labels=cfg.RESNET.LABELS |
|
self.dual_define("t4",self.labels,LightDecoderBlock(width_per_group * 16,width_per_group * 8,width_per_group * 4)) |
|
self.dual_define("t3",self.labels,LightDecoderBlock(width_per_group * 4,width_per_group * 4, 128)) |
|
self.dual_define("conv1x1",self.labels,nn.Sequential( |
|
nn.Conv3d(128+width_per_group, 64, kernel_size=(1, 1, 1), stride=1, padding=0), |
|
nn.BatchNorm3d(64), |
|
nn.ReLU(), |
|
nn.Conv3d(64, 1, kernel_size=(1, 1, 1), stride=1, padding=0), |
|
)) |
|
|
|
self.linear = nn.Linear(1, 1) |
|
|
|
def forward_plus(self, x, y, net): |
|
return [net(x)[0] + y[0]] |
|
|
|
|
|
def dual_define(self,name,labels,net): |
|
for label in labels: |
|
self.add_module(f"{name}_{label}",copy.deepcopy(net)) |
|
|
|
|
|
def upsample(self, x, dims=["space"]): |
|
ori_size = x[0].shape[2:5] |
|
t, h, w = ori_size |
|
if "space" in dims: |
|
h = 2 * h |
|
w = 2 * w |
|
if "time" in dims: |
|
t = 2 * t |
|
size = (t, h, w) |
|
return [F.interpolate(x[0], size)] |
|
|
|
def concat(self,x,y): |
|
return [torch.cat([x[0],y[0]],1)] |
|
|
|
def get_detach_var(self,x): |
|
return [t.detach() for t in x] |
|
|
|
|
|
def forward(self, x, freeze_backbone=False): |
|
x = self.get_detach_var(x) |
|
x1 = self.s1(x) |
|
x2 = self.s2(x1) |
|
x3 = self.s3(x2) |
|
feat= self.s4(x3) |
|
assert isinstance(freeze_backbone,bool) |
|
if freeze_backbone: |
|
feat=self.get_detach_var(feat) |
|
x1=self.get_detach_var(x1) |
|
x2=self.get_detach_var(x2) |
|
x3=self.get_detach_var(x3) |
|
|
|
feat = self.upsample(feat) |
|
feat = self.concat(x3,feat) |
|
reg_out=[self.forward_branch(feat,x1,x2,label) for label in self.labels] |
|
reg_out=torch.cat(reg_out,1) |
|
reg_out=torch.sigmoid(reg_out) |
|
class_out = reg_out.mean([3, 4]).view(-1, 1)*100 |
|
class_out = self.linear(class_out) |
|
class_out = class_out.view(reg_out.size(0),len(self.labels),-1) |
|
class_out = torch.sigmoid(class_out) |
|
return reg_out,class_out |
|
|
|
|
|
def forward_branch(self,feat,x1,x2,label): |
|
t4= getattr(self,f"t4_{label}") |
|
feat = t4(feat[0]) |
|
feat = self.upsample([feat]) |
|
feat = self.concat(x2,feat) |
|
t3= getattr(self,f"t3_{label}") |
|
feat = t3(feat[0]) |
|
feat = self.concat(x1,[feat]) |
|
conv1x1=getattr(self,f"conv1x1_{label}") |
|
feat = conv1x1(feat[0]) |
|
return feat |
|
|
|
|
|
|
|
@MODEL_REGISTRY.register() |
|
class ResUNetStrong(nn.Module): |
|
""" |
|
ResNet model builder. It builds a ResNet like network backbone without |
|
lateral connection (C2D, I3D, Slow). |
|
|
|
Christoph Feichtenhofer, Haoqi Fan, Jitendra Malik, and Kaiming He. |
|
"SlowFast networks for video recognition." |
|
https://arxiv.org/pdf/1812.03982.pdf |
|
|
|
Xiaolong Wang, Ross Girshick, Abhinav Gupta, and Kaiming He. |
|
"Non-local neural networks." |
|
https://arxiv.org/pdf/1711.07971.pdf |
|
""" |
|
|
|
def __init__(self, cfg): |
|
""" |
|
The `__init__` method of any subclass should also contain these |
|
arguments. |
|
|
|
Args: |
|
cfg (CfgNode): model building configs, details are in the |
|
comments of the config file. |
|
""" |
|
super(ResUNetStrong, self).__init__() |
|
self.norm_module = get_norm(cfg) |
|
self.enable_detection = cfg.DETECTION.ENABLE |
|
self.enable_jitter = cfg.JITTER.ENABLE |
|
self.num_pathways = 1 |
|
assert cfg.DATA.TRAIN_CROP_SIZE == cfg.DATA.TEST_CROP_SIZE |
|
self.image_size = cfg.DATA.TRAIN_CROP_SIZE |
|
self.clip_size = cfg.DATA.NUM_FRAMES |
|
self._construct_network(cfg) |
|
init_helper.init_weights( |
|
self, cfg.MODEL.FC_INIT_STD, cfg.RESNET.ZERO_INIT_FINAL_BN |
|
) |
|
|
|
def _construct_network(self, cfg): |
|
""" |
|
Builds a single pathway ResNet model. |
|
|
|
Args: |
|
cfg (CfgNode): model building configs, details are in the |
|
comments of the config file. |
|
""" |
|
assert cfg.MODEL.ARCH in _POOL1.keys() |
|
pool_size = _POOL1[cfg.MODEL.ARCH] |
|
self.cfg = cfg |
|
assert len({len(pool_size), self.num_pathways}) == 1 |
|
assert cfg.RESNET.DEPTH in _MODEL_STAGE_DEPTH.keys() |
|
|
|
(d2, d3, d4, d5) = _MODEL_STAGE_DEPTH[cfg.RESNET.DEPTH] |
|
|
|
num_groups = cfg.RESNET.NUM_GROUPS |
|
width_per_group = cfg.RESNET.WIDTH_PER_GROUP |
|
dim_inner = num_groups * width_per_group |
|
|
|
temp_kernel = _TEMPORAL_KERNEL_BASIS[cfg.MODEL.ARCH] |
|
|
|
self.s1 = stem_helper.VideoModelStem( |
|
dim_in=cfg.DATA.INPUT_CHANNEL_NUM, |
|
dim_out=[width_per_group], |
|
kernel=[temp_kernel[0][0] + [7, 7]], |
|
stride=[[1, 2, 2]], |
|
padding=[[temp_kernel[0][0][0] // 2, 3, 3]], |
|
norm_module=self.norm_module, |
|
) |
|
|
|
self.s2 = resnet_helper.ResStage( |
|
dim_in=[width_per_group], |
|
dim_out=[width_per_group * 4], |
|
dim_inner=[dim_inner], |
|
temp_kernel_sizes=temp_kernel[1], |
|
stride=cfg.RESNET.SPATIAL_STRIDES[0], |
|
num_blocks=[d2], |
|
num_groups=[num_groups], |
|
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[0], |
|
nonlocal_inds=cfg.NONLOCAL.LOCATION[0], |
|
nonlocal_group=cfg.NONLOCAL.GROUP[0], |
|
nonlocal_pool=cfg.NONLOCAL.POOL[0], |
|
instantiation=cfg.NONLOCAL.INSTANTIATION, |
|
trans_func_name=cfg.RESNET.TRANS_FUNC, |
|
stride_1x1=cfg.RESNET.STRIDE_1X1, |
|
inplace_relu=cfg.RESNET.INPLACE_RELU, |
|
dilation=cfg.RESNET.SPATIAL_DILATIONS[0], |
|
norm_module=self.norm_module, |
|
) |
|
|
|
for pathway in range(self.num_pathways): |
|
pool = nn.MaxPool3d( |
|
kernel_size=pool_size[pathway], |
|
stride=pool_size[pathway], |
|
padding=[0, 0, 0], |
|
) |
|
self.add_module("pathway{}_pool".format(pathway), pool) |
|
|
|
self.s3 = resnet_helper.ResStage( |
|
dim_in=[width_per_group * 4], |
|
dim_out=[width_per_group * 8], |
|
dim_inner=[dim_inner * 2], |
|
temp_kernel_sizes=temp_kernel[2], |
|
stride=cfg.RESNET.SPATIAL_STRIDES[1], |
|
num_blocks=[d3], |
|
num_groups=[num_groups], |
|
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[1], |
|
nonlocal_inds=cfg.NONLOCAL.LOCATION[1], |
|
nonlocal_group=cfg.NONLOCAL.GROUP[1], |
|
nonlocal_pool=cfg.NONLOCAL.POOL[1], |
|
instantiation=cfg.NONLOCAL.INSTANTIATION, |
|
trans_func_name=cfg.RESNET.TRANS_FUNC, |
|
stride_1x1=cfg.RESNET.STRIDE_1X1, |
|
inplace_relu=cfg.RESNET.INPLACE_RELU, |
|
dilation=cfg.RESNET.SPATIAL_DILATIONS[1], |
|
norm_module=self.norm_module, |
|
) |
|
|
|
self.s4 = resnet_helper.ResStage( |
|
dim_in=[width_per_group * 8], |
|
dim_out=[width_per_group * 16], |
|
dim_inner=[dim_inner * 4], |
|
temp_kernel_sizes=temp_kernel[3], |
|
stride=cfg.RESNET.SPATIAL_STRIDES[2], |
|
num_blocks=[d4], |
|
num_groups=[num_groups], |
|
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[2], |
|
nonlocal_inds=cfg.NONLOCAL.LOCATION[2], |
|
nonlocal_group=cfg.NONLOCAL.GROUP[2], |
|
nonlocal_pool=cfg.NONLOCAL.POOL[2], |
|
instantiation=cfg.NONLOCAL.INSTANTIATION, |
|
trans_func_name=cfg.RESNET.TRANS_FUNC, |
|
stride_1x1=cfg.RESNET.STRIDE_1X1, |
|
inplace_relu=cfg.RESNET.INPLACE_RELU, |
|
dilation=cfg.RESNET.SPATIAL_DILATIONS[2], |
|
norm_module=self.norm_module, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.labels=cfg.RESNET.LABELS |
|
self.dual_define("t4",self.labels,ResDecoderBlock(width_per_group * 16,width_per_group * 8,width_per_group * 8)) |
|
self.dual_define("t3",self.labels,ResDecoderBlock(width_per_group * 8,width_per_group * 4, 256)) |
|
self.dual_define("conv1x1",self.labels,nn.Sequential( |
|
nn.Conv3d(width_per_group*4+width_per_group, 128, kernel_size=(1, 1, 1), stride=1, padding=0), |
|
nn.BatchNorm3d(128), |
|
nn.ReLU(), |
|
nn.Conv3d(128, 1, kernel_size=(1, 1, 1), stride=1, padding=0), |
|
)) |
|
|
|
self.linear = nn.Linear(1, 1) |
|
|
|
def forward_plus(self, x, y, net): |
|
return [net(x)[0] + y[0]] |
|
|
|
|
|
def dual_define(self,name,labels,net): |
|
for label in labels: |
|
self.add_module(f"{name}_{label}",copy.deepcopy(net)) |
|
|
|
|
|
def upsample(self, x, dims=["space"]): |
|
ori_size = x[0].shape[2:5] |
|
t, h, w = ori_size |
|
if "space" in dims: |
|
h = 2 * h |
|
w = 2 * w |
|
if "time" in dims: |
|
t = 2 * t |
|
size = (t, h, w) |
|
return [F.interpolate(x[0], size)] |
|
|
|
def concat(self,x,y): |
|
return [torch.cat([x[0],y[0]],1)] |
|
|
|
def get_detach_var(self,x): |
|
return [t.detach() for t in x] |
|
|
|
|
|
def forward(self, x, freeze_backbone=False): |
|
x = self.get_detach_var(x) |
|
x1 = self.s1(x) |
|
x2 = self.s2(x1) |
|
x3 = self.s3(x2) |
|
feat= self.s4(x3) |
|
assert isinstance(freeze_backbone,bool) |
|
if freeze_backbone: |
|
feat=self.get_detach_var(feat) |
|
x1=self.get_detach_var(x1) |
|
x2=self.get_detach_var(x2) |
|
x3=self.get_detach_var(x3) |
|
|
|
feat = self.upsample(feat) |
|
feat = self.concat(x3,feat) |
|
reg_out=[self.forward_branch(feat,x1,x2,label) for label in self.labels] |
|
reg_out=torch.cat(reg_out,1) |
|
reg_out=torch.sigmoid(reg_out) |
|
class_out = reg_out.mean([3, 4]).view(-1, 1)*100 |
|
class_out = self.linear(class_out) |
|
class_out = class_out.view(reg_out.size(0),len(self.labels),-1) |
|
class_out = torch.sigmoid(class_out) |
|
return reg_out,class_out |
|
|
|
|
|
def forward_branch(self,feat,x1,x2,label): |
|
t4= getattr(self,f"t4_{label}") |
|
feat = t4(feat[0]) |
|
feat = self.upsample([feat]) |
|
feat = self.concat(x2,feat) |
|
t3= getattr(self,f"t3_{label}") |
|
feat = t3(feat[0]) |
|
feat = self.concat(x1,[feat]) |
|
conv1x1=getattr(self,f"conv1x1_{label}") |
|
feat = conv1x1(feat[0]) |
|
return feat |
|
|